-
Notifications
You must be signed in to change notification settings - Fork 30.3k
Closed
Labels
Description
System Info
- transformer: main
- pytorch, cuda: anyversion
Who can help?
No response
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
from transformers import AutoModelForCausalLM
import numpy as np
import torch
MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
input_ids = torch.randint(0, 100, (1, 8192))
with torch.no_grad():
output_raw = model(input_ids)
# correct situaion: should be the same as the original model since use_sliding_window is False
model_no_sliding = AutoModelForCausalLM.from_pretrained(MODEL_NAME, sliding_window=None)
with torch.no_grad():
output_non_sliding = model_no_sliding(input_ids)
np.testing.assert_allclose(output_raw.logits[:, :4096], output_non_sliding.logits[:, :4096])
# wrong: the logits are unexpectedly different with sliding_window=4096
np.testing.assert_allclose(output_raw.logits[:, 4096:], output_non_sliding.logits[:, 4096:])
Expected behavior
description
What is expected:
"sliding_window": 4096,
"use_sliding_window": false,
use_sliding_window
is set as false in deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
here. We do expect sliding window is disabled. In other words, we should expect the same results even with different sliding_window
.
However, the results are different in the repro script.
Root cause
Attention Mask is changed according to sliding_window
without respect on use_sliding_window
.
transformers/src/transformers/models/qwen2/modeling_qwen2.py
Lines 708 to 715 in 3c0796a
if config.get_text_config().sliding_window is not None: | |
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also | |
# the check is needed to verify is current checkpoint was trained with sliding window or not | |
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: | |
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= ( | |
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window | |
) | |
diagonal_attend_mask.bitwise_or_(sliding_attend_mask) |
If we add some printing under this conditional block, we can clearly see attention mask is changed even with use_sliding_window=false