From 4e5b3184db0f4f66952e75e2bff767eee2855c4e Mon Sep 17 00:00:00 2001 From: Wonjoo Lee Date: Thu, 15 Aug 2024 18:01:25 +0000 Subject: [PATCH] Enable dynamo --- test/test_pallas.py | 2 +- torch_xla/experimental/custom_kernel.py | 67 +++++++++++++++++++++++++ 2 files changed, 68 insertions(+), 1 deletion(-) diff --git a/test/test_pallas.py b/test/test_pallas.py index 9d31f6784f0..99ff8f222f6 100644 --- a/test/test_pallas.py +++ b/test/test_pallas.py @@ -1082,7 +1082,7 @@ def test_splash_attention_wrapper(self): v=v.to("xla"), q_segment_ids=q_segment_ids.to("xla"), kv_segment_ids=kv_segment_ids.to("xla"), - mask_value=mask_value.to("xla"), + mask_value=mask_value, is_mqa=is_mqa, residual_checkpoint_name=residual_checkpoint_name, save_residuals=save_residuals, diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index 31ddc9bd8bc..1ea1757a024 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -1161,6 +1161,73 @@ def paged_attention_non_xla(q: torch.Tensor, return non_xla_attetion(q, k_pages, v_pages, "paged") +XLA_LIB.define( + "splash_attention(Tensor mask_info_data_next, Tensor mask_info_mask_next, Tensor mask_info_block_mask, Tensor mask_info_partial_mask_blocks, Tensor mask_info_q_sequence, Tensor q, Tensor k, Tensor v, Tensor q_segment_ids, Tensor kv_segment_ids, float mask_value, bool is_mqa, str residual_checkpoint_name=None, bool save_residuals=False, float attn_logits_soft_cap=None, bool interpret=False) -> Tensor", +) + + +@impl(XLA_LIB, "splash_attention", "XLA") +def splash_attention_xla( + mask_info_data_next: torch.Tensor | None, + mask_info_mask_next: torch.Tensor | None, + mask_info_block_mask: torch.Tensor | None, + mask_info_partial_mask_blocks: torch.Tensor | None, + mask_info_q_sequence: torch.Tensor | None, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + q_segment_ids: torch.Tensor | None, # [q_seq_len] + kv_segment_ids: torch.Tensor | None, # [kv_seq_len] + mask_value: float, + is_mqa: bool, + residual_checkpoint_name: str | None = None, + save_residuals: bool = False, + # mask_function: Callable | None = None, + attn_logits_soft_cap: float | None = None, + interpret: bool = False): + return splash_attention( + mask_info_data_next, + mask_info_mask_next, + mask_info_block_mask, + mask_info_partial_mask_blocks, + mask_info_q_sequence, + q, + k, + v, + q_segment_ids, + kv_segment_ids, + mask_value, + is_mqa, + residual_checkpoint_name=residual_checkpoint_name, + save_residuals=save_residuals, + mask_function=None, + attn_logits_soft_cap=attn_logits_soft_cap, + interpret=interpret, + ) + + +@impl(XLA_LIB, "splash_attention", "CompositeExplicitAutograd") +def splash_attention_non_xla( + mask_info_data_next: torch.Tensor | None, + mask_info_mask_next: torch.Tensor | None, + mask_info_block_mask: torch.Tensor | None, + mask_info_partial_mask_blocks: torch.Tensor | None, + mask_info_q_sequence: torch.Tensor | None, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + q_segment_ids: torch.Tensor | None, # [q_seq_len] + kv_segment_ids: torch.Tensor | None, # [kv_seq_len] + mask_value: float, + is_mqa: bool, + residual_checkpoint_name: str | None = None, + save_residuals: bool = False, + # mask_function: Callable | None = None, + attn_logits_soft_cap: float | None = None, + interpret: bool = False): + return non_xla_attetion(q, k, v, "splash") + + XLA_LIB.define( "gmm(Tensor lhs, Tensor rhs, Tensor group_sizes, int[]? tiling=None) -> Tensor", )