Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support SplashAttention in PyTorch/XLA #7798

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 107 additions & 0 deletions test/test_pallas.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import os
from typing import Callable
import unittest

import torch
Expand Down Expand Up @@ -1023,6 +1024,112 @@ def test_flash_attention_ab_backward_2(self):
jax.config.update("jax_default_matmul_precision", "default")


@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
"This test only works on TPUv4+.")
def test_splash_attention_wrapper(self):
from torch_xla.experimental.custom_kernel import splash_attention
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import BlockSizes, SegmentIds, _splash_attention_forward
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask_info as mask_info_lib

num_q_heads = 2
q_seq_len = 16
kv_seq_len = 16
head_dim = 64
block_q = 4
block_kv = 4
block_kv_compute = 2
dtype = torch.float32

q = torch.randn(num_q_heads, q_seq_len, head_dim, dtype=dtype)
k = torch.randn(num_q_heads, kv_seq_len, head_dim, dtype=dtype)
v = torch.randn(num_q_heads, kv_seq_len, head_dim, dtype=dtype)

data_next_torch = torch.randint(
0,
kv_seq_len, (num_q_heads, q_seq_len // block_q, kv_seq_len // block_kv),
dtype=torch.int32)
mask_next_torch = torch.randint(
0,
2, (num_q_heads, q_seq_len // block_q, kv_seq_len // block_kv),
dtype=torch.int32)
block_mask_torch = torch.randint(
0,
3, (num_q_heads, q_seq_len // block_q, kv_seq_len // block_kv),
dtype=torch.int32)
partial_mask_blocks_torch = torch.randint(
0, 2, (10, block_q, block_kv), dtype=torch.int32) # Example shape
q_sequence_torch = torch.arange(q_seq_len, dtype=torch.int32)
q_segment_ids = torch.randint(0, 2, (q_seq_len,), dtype=torch.int32)
kv_segment_ids = torch.randint(0, 2, (kv_seq_len,), dtype=torch.int32)

mask_value = -0.7 * float(torch.finfo(torch.float32).max)
is_mqa = False
residual_checkpoint_name = None
save_residuals = False
mask_function = None
attn_logits_soft_cap = None
interpret = False

# PT/XLA output
output_xla = splash_attention(
mask_info_data_next=data_next_torch.to("xla"),
mask_info_mask_next=mask_next_torch.to("xla"),
mask_info_block_mask=block_mask_torch.to("xla"),
mask_info_partial_mask_blocks=partial_mask_blocks_torch.to("xla"),
mask_info_q_sequence=q_sequence_torch.to("xla"),
q=q.to("xla"),
k=k.to("xla"),
v=v.to("xla"),
q_segment_ids=q_segment_ids.to("xla"),
kv_segment_ids=kv_segment_ids.to("xla"),
mask_value=mask_value,
is_mqa=is_mqa,
residual_checkpoint_name=residual_checkpoint_name,
save_residuals=save_residuals,
mask_function=mask_function,
attn_logits_soft_cap=attn_logits_soft_cap,
interpret=interpret)

# JAX output
jax_fwd_mask_info = mask_info_lib.MaskInfo(
data_next=data_next_torch.numpy(),
mask_next=mask_next_torch.numpy(),
block_mask=block_mask_torch.numpy(),
partial_mask_blocks=partial_mask_blocks_torch.numpy(),
q_sequence=q_sequence_torch.numpy())

jax_segment_ids = SegmentIds(
q=jnp.array(q_segment_ids.numpy(), dtype=jnp.int32),
kv=jnp.array(kv_segment_ids.numpy(), dtype=jnp.int32))

jax_mask_value = -0.7 * float(np.finfo(np.dtype("float32")).max)
jax_block_sizes = BlockSizes.get_default()

# Call JAX's splash_attention
output_jax = _splash_attention_forward(
fwd_mask_info=jax_fwd_mask_info,
q=jnp.array(q.numpy(), dtype=jnp.float32),
k=jnp.array(k.numpy(), dtype=jnp.float32),
v=jnp.array(v.numpy(), dtype=jnp.float32),
segment_ids=jax_segment_ids,
mask_value=jax_mask_value,
is_mqa=is_mqa,
block_sizes=jax_block_sizes,
residual_checkpoint_name=residual_checkpoint_name,
save_residuals=save_residuals,
mask_function=mask_function,
attn_logits_soft_cap=attn_logits_soft_cap,
interpret=interpret)

# Convert JAX output to torch for comparison
expected_output = torch.from_numpy(np.array(output_jax))

# Compare outputs
self.assertTrue(
torch.allclose(
output_xla.cpu(), expected_output.cpu(), atol=1e-5, rtol=1e-5))


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
torch.set_default_dtype(torch.float32)
Expand Down
Loading
Loading