diff --git a/test/test_pallas.py b/test/test_pallas.py index 4e7fb90cbf7..99ff8f222f6 100644 --- a/test/test_pallas.py +++ b/test/test_pallas.py @@ -1,5 +1,6 @@ import logging import os +from typing import Callable import unittest import torch @@ -1023,6 +1024,112 @@ def test_flash_attention_ab_backward_2(self): jax.config.update("jax_default_matmul_precision", "default") + @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4, + "This test only works on TPUv4+.") + def test_splash_attention_wrapper(self): + from torch_xla.experimental.custom_kernel import splash_attention + from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import BlockSizes, SegmentIds, _splash_attention_forward + from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask_info as mask_info_lib + + num_q_heads = 2 + q_seq_len = 16 + kv_seq_len = 16 + head_dim = 64 + block_q = 4 + block_kv = 4 + block_kv_compute = 2 + dtype = torch.float32 + + q = torch.randn(num_q_heads, q_seq_len, head_dim, dtype=dtype) + k = torch.randn(num_q_heads, kv_seq_len, head_dim, dtype=dtype) + v = torch.randn(num_q_heads, kv_seq_len, head_dim, dtype=dtype) + + data_next_torch = torch.randint( + 0, + kv_seq_len, (num_q_heads, q_seq_len // block_q, kv_seq_len // block_kv), + dtype=torch.int32) + mask_next_torch = torch.randint( + 0, + 2, (num_q_heads, q_seq_len // block_q, kv_seq_len // block_kv), + dtype=torch.int32) + block_mask_torch = torch.randint( + 0, + 3, (num_q_heads, q_seq_len // block_q, kv_seq_len // block_kv), + dtype=torch.int32) + partial_mask_blocks_torch = torch.randint( + 0, 2, (10, block_q, block_kv), dtype=torch.int32) # Example shape + q_sequence_torch = torch.arange(q_seq_len, dtype=torch.int32) + q_segment_ids = torch.randint(0, 2, (q_seq_len,), dtype=torch.int32) + kv_segment_ids = torch.randint(0, 2, (kv_seq_len,), dtype=torch.int32) + + mask_value = -0.7 * float(torch.finfo(torch.float32).max) + is_mqa = False + residual_checkpoint_name = None + save_residuals = False + mask_function = None + attn_logits_soft_cap = None + interpret = False + + # PT/XLA output + output_xla = splash_attention( + mask_info_data_next=data_next_torch.to("xla"), + mask_info_mask_next=mask_next_torch.to("xla"), + mask_info_block_mask=block_mask_torch.to("xla"), + mask_info_partial_mask_blocks=partial_mask_blocks_torch.to("xla"), + mask_info_q_sequence=q_sequence_torch.to("xla"), + q=q.to("xla"), + k=k.to("xla"), + v=v.to("xla"), + q_segment_ids=q_segment_ids.to("xla"), + kv_segment_ids=kv_segment_ids.to("xla"), + mask_value=mask_value, + is_mqa=is_mqa, + residual_checkpoint_name=residual_checkpoint_name, + save_residuals=save_residuals, + mask_function=mask_function, + attn_logits_soft_cap=attn_logits_soft_cap, + interpret=interpret) + + # JAX output + jax_fwd_mask_info = mask_info_lib.MaskInfo( + data_next=data_next_torch.numpy(), + mask_next=mask_next_torch.numpy(), + block_mask=block_mask_torch.numpy(), + partial_mask_blocks=partial_mask_blocks_torch.numpy(), + q_sequence=q_sequence_torch.numpy()) + + jax_segment_ids = SegmentIds( + q=jnp.array(q_segment_ids.numpy(), dtype=jnp.int32), + kv=jnp.array(kv_segment_ids.numpy(), dtype=jnp.int32)) + + jax_mask_value = -0.7 * float(np.finfo(np.dtype("float32")).max) + jax_block_sizes = BlockSizes.get_default() + + # Call JAX's splash_attention + output_jax = _splash_attention_forward( + fwd_mask_info=jax_fwd_mask_info, + q=jnp.array(q.numpy(), dtype=jnp.float32), + k=jnp.array(k.numpy(), dtype=jnp.float32), + v=jnp.array(v.numpy(), dtype=jnp.float32), + segment_ids=jax_segment_ids, + mask_value=jax_mask_value, + is_mqa=is_mqa, + block_sizes=jax_block_sizes, + residual_checkpoint_name=residual_checkpoint_name, + save_residuals=save_residuals, + mask_function=mask_function, + attn_logits_soft_cap=attn_logits_soft_cap, + interpret=interpret) + + # Convert JAX output to torch for comparison + expected_output = torch.from_numpy(np.array(output_jax)) + + # Compare outputs + self.assertTrue( + torch.allclose( + output_xla.cpu(), expected_output.cpu(), atol=1e-5, rtol=1e-5)) + + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) torch.set_default_dtype(torch.float32) diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index 6c5b2e19466..1ea1757a024 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -512,6 +512,240 @@ def paged_attention(q, return output.reshape(batch_size, num_heads, head_dim).to(q.dtype) +def _splash_attention_forward( + # fwd_mask_info: mask_info_lib.MaskInfo, + mask_info_data_next: torch.Tensor | None, + mask_info_mask_next: torch.Tensor | None, + mask_info_block_mask: torch.Tensor | None, + mask_info_partial_mask_blocks: torch.Tensor | None, + mask_info_q_sequence: torch.Tensor | None, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + q_segment_ids: torch.Tensor | None, # [q_seq_len] + kv_segment_ids: torch.Tensor | None, # [kv_seq_len] + mask_value: float, + is_mqa: bool, + # TODO do we want to accept custom block_sizes as a parameter? + # block_sizes: BlockSizes, + residual_checkpoint_name: str | None = None, + save_residuals: bool = False, + # TODO mask_function is a custom callable type. How can we allow Dynamo to trace it? + mask_function: Callable | None = None, + attn_logits_soft_cap: float | None = None, + interpret: bool = False): + # Import JAX within the function such that we don't need to call the jax_import_guard() + # in the global scope which could cause problems for xmp.spawn. + jax_import_guard() + from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import BlockSizes, SegmentIds, QKVLayout, _splash_attention_forward + from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask_info as mask_info_lib + + # Define default values + block_sizes = BlockSizes.get_default( + ) # Should we define and optimize block sizes? + num_lanes = 128 # From https://github.com/google/jax/blob/main/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py#L38 + num_sublanes = 8 + + # TODO For mask_info and segment_ids below, do we need the actual values or will the shape suffice? + mask_info = mask_info_lib.MaskInfo( + data_next=mask_info_data_next.cpu().numpy() + if mask_info_data_next is not None else None, + mask_next=mask_info_mask_next.cpu().numpy() + if mask_info_mask_next is not None else None, + block_mask=mask_info_block_mask.cpu().numpy() + if mask_info_block_mask is not None else None, + partial_mask_blocks=mask_info_partial_mask_blocks.cpu().numpy() + if mask_info_partial_mask_blocks is not None else None, + q_sequence=mask_info_q_sequence.cpu().numpy() + if mask_info_q_sequence is not None else None, + ) + + segment_ids = None + if q_segment_ids is not None or kv_segment_ids is not None: + segment_ids = SegmentIds( + q=q_segment_ids.cpu().numpy() + if q_segment_ids is not None else None, # [q_seq_len] + k=kv_segment_ids.cpu().numpy() + if kv_segment_ids is not None else None, # [kv_seq_len] + ) + + num_q_heads, q_seq_len, head_dim = q.shape + bq, bkv = block_sizes.block_q, block_sizes.block_kv + bkv_compute = block_sizes.block_kv_compute + + if is_mqa: + expected_kv_rank = 2 + kv_head_dimension = 1 + kv_seq_len_dimension = 0 + num_kv_heads = 1 + else: + expected_kv_rank = 3 + kv_head_dimension = 2 + kv_seq_len_dimension = 1 + num_kv_heads = k.shape[0] + + if len(k.shape) != expected_kv_rank: + raise ValueError( + f"Expected {expected_kv_rank}-dim 'key' tensor for MQA. Instead got a" + f" {len(k.shape)}-dim one.") + + if k.shape[kv_head_dimension] != head_dim: + raise ValueError( + f"Expected 'key' head dimension to be: {head_dim}. Instead got:" + f" {k.shape[kv_head_dimension]}.") + + if not is_mqa and num_q_heads % num_kv_heads != 0: + raise ValueError( + f"In MHA, expected number of 'key' heads ({num_kv_heads}) to be a" + f" multiple of the number of 'query' heads ({num_q_heads})") + + if k.shape != v.shape: + raise ValueError( + f"Expected 'key' {k.shape} and 'value' {v.shape} to have the same" + " shape.") + + if bkv % bkv_compute: + raise ValueError(f"{bkv=} must be a multiple of {bkv_compute=}.") + if bkv_compute % num_lanes: + raise ValueError(f"{bkv_compute=} must be a multiple of {num_lanes}.") + + kv_seq_len = k.shape[kv_seq_len_dimension] + + q_heads_per_kv_head = num_q_heads // num_kv_heads + + if segment_ids is not None: + if segment_ids.q.shape != (q_seq_len,): + raise ValueError("Invalid shape for q segment_ids: " + f"{segment_ids.q.shape}. Expected: {(q_seq_len,)}") + if segment_ids.kv.shape != (kv_seq_len,): + raise ValueError("Invalid shape for kv segment_ids: " + f"{segment_ids.kv.shape}. Expected: {(kv_seq_len,)}") + # q_segment_ids = jax.lax.broadcast_in_dim( + # segment_ids.q, (q_seq_len, NUM_LANES), (0,) + # ) + q_segment_ids = q_segment_ids.unsqueeze(1).expand(q_seq_len, num_lanes) + # kv_segment_ids = jax.lax.broadcast_in_dim( + # segment_ids.kv, (NUM_SUBLANES, kv_seq_len), (1,) + # ) + kv_segment_ids = kv_segment_ids.unsqueeze(0).expand(num_sublanes, + kv_seq_len) + else: + q_segment_ids = kv_segment_ids = None + + if mask_info_q_sequence is not None: + # q_sequence = jax.lax.broadcast_in_dim( + # fwd_mask_info.q_sequence, (q_seq_len, NUM_LANES), (0,) + # ) + q_sequence = mask_info_q_sequence.unsqueeze(1).expand(q_seq_len, num_lanes) + else: + q_sequence = None + + payload, tensor_args = trace_pallas( + _splash_attention_forward, + mask_info, + q, + k, + v, + segment_ids, + mask_value, + is_mqa, + block_sizes, + residual_checkpoint_name, + save_residuals, + mask_function, + attn_logits_soft_cap=attn_logits_soft_cap, + interpret=interpret, + static_argnames=[ + "is_mqa", + "block_sizes", + "save_residuals", + "mask_value", + "attn_logits_soft_cap", + "residual_checkpoint_name", + "mask_function", + "interpret", + ], + ) + + q_layout = block_sizes.q_layout + k_layout = block_sizes.k_layout + v_layout = block_sizes.v_layout + + out_shapes = [ + # jax.ShapeDtypeStruct((bq, NUM_LANES), jnp.float32), # m_scratch + torch.Size([bq, num_lanes]), + # jax.ShapeDtypeStruct((bq, NUM_LANES), jnp.float32), # l_scratch + torch.Size([bq, num_lanes]), + # jax.ShapeDtypeStruct((bq, head_dim), jnp.float32), # o_scratch + torch.Size([bq, head_dim]), + # jax.ShapeDtypeStruct((num_q_heads, q_seq_len, head_dim), q.dtype), + torch.Size([num_q_heads, q_seq_len, head_dim]), + ] + + out_d_types = [torch.float32, torch.float32, torch.float32, q.dtype] + + _, _, _, out, logsumexp = torch_xla._XLAC._xla_tpu_custom_call([ + mask_info.data_next, + mask_info.block_mask, + mask_info.mask_next, + q if q_layout == QKVLayout.HEAD_DIM_MINOR else q.transpose(-1, -2), + k if k_layout == QKVLayout.HEAD_DIM_MINOR else k.transpose(-1, -2), + v if v_layout == QKVLayout.HEAD_DIM_MINOR else v.transpose(-1, -2), + q_segment_ids, + kv_segment_ids, + mask_info.partial_mask_blocks, + q_sequence, + ], payload, out_shapes, out_d_types) + + return out + + +def splash_attention( + # fwd_mask_info: mask_info_lib.MaskInfo, + mask_info_data_next: torch.Tensor | None, + mask_info_mask_next: torch.Tensor | None, + mask_info_block_mask: torch.Tensor | None, + mask_info_partial_mask_blocks: torch.Tensor | None, + mask_info_q_sequence: torch.Tensor | None, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + q_segment_ids: torch.Tensor | None, # [q_seq_len] + kv_segment_ids: torch.Tensor | None, # [kv_seq_len] + mask_value: float, + is_mqa: bool, + # TODO do we want to accept custom block_sizes as a parameter? + # block_sizes: BlockSizes, + residual_checkpoint_name: str | None = None, + save_residuals: bool = False, + # TODO mask_function is a custom callable type. How can we allow Dynamo to trace it? + mask_function: Callable | None = None, + attn_logits_soft_cap: float | None = None, + interpret: bool = False): + # TODO handle backward + return _splash_attention_forward( + # fwd_mask_info: mask_info_lib.MaskInfo, + mask_info_data_next, + mask_info_mask_next, + mask_info_block_mask, + mask_info_partial_mask_blocks, + mask_info_q_sequence, + q, + k, + v, + q_segment_ids, # [q_seq_len] + kv_segment_ids, # [kv_seq_len] + mask_value, + is_mqa, + # block_sizes: BlockSizes, + residual_checkpoint_name=residual_checkpoint_name, + save_residuals=save_residuals, + mask_function=mask_function, + attn_logits_soft_cap=mask_function, + interpret=interpret, + ) + + def _calculate_num_tiles(x: int, tx: int) -> int: tiles, rem = divmod(x, tx) if rem: @@ -927,6 +1161,73 @@ def paged_attention_non_xla(q: torch.Tensor, return non_xla_attetion(q, k_pages, v_pages, "paged") +XLA_LIB.define( + "splash_attention(Tensor mask_info_data_next, Tensor mask_info_mask_next, Tensor mask_info_block_mask, Tensor mask_info_partial_mask_blocks, Tensor mask_info_q_sequence, Tensor q, Tensor k, Tensor v, Tensor q_segment_ids, Tensor kv_segment_ids, float mask_value, bool is_mqa, str residual_checkpoint_name=None, bool save_residuals=False, float attn_logits_soft_cap=None, bool interpret=False) -> Tensor", +) + + +@impl(XLA_LIB, "splash_attention", "XLA") +def splash_attention_xla( + mask_info_data_next: torch.Tensor | None, + mask_info_mask_next: torch.Tensor | None, + mask_info_block_mask: torch.Tensor | None, + mask_info_partial_mask_blocks: torch.Tensor | None, + mask_info_q_sequence: torch.Tensor | None, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + q_segment_ids: torch.Tensor | None, # [q_seq_len] + kv_segment_ids: torch.Tensor | None, # [kv_seq_len] + mask_value: float, + is_mqa: bool, + residual_checkpoint_name: str | None = None, + save_residuals: bool = False, + # mask_function: Callable | None = None, + attn_logits_soft_cap: float | None = None, + interpret: bool = False): + return splash_attention( + mask_info_data_next, + mask_info_mask_next, + mask_info_block_mask, + mask_info_partial_mask_blocks, + mask_info_q_sequence, + q, + k, + v, + q_segment_ids, + kv_segment_ids, + mask_value, + is_mqa, + residual_checkpoint_name=residual_checkpoint_name, + save_residuals=save_residuals, + mask_function=None, + attn_logits_soft_cap=attn_logits_soft_cap, + interpret=interpret, + ) + + +@impl(XLA_LIB, "splash_attention", "CompositeExplicitAutograd") +def splash_attention_non_xla( + mask_info_data_next: torch.Tensor | None, + mask_info_mask_next: torch.Tensor | None, + mask_info_block_mask: torch.Tensor | None, + mask_info_partial_mask_blocks: torch.Tensor | None, + mask_info_q_sequence: torch.Tensor | None, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + q_segment_ids: torch.Tensor | None, # [q_seq_len] + kv_segment_ids: torch.Tensor | None, # [kv_seq_len] + mask_value: float, + is_mqa: bool, + residual_checkpoint_name: str | None = None, + save_residuals: bool = False, + # mask_function: Callable | None = None, + attn_logits_soft_cap: float | None = None, + interpret: bool = False): + return non_xla_attetion(q, k, v, "splash") + + XLA_LIB.define( "gmm(Tensor lhs, Tensor rhs, Tensor group_sizes, int[]? tiling=None) -> Tensor", )