Skip to content

Commit

Permalink
Use lower precision in DPA (#20615)
Browse files Browse the repository at this point in the history
  • Loading branch information
james77777778 authored Dec 10, 2024
1 parent 5b6b9b0 commit c6c0720
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 12 deletions.
14 changes: 8 additions & 6 deletions keras/src/backend/numpy/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1096,12 +1096,14 @@ def _apply_masks(logits, mask, is_causal):


def _dot_product_attention_xla(query, key, value, bias, mask, is_causal, scale):
original_dtype = key.dtype
logits_dtype = np.promote_types(query.dtype, np.float32)
logits = np.einsum(
"BTNH,BSNH->BNTS",
query.astype(logits_dtype),
key.astype(logits_dtype),
)
if backend.standardize_dtype(key.dtype) == "bfloat16":
# `np.einsum` doesn't support bfloat16
key = key.astype("float32")
value = value.astype("float32")
logits = np.einsum("BTNH,BSNH->BNTS", query, key)
logits = logits.astype(logits_dtype)
logits *= np.array(scale, dtype=logits.dtype)

if bias is not None:
Expand All @@ -1111,7 +1113,7 @@ def _dot_product_attention_xla(query, key, value, bias, mask, is_causal, scale):

# Softmax and it is always carried out in fp32.
padded_logits = padded_logits.astype(np.float32)
probs = softmax(padded_logits, axis=-1).astype(key.dtype)
probs = softmax(padded_logits, axis=-1).astype(original_dtype)
encoded_dtype = probs.dtype
if backend.standardize_dtype(probs.dtype) == "bfloat16":
# `np.einsum` doesn't support bfloat16
Expand Down
8 changes: 2 additions & 6 deletions keras/src/backend/tensorflow/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1015,12 +1015,8 @@ def _apply_masks(logits, mask, is_causal):

def _dot_product_attention_xla(query, key, value, bias, mask, is_causal, scale):
logits_dtype = backend.result_type(query.dtype, "float32")
logits = tf.einsum(
"BTNH,BSNH->BNTS",
tf.cast(query, dtype=logits_dtype),
tf.cast(key, dtype=logits_dtype),
optimize="optimal",
)
logits = tf.einsum("BTNH,BSNH->BNTS", query, key, optimize="optimal")
logits = tf.cast(logits, logits_dtype)
logits = tf.multiply(logits, tf.cast(scale, logits.dtype))

if bias is not None:
Expand Down

0 comments on commit c6c0720

Please sign in to comment.