Skip to content

Commit

Permalink
Flash attention support. (keras-team#20152)
Browse files Browse the repository at this point in the history
* added flash attention support for pytorch

* added a comment explaining why the causal mask is created manually

* added unit tests for flash attention

* added test skipping for flash attention for numpy

* removed flash attn op and added support for flash attention inside dot_product_attention op

* added skip tests for every framework except torch

* formatted files

* added checks for flash attention in pytorch beforing computing attention and removed flash attention from tests

* added skipping tests for all frameworks except jax

* formatted files

* added conditions to skip tests for jax

* fixed typo
  • Loading branch information
hazemessamm authored Oct 8, 2024
1 parent f52f9f5 commit 8e67e0e
Show file tree
Hide file tree
Showing 6 changed files with 176 additions and 20 deletions.
18 changes: 17 additions & 1 deletion keras/src/backend/jax/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -986,7 +986,14 @@ def _dot_product_attention_core(


def dot_product_attention(
query, key, value, bias=None, mask=None, scale=None, is_causal=False
query,
key,
value,
bias=None,
mask=None,
scale=None,
is_causal=False,
flash_attention=False,
):
query = convert_to_tensor(query)
key = convert_to_tensor(key)
Expand All @@ -1000,6 +1007,7 @@ def dot_product_attention(

# `dot_product_attention` is only available in jax>=0.4.31
if hasattr(jax.nn, "dot_product_attention"):
implementation = "cudnn" if flash_attention else "xla"
return jax.nn.dot_product_attention(
query,
key,
Expand All @@ -1008,6 +1016,14 @@ def dot_product_attention(
mask=mask,
scale=scale,
is_causal=is_causal,
implementation=implementation,
)

if flash_attention:
raise ValueError(
"Flash attention is not supported in your "
"current JAX version. Please update it "
"using `pip install -U jax jaxlib`."
)

# Ref: jax.nn.dot_product_attention
Expand Down
11 changes: 10 additions & 1 deletion keras/src/backend/numpy/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1033,8 +1033,17 @@ def _dot_product_attention_xla(query, key, value, bias, mask, is_causal, scale):


def dot_product_attention(
query, key, value, bias=None, mask=None, scale=None, is_causal=False
query,
key,
value,
bias=None,
mask=None,
scale=None,
is_causal=False,
flash_attention=False,
):
if flash_attention:
raise ValueError("Flash attention is not implemented in NumPy.")
# Ref: jax.nn.dot_product_attention
# https://github.com/jax-ml/jax/blob/jax-v0.4.32/jax/_src/nn/functions.py#L828
# Not support `query_seq_lengths` and `key_value_seq_lengths` args
Expand Down
14 changes: 13 additions & 1 deletion keras/src/backend/tensorflow/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -964,8 +964,20 @@ def _dot_product_attention_xla(query, key, value, bias, mask, is_causal, scale):


def dot_product_attention(
query, key, value, bias=None, mask=None, scale=None, is_causal=False
query,
key,
value,
bias=None,
mask=None,
scale=None,
is_causal=False,
flash_attention=False,
):
if flash_attention:
raise ValueError(
"Flash attention is not supported yet in TensorFlow backend."
)

# Ref: jax.nn.dot_product_attention
# https://github.com/jax-ml/jax/blob/jax-v0.4.32/jax/_src/nn/functions.py#L828
# Not support `query_seq_lengths` and `key_value_seq_lengths` args
Expand Down
56 changes: 52 additions & 4 deletions keras/src/backend/torch/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,8 +864,28 @@ def _get_large_negative(dtype):
return convert_to_tensor(val * -0.7, dtype=dtype)


def is_flash_attention_enabled(query, key, value, mask=None, is_causal=False):
params = torch.backends.cuda.SDPAParams(
query,
key,
value,
mask,
0.0,
is_causal,
)
is_enabled = torch.backends.cuda.can_use_flash_attention(params, False)
return is_enabled


def dot_product_attention(
query, key, value, bias=None, mask=None, scale=None, is_causal=False
query,
key,
value,
bias=None,
mask=None,
scale=None,
is_causal=False,
flash_attention=False,
):
if bias is not None:
raise ValueError(
Expand All @@ -891,7 +911,35 @@ def dot_product_attention(
query = torch.transpose(query, axis0, axis1)
key = torch.transpose(key, axis0, axis1)
value = torch.transpose(value, axis0, axis1)
attention_output = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=mask, is_causal=is_causal, scale=scale
)

if flash_attention:
is_enabled = is_flash_attention_enabled(
query=query,
key=key,
value=value,
mask=mask,
is_causal=is_causal,
)
if not is_enabled:
raise ValueError(
"Flash attention is not enabled in `torch` backend. "
"The dtype of the inputs should be float16/bfloat16 "
"and your GPU should support flash attention implementation."
)

with torch.nn.attention.sdpa_kernel(
backends=[torch.nn.attention.SDPBackend.FLASH_ATTENTION],
):
attention_output = torch.nn.functional.scaled_dot_product_attention(
query,
key,
value,
attn_mask=mask,
is_causal=is_causal,
scale=scale,
)
else:
attention_output = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=mask, is_causal=is_causal, scale=scale
)
return torch.transpose(attention_output, axis1, axis0)
32 changes: 29 additions & 3 deletions keras/src/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2131,7 +2131,16 @@ def __init__(self, is_causal=False):
super().__init__()
self.is_causal = is_causal

def call(self, query, key, value, bias=None, mask=None, scale=None):
def call(
self,
query,
key,
value,
bias=None,
mask=None,
scale=None,
flash_attention=False,
):
return backend.nn.dot_product_attention(
query,
key,
Expand All @@ -2140,10 +2149,18 @@ def call(self, query, key, value, bias=None, mask=None, scale=None):
mask=mask,
scale=scale,
is_causal=self.is_causal,
flash_attention=flash_attention,
)

def compute_output_spec(
self, query, key, value, bias=None, mask=None, scale=None
self,
query,
key,
value,
bias=None,
mask=None,
scale=None,
flash_attention=False,
):
return KerasTensor(query.shape, dtype=query.dtype)

Expand All @@ -2152,7 +2169,14 @@ def compute_output_spec(
["keras.ops.dot_product_attention", "keras.ops.nn.dot_product_attention"]
)
def dot_product_attention(
query, key, value, bias=None, mask=None, scale=None, is_causal=False
query,
key,
value,
bias=None,
mask=None,
scale=None,
is_causal=False,
flash_attention=False,
):
"""Scaled dot product attention function.
Expand Down Expand Up @@ -2207,6 +2231,7 @@ def dot_product_attention(
bias=bias,
mask=mask,
scale=scale,
flash_attention=flash_attention,
)
return backend.nn.dot_product_attention(
query,
Expand All @@ -2216,4 +2241,5 @@ def dot_product_attention(
mask=mask,
scale=scale,
is_causal=is_causal,
flash_attention=flash_attention,
)
65 changes: 55 additions & 10 deletions keras/src/ops/nn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2208,9 +2208,12 @@ def test_psnr(self):
bias=(None, True),
scale=(None, 1.0),
mask_and_is_causal=((None, False), (True, False), (None, True)),
flash_attention=(True, False),
)
)
def test_dot_product_attention(self, bias, scale, mask_and_is_causal):
def test_dot_product_attention(
self, bias, scale, mask_and_is_causal, flash_attention
):
mask, is_causal = mask_and_is_causal
query_shape = (2, 3, 4, 5)
key_shape = (2, 6, 4, 5)
Expand All @@ -2232,6 +2235,57 @@ def test_dot_product_attention(self, bias, scale, mask_and_is_causal):
mask_shape
)

if flash_attention and backend.backend() in [
"torch",
"tensorflow",
"numpy",
]:
self.skipTest(
"Not supported in TF and NumPy and supported for "
"PyTorch with specific requirements."
)

if flash_attention and backend.backend() == "jax":
try:
outputs = knn.dot_product_attention(
query,
key,
value,
bias=bias,
mask=mask,
scale=scale,
is_causal=is_causal,
flash_attention=flash_attention,
)
except ValueError as e:
if e.args[0].startswith(
"Flash attention is not supported in your "
"current JAX version"
):
self.skipTest(
"JAX version does not have "
"`dot_product_attention` function."
)
except RuntimeError as e:
if e.args[0] == "cuDNN is not detected.":
self.skipTest("No CuDNN to run flash attention for JAX.")
elif e.args[0] == "Require at least Ampere arch to run":
self.skipTest(
"Requires at least Ampere arch to run flash attention "
"for JAX."
)
else:
outputs = knn.dot_product_attention(
query,
key,
value,
bias=bias,
mask=mask,
scale=scale,
is_causal=is_causal,
flash_attention=flash_attention,
)

expected = _dot_product_attention(
query,
key,
Expand All @@ -2241,15 +2295,6 @@ def test_dot_product_attention(self, bias, scale, mask_and_is_causal):
scale=scale,
is_causal=is_causal,
)
outputs = knn.dot_product_attention(
query,
key,
value,
bias=bias,
mask=mask,
scale=scale,
is_causal=is_causal,
)
self.assertAllClose(outputs, expected)


Expand Down

0 comments on commit 8e67e0e

Please sign in to comment.