Skip to content

Commit

Permalink
Avoid calling get_max_length (huggingface#34971)
Browse files Browse the repository at this point in the history
fix

Co-authored-by: ydshieh <[email protected]>
  • Loading branch information
2 people authored and BernardZach committed Dec 6, 2024
1 parent 33d51d9 commit a1e320b
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/transformers/models/bloom/modeling_bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -911,7 +911,7 @@ def prepare_inputs_for_generation(
# This part differs from other models because BLOOM needs a 2D mask to construct alibi tensor
# The only difference is the usage of 2D instead of 4D mask, but the shape will be static
if isinstance(past_key_values, StaticCache) and attention_mask is not None:
target_length = past_key_values.get_max_length()
target_length = past_key_values.get_max_cache_shape()
batch_size, seq_length = attention_mask.shape
diff = target_length - seq_length

Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/granite/modeling_granite.py
Original file line number Diff line number Diff line change
Expand Up @@ -909,7 +909,7 @@ def _update_causal_mask(
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_length()
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/granitemoe/modeling_granitemoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1142,7 +1142,7 @@ def _update_causal_mask(
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_length()
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/paligemma/modeling_paligemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ def _update_causal_mask(
min_dtype = torch.finfo(dtype).min
sequence_length = inputs_embeds.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_length()
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
Expand Down

0 comments on commit a1e320b

Please sign in to comment.