diff --git a/keras/src/backend/jax/nn.py b/keras/src/backend/jax/nn.py index b549b3517e2..cba73918976 100644 --- a/keras/src/backend/jax/nn.py +++ b/keras/src/backend/jax/nn.py @@ -986,7 +986,14 @@ def _dot_product_attention_core( def dot_product_attention( - query, key, value, bias=None, mask=None, scale=None, is_causal=False + query, + key, + value, + bias=None, + mask=None, + scale=None, + is_causal=False, + flash_attention=False, ): query = convert_to_tensor(query) key = convert_to_tensor(key) @@ -1000,6 +1007,7 @@ def dot_product_attention( # `dot_product_attention` is only available in jax>=0.4.31 if hasattr(jax.nn, "dot_product_attention"): + implementation = "cudnn" if flash_attention else "xla" return jax.nn.dot_product_attention( query, key, @@ -1008,6 +1016,14 @@ def dot_product_attention( mask=mask, scale=scale, is_causal=is_causal, + implementation=implementation, + ) + + if flash_attention: + raise ValueError( + "Flash attention is not supported in your " + "current JAX version. Please update it " + "using `pip install -U jax jaxlib`." ) # Ref: jax.nn.dot_product_attention diff --git a/keras/src/backend/numpy/nn.py b/keras/src/backend/numpy/nn.py index f3e02d6d5a9..eea127e554a 100644 --- a/keras/src/backend/numpy/nn.py +++ b/keras/src/backend/numpy/nn.py @@ -1033,8 +1033,17 @@ def _dot_product_attention_xla(query, key, value, bias, mask, is_causal, scale): def dot_product_attention( - query, key, value, bias=None, mask=None, scale=None, is_causal=False + query, + key, + value, + bias=None, + mask=None, + scale=None, + is_causal=False, + flash_attention=False, ): + if flash_attention: + raise ValueError("Flash attention is not implemented in NumPy.") # Ref: jax.nn.dot_product_attention # https://github.com/jax-ml/jax/blob/jax-v0.4.32/jax/_src/nn/functions.py#L828 # Not support `query_seq_lengths` and `key_value_seq_lengths` args diff --git a/keras/src/backend/tensorflow/nn.py b/keras/src/backend/tensorflow/nn.py index bc7c1e61486..01a1aca26d0 100644 --- a/keras/src/backend/tensorflow/nn.py +++ b/keras/src/backend/tensorflow/nn.py @@ -964,8 +964,20 @@ def _dot_product_attention_xla(query, key, value, bias, mask, is_causal, scale): def dot_product_attention( - query, key, value, bias=None, mask=None, scale=None, is_causal=False + query, + key, + value, + bias=None, + mask=None, + scale=None, + is_causal=False, + flash_attention=False, ): + if flash_attention: + raise ValueError( + "Flash attention is not supported yet in TensorFlow backend." + ) + # Ref: jax.nn.dot_product_attention # https://github.com/jax-ml/jax/blob/jax-v0.4.32/jax/_src/nn/functions.py#L828 # Not support `query_seq_lengths` and `key_value_seq_lengths` args diff --git a/keras/src/backend/torch/nn.py b/keras/src/backend/torch/nn.py index 449c0976aff..e4291f6b84c 100644 --- a/keras/src/backend/torch/nn.py +++ b/keras/src/backend/torch/nn.py @@ -864,8 +864,28 @@ def _get_large_negative(dtype): return convert_to_tensor(val * -0.7, dtype=dtype) +def is_flash_attention_enabled(query, key, value, mask=None, is_causal=False): + params = torch.backends.cuda.SDPAParams( + query, + key, + value, + mask, + 0.0, + is_causal, + ) + is_enabled = torch.backends.cuda.can_use_flash_attention(params, False) + return is_enabled + + def dot_product_attention( - query, key, value, bias=None, mask=None, scale=None, is_causal=False + query, + key, + value, + bias=None, + mask=None, + scale=None, + is_causal=False, + flash_attention=False, ): if bias is not None: raise ValueError( @@ -891,7 +911,35 @@ def dot_product_attention( query = torch.transpose(query, axis0, axis1) key = torch.transpose(key, axis0, axis1) value = torch.transpose(value, axis0, axis1) - attention_output = torch.nn.functional.scaled_dot_product_attention( - query, key, value, attn_mask=mask, is_causal=is_causal, scale=scale - ) + + if flash_attention: + is_enabled = is_flash_attention_enabled( + query=query, + key=key, + value=value, + mask=mask, + is_causal=is_causal, + ) + if not is_enabled: + raise ValueError( + "Flash attention is not enabled in `torch` backend. " + "The dtype of the inputs should be float16/bfloat16 " + "and your GPU should support flash attention implementation." + ) + + with torch.nn.attention.sdpa_kernel( + backends=[torch.nn.attention.SDPBackend.FLASH_ATTENTION], + ): + attention_output = torch.nn.functional.scaled_dot_product_attention( + query, + key, + value, + attn_mask=mask, + is_causal=is_causal, + scale=scale, + ) + else: + attention_output = torch.nn.functional.scaled_dot_product_attention( + query, key, value, attn_mask=mask, is_causal=is_causal, scale=scale + ) return torch.transpose(attention_output, axis1, axis0) diff --git a/keras/src/ops/nn.py b/keras/src/ops/nn.py index c0f65dc87cc..2d779582a5b 100644 --- a/keras/src/ops/nn.py +++ b/keras/src/ops/nn.py @@ -2131,7 +2131,16 @@ def __init__(self, is_causal=False): super().__init__() self.is_causal = is_causal - def call(self, query, key, value, bias=None, mask=None, scale=None): + def call( + self, + query, + key, + value, + bias=None, + mask=None, + scale=None, + flash_attention=False, + ): return backend.nn.dot_product_attention( query, key, @@ -2140,10 +2149,18 @@ def call(self, query, key, value, bias=None, mask=None, scale=None): mask=mask, scale=scale, is_causal=self.is_causal, + flash_attention=flash_attention, ) def compute_output_spec( - self, query, key, value, bias=None, mask=None, scale=None + self, + query, + key, + value, + bias=None, + mask=None, + scale=None, + flash_attention=False, ): return KerasTensor(query.shape, dtype=query.dtype) @@ -2152,7 +2169,14 @@ def compute_output_spec( ["keras.ops.dot_product_attention", "keras.ops.nn.dot_product_attention"] ) def dot_product_attention( - query, key, value, bias=None, mask=None, scale=None, is_causal=False + query, + key, + value, + bias=None, + mask=None, + scale=None, + is_causal=False, + flash_attention=False, ): """Scaled dot product attention function. @@ -2207,6 +2231,7 @@ def dot_product_attention( bias=bias, mask=mask, scale=scale, + flash_attention=flash_attention, ) return backend.nn.dot_product_attention( query, @@ -2216,4 +2241,5 @@ def dot_product_attention( mask=mask, scale=scale, is_causal=is_causal, + flash_attention=flash_attention, ) diff --git a/keras/src/ops/nn_test.py b/keras/src/ops/nn_test.py index fe8d34fc656..4d75760d894 100644 --- a/keras/src/ops/nn_test.py +++ b/keras/src/ops/nn_test.py @@ -2208,9 +2208,12 @@ def test_psnr(self): bias=(None, True), scale=(None, 1.0), mask_and_is_causal=((None, False), (True, False), (None, True)), + flash_attention=(True, False), ) ) - def test_dot_product_attention(self, bias, scale, mask_and_is_causal): + def test_dot_product_attention( + self, bias, scale, mask_and_is_causal, flash_attention + ): mask, is_causal = mask_and_is_causal query_shape = (2, 3, 4, 5) key_shape = (2, 6, 4, 5) @@ -2232,6 +2235,57 @@ def test_dot_product_attention(self, bias, scale, mask_and_is_causal): mask_shape ) + if flash_attention and backend.backend() in [ + "torch", + "tensorflow", + "numpy", + ]: + self.skipTest( + "Not supported in TF and NumPy and supported for " + "PyTorch with specific requirements." + ) + + if flash_attention and backend.backend() == "jax": + try: + outputs = knn.dot_product_attention( + query, + key, + value, + bias=bias, + mask=mask, + scale=scale, + is_causal=is_causal, + flash_attention=flash_attention, + ) + except ValueError as e: + if e.args[0].startswith( + "Flash attention is not supported in your " + "current JAX version" + ): + self.skipTest( + "JAX version does not have " + "`dot_product_attention` function." + ) + except RuntimeError as e: + if e.args[0] == "cuDNN is not detected.": + self.skipTest("No CuDNN to run flash attention for JAX.") + elif e.args[0] == "Require at least Ampere arch to run": + self.skipTest( + "Requires at least Ampere arch to run flash attention " + "for JAX." + ) + else: + outputs = knn.dot_product_attention( + query, + key, + value, + bias=bias, + mask=mask, + scale=scale, + is_causal=is_causal, + flash_attention=flash_attention, + ) + expected = _dot_product_attention( query, key, @@ -2241,15 +2295,6 @@ def test_dot_product_attention(self, bias, scale, mask_and_is_causal): scale=scale, is_causal=is_causal, ) - outputs = knn.dot_product_attention( - query, - key, - value, - bias=bias, - mask=mask, - scale=scale, - is_causal=is_causal, - ) self.assertAllClose(outputs, expected)