Skip to content

Commit

Permalink
Update attention.py (#1416)
Browse files Browse the repository at this point in the history
* Update attention.py

modify the code about bigcode. 
This modification makes the KV cache with multiple new tokens works well.

* consider batch size = 1

* Update attention.py

* def kv_seq_len
  • Loading branch information
DongHande authored Oct 9, 2023
1 parent c98cb87 commit c8cf353
Showing 1 changed file with 27 additions and 22 deletions.
49 changes: 27 additions & 22 deletions optimum/bettertransformer/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,7 @@ def gpt_bigcode_wrapped_scaled_dot_product(
# MHA models: (batch_size, num_heads, query_length, head_dim)
query_shape = query.shape
batch_size = query_shape[0]
kv_seq_len = key.shape[-2]

if self.multi_query:
query_length = query_shape[1]
Expand All @@ -725,30 +726,34 @@ def gpt_bigcode_wrapped_scaled_dot_product(
key = key.expand(-1, self.num_heads, -1, -1)
value = value.expand(-1, self.num_heads, -1, -1)

if batch_size == 1 or self.training:
if query_length > 1:
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=True
)
else:
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=False
)
# We treat self.training and (batch_size == 1 and query_length == 1) cases separately to still allow the dispatch to Flash Attention.
if self.training:
is_causal = True
attn_mask = None
elif batch_size == 1 and query_length == 1:
is_causal = False
attn_mask = None
elif batch_size == 1 and kv_seq_len == query_length:
is_causal = True
attn_mask = None
elif attention_mask is not None:
mask_value = self._get_mask_value(query.device, query.dtype)

# gpt_bigcode has the bad taste to use a causal mask a
# [batch_size, target_length, 1, source_length] which is different from
# **all** other architectures and not compatible with SDPA.
# We could avoid this transpose by overriding the forward from GPTBigCodeModel,
# but it is probably not worth it.
attention_mask = attention_mask.transpose(1, 2)
attn_mask = torch.where(attention_mask, 0.0, mask_value)
is_causal = False
else:
if attention_mask is not None:
mask_value = self._get_mask_value(query.device, query.dtype)
attn_mask = None
is_causal = True

# gpt_bigcode has the bad taste to use a causal mask a
# [batch_size, target_length, 1, source_length] which is different from
# **all** other architectures and not compatible with SDPA.
# We could avoid this transpose by overriding the forward from GPTBigCodeModel,
# but it is probably not worth it.
attention_mask = attention_mask.transpose(1, 2)
attention_mask = torch.where(attention_mask, 0.0, mask_value)

sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=dropout_p, is_causal=False
)
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal
)

if self.multi_query:
# (batch_size, num_heads, seq_len, head_dim) --> (batch_size, seq_len, num_heads, head_dim)
Expand Down

0 comments on commit c8cf353

Please sign in to comment.