Skip to content

Commit

Permalink
Enable dynamo
Browse files Browse the repository at this point in the history
  • Loading branch information
wonjoolee95 committed Aug 15, 2024
1 parent 168beca commit 4e5b318
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 1 deletion.
2 changes: 1 addition & 1 deletion test/test_pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -1082,7 +1082,7 @@ def test_splash_attention_wrapper(self):
v=v.to("xla"),
q_segment_ids=q_segment_ids.to("xla"),
kv_segment_ids=kv_segment_ids.to("xla"),
mask_value=mask_value.to("xla"),
mask_value=mask_value,
is_mqa=is_mqa,
residual_checkpoint_name=residual_checkpoint_name,
save_residuals=save_residuals,
Expand Down
67 changes: 67 additions & 0 deletions torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1161,6 +1161,73 @@ def paged_attention_non_xla(q: torch.Tensor,
return non_xla_attetion(q, k_pages, v_pages, "paged")


XLA_LIB.define(
"splash_attention(Tensor mask_info_data_next, Tensor mask_info_mask_next, Tensor mask_info_block_mask, Tensor mask_info_partial_mask_blocks, Tensor mask_info_q_sequence, Tensor q, Tensor k, Tensor v, Tensor q_segment_ids, Tensor kv_segment_ids, float mask_value, bool is_mqa, str residual_checkpoint_name=None, bool save_residuals=False, float attn_logits_soft_cap=None, bool interpret=False) -> Tensor",
)


@impl(XLA_LIB, "splash_attention", "XLA")
def splash_attention_xla(
mask_info_data_next: torch.Tensor | None,
mask_info_mask_next: torch.Tensor | None,
mask_info_block_mask: torch.Tensor | None,
mask_info_partial_mask_blocks: torch.Tensor | None,
mask_info_q_sequence: torch.Tensor | None,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
q_segment_ids: torch.Tensor | None, # [q_seq_len]
kv_segment_ids: torch.Tensor | None, # [kv_seq_len]
mask_value: float,
is_mqa: bool,
residual_checkpoint_name: str | None = None,
save_residuals: bool = False,
# mask_function: Callable | None = None,
attn_logits_soft_cap: float | None = None,
interpret: bool = False):
return splash_attention(
mask_info_data_next,
mask_info_mask_next,
mask_info_block_mask,
mask_info_partial_mask_blocks,
mask_info_q_sequence,
q,
k,
v,
q_segment_ids,
kv_segment_ids,
mask_value,
is_mqa,
residual_checkpoint_name=residual_checkpoint_name,
save_residuals=save_residuals,
mask_function=None,
attn_logits_soft_cap=attn_logits_soft_cap,
interpret=interpret,
)


@impl(XLA_LIB, "splash_attention", "CompositeExplicitAutograd")
def splash_attention_non_xla(
mask_info_data_next: torch.Tensor | None,
mask_info_mask_next: torch.Tensor | None,
mask_info_block_mask: torch.Tensor | None,
mask_info_partial_mask_blocks: torch.Tensor | None,
mask_info_q_sequence: torch.Tensor | None,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
q_segment_ids: torch.Tensor | None, # [q_seq_len]
kv_segment_ids: torch.Tensor | None, # [kv_seq_len]
mask_value: float,
is_mqa: bool,
residual_checkpoint_name: str | None = None,
save_residuals: bool = False,
# mask_function: Callable | None = None,
attn_logits_soft_cap: float | None = None,
interpret: bool = False):
return non_xla_attetion(q, k, v, "splash")


XLA_LIB.define(
"gmm(Tensor lhs, Tensor rhs, Tensor group_sizes, int[]? tiling=None) -> Tensor",
)
Expand Down

0 comments on commit 4e5b318

Please sign in to comment.