Skip to content

Commit

Permalink
Update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
wonjoolee95 committed Aug 15, 2024
1 parent 3d9261d commit 168beca
Showing 1 changed file with 90 additions and 47 deletions.
137 changes: 90 additions & 47 deletions test/test_pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -1028,63 +1028,106 @@ def test_flash_attention_ab_backward_2(self):
"This test only works on TPUv4+.")
def test_splash_attention_wrapper(self):
from torch_xla.experimental.custom_kernel import splash_attention
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import BlockSizes, SegmentIds, QKVLayout, _splash_attention_forward
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import BlockSizes, SegmentIds, _splash_attention_forward
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask_info as mask_info_lib

# Example input data for unit test
num_q_heads = 2
q_seq_len = 16
kv_seq_len = 16
head_dim = 64
block_q = 4
block_kv = 4
block_kv_compute = 2
dtype = torch.float32

fwd_mask_info: mask_info_lib.MaskInfo = mask_info_lib.MaskInfo(
q_sequence=torch.randint(0, 2, (q_seq_len,)),
data_next=torch.randn(num_q_heads, q_seq_len, kv_seq_len),
block_mask=torch.randint(0, 2, (num_q_heads, q_seq_len, kv_seq_len)),
mask_next=torch.randint(0, 2, (num_q_heads, q_seq_len, kv_seq_len)),
partial_mask_blocks=torch.randint(0, 2,
(num_q_heads, q_seq_len, kv_seq_len)),
)

q_xla = torch.randn(num_q_heads, q_seq_len, head_dim).to("xla")
k_xla = torch.randn(num_q_heads, q_seq_len, head_dim).to("xla")
v_xla = torch.randn(num_q_heads, q_seq_len, head_dim).to("xla")
segment_ids: SegmentIds | None = None
mask_value: float = -0.7 * float(torch.finfo(torch.float32).max)
is_mqa: bool = False
block_sizes: BlockSizes = None
residual_checkpoint_name: str | None = None
save_residuals: bool = False
mask_function: Callable | None = None
attn_logits_soft_cap: float | None = None
interpret: bool = False

output = splash_attention(
mask_info_data_next,
mask_info_mask_next,
mask_info_block_mask,
mask_info_partial_mask_blocks,
mask_info_q_sequence,
q_xla,
k_xla,
v_xla,
q_segment_ids, # [q_seq_len]
kv_segment_ids, # [kv_seq_len]
mask_value,
is_mqa,
)

jax_output = _splash_attention_forward(fwd_mask_info, q_jax, k_jax, v_jax,
segment_ids, mask_value, is_mqa,
block_sizes,
residual_checkpoint_name,
save_residuals, mask_function,
attn_logits_soft_cap, interpret)

self._assert_allclose(output, jax_output, atol=3e-3, rtol=3e-3)
q = torch.randn(num_q_heads, q_seq_len, head_dim, dtype=dtype)
k = torch.randn(num_q_heads, kv_seq_len, head_dim, dtype=dtype)
v = torch.randn(num_q_heads, kv_seq_len, head_dim, dtype=dtype)

data_next_torch = torch.randint(
0,
kv_seq_len, (num_q_heads, q_seq_len // block_q, kv_seq_len // block_kv),
dtype=torch.int32)
mask_next_torch = torch.randint(
0,
2, (num_q_heads, q_seq_len // block_q, kv_seq_len // block_kv),
dtype=torch.int32)
block_mask_torch = torch.randint(
0,
3, (num_q_heads, q_seq_len // block_q, kv_seq_len // block_kv),
dtype=torch.int32)
partial_mask_blocks_torch = torch.randint(
0, 2, (10, block_q, block_kv), dtype=torch.int32) # Example shape
q_sequence_torch = torch.arange(q_seq_len, dtype=torch.int32)
q_segment_ids = torch.randint(0, 2, (q_seq_len,), dtype=torch.int32)
kv_segment_ids = torch.randint(0, 2, (kv_seq_len,), dtype=torch.int32)

mask_value = -0.7 * float(torch.finfo(torch.float32).max)
is_mqa = False
residual_checkpoint_name = None
save_residuals = False
mask_function = None
attn_logits_soft_cap = None
interpret = False

# PT/XLA output
output_xla = splash_attention(
mask_info_data_next=data_next_torch.to("xla"),
mask_info_mask_next=mask_next_torch.to("xla"),
mask_info_block_mask=block_mask_torch.to("xla"),
mask_info_partial_mask_blocks=partial_mask_blocks_torch.to("xla"),
mask_info_q_sequence=q_sequence_torch.to("xla"),
q=q.to("xla"),
k=k.to("xla"),
v=v.to("xla"),
q_segment_ids=q_segment_ids.to("xla"),
kv_segment_ids=kv_segment_ids.to("xla"),
mask_value=mask_value.to("xla"),
is_mqa=is_mqa,
residual_checkpoint_name=residual_checkpoint_name,
save_residuals=save_residuals,
mask_function=mask_function,
attn_logits_soft_cap=attn_logits_soft_cap,
interpret=interpret)

# JAX output
jax_fwd_mask_info = mask_info_lib.MaskInfo(
data_next=data_next_torch.numpy(),
mask_next=mask_next_torch.numpy(),
block_mask=block_mask_torch.numpy(),
partial_mask_blocks=partial_mask_blocks_torch.numpy(),
q_sequence=q_sequence_torch.numpy())

jax_segment_ids = SegmentIds(
q=jnp.array(q_segment_ids.numpy(), dtype=jnp.int32),
kv=jnp.array(kv_segment_ids.numpy(), dtype=jnp.int32))

jax_mask_value = -0.7 * float(np.finfo(np.dtype("float32")).max)
jax_block_sizes = BlockSizes.get_default()

# Call JAX's splash_attention
output_jax = _splash_attention_forward(
fwd_mask_info=jax_fwd_mask_info,
q=jnp.array(q.numpy(), dtype=jnp.float32),
k=jnp.array(k.numpy(), dtype=jnp.float32),
v=jnp.array(v.numpy(), dtype=jnp.float32),
segment_ids=jax_segment_ids,
mask_value=jax_mask_value,
is_mqa=is_mqa,
block_sizes=jax_block_sizes,
residual_checkpoint_name=residual_checkpoint_name,
save_residuals=save_residuals,
mask_function=mask_function,
attn_logits_soft_cap=attn_logits_soft_cap,
interpret=interpret)

# Convert JAX output to torch for comparison
expected_output = torch.from_numpy(np.array(output_jax))

# Compare outputs
self.assertTrue(
torch.allclose(
output_xla.cpu(), expected_output.cpu(), atol=1e-5, rtol=1e-5))


if __name__ == '__main__':
Expand Down

0 comments on commit 168beca

Please sign in to comment.