diff --git a/test/test_pallas.py b/test/test_pallas.py index 4223053573ae..066f7871abe0 100644 --- a/test/test_pallas.py +++ b/test/test_pallas.py @@ -12,8 +12,9 @@ import numpy as np if xr.device_type() == 'TPU': - from torch_xla.experimental.custom_kernel import jax_import_guard - jax_import_guard() + # from torch_xla.experimental.custom_kernel import jax_import_guard + # jax_import_guard() + torch_xla._XLAC._init_computation_client() import jax import jax.numpy as jnp from jax.experimental import pallas as pl @@ -488,6 +489,79 @@ def test_flash_attention_backward(self): self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-05)) jax.config.update("jax_default_matmul_precision", "default") + + @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4, + "This test only works on TPUv4+.") + def test_flash_attention_backward_aot_autograd_traceable(self): + from functorch.compile import aot_function, make_boxed_func + from torch_xla.experimental.custom_kernel import flash_attention, FlashAttention, flash_attention_compilable + import torch_xla.core.xla_model as xm + jax.config.update("jax_default_matmul_precision", "highest") + def compiler(gm, _): + print("Got graph:") + print(gm.code) + return make_boxed_func(gm) + + torch.manual_seed(42) + q = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") + k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") + v = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") + q.retain_grad() + k.retain_grad() + v.retain_grad() + B, N, SEQ, H = q.size() + causal = True + q_segment_ids = None + kv_segment_ids = None + sm_scale = 1.0 + mask = (torch.rand(4, 2, 128, 128) > 0.5).to("xla") + # ab = torch.ones(4, 2, 128, 128).to("xla") + # ab = ab.masked_fill(mask, torch.finfo(ab.dtype).min).requires_grad_(True) + # ab.retain_grad() + ab = None + partition_spec = ('fsdp', 'tensor', None, None) + # partition_spec = None + import torch_xla.runtime as xr + from torch_xla.distributed.spmd import Mesh + xr.use_spmd() + num_devices = xr.global_runtime_device_count() + mesh_shape = (num_devices // 2, 2) + device_ids = np.array(range(num_devices)) + mesh = Mesh(device_ids, mesh_shape, ('fsdp', 'tensor')) + + def flash_attention_wrapper(q, k, v, casual, q_segment_ids, kv_segment_ids, sm_scale, ab, partition_spec, mesh): + return flash_attention_compilable(q, k, v, casual, q_segment_ids, kv_segment_ids, sm_scale, ab=ab, partition_spec=partition_spec, mesh=mesh) + + + # AOT compatiable funtion only accepts argument types listed https://github.com/pytorch/pytorch/blob/82859f61857ef39898b34a5cdf0ae56ec25704d9/torch/_functorch/_aot_autograd/utils.py#L23-L34, so we serliaze partition_spec and mesh into string. + # compiled_flash_attention = aot_function( + # flash_attention_wrapper, fw_compiler=compiler) + # o_actual = compiled_flash_attention(q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab, str(partition_spec), str(mesh)) + o_actual = flash_attention(q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab=ab, partition_spec=partition_spec, mesh=mesh) + + print(o_actual.sum()) + o_actual.sum().backward() + print(q.grad) + + # if causal: + # attention_mask = torch.triu(torch.ones(SEQ, SEQ), diagonal=1).to("xla") + # # attention_mask = attention_mask.view(1, 1, SEQ, SEQ) + # # attention_mask = attention_mask.expand(q.size(0), q.size(1), -1, -1) + # else: + # attention_mask = None + # print(attention_mask) + # assert False + # import torch_xla.distributed.spmd as xs + # expected_output = self._attention(q, k, v, attn_mask = attention_mask) + # print(expected_output) + # self.assertTrue( + # torch.allclose( + # expected_output.cpu(), + # o_actual.cpu(), + # atol=1e-1, + # rtol=1e-1)) + + @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4, "This test only works on TPUv4+.") def test_paged_attention_wrapper(self): diff --git a/third_party/xla b/third_party/xla new file mode 160000 index 000000000000..6e91ff19dad5 --- /dev/null +++ b/third_party/xla @@ -0,0 +1 @@ +Subproject commit 6e91ff19dad528ab7d2025a9bb46150618a3bc7d diff --git a/torch_xla/distributed/spmd/xla_sharding.py b/torch_xla/distributed/spmd/xla_sharding.py index 5ea1343bfcdf..0151abd8e716 100644 --- a/torch_xla/distributed/spmd/xla_sharding.py +++ b/torch_xla/distributed/spmd/xla_sharding.py @@ -124,8 +124,27 @@ def get_op_sharding(self, partition_spec) return torch_xla._XLAC.OpSharding(tile_assignment, group_assignment, replication_groups, sharding_type) - - + + def __str__(self): + """Convert Mesh to string representation.""" + return (f"{{'device_ids': {self.device_ids.tolist()}, " + f"'mesh_shape': {self.mesh_shape}, " + f"'axis_names': {self.axis_names}}}") + + @classmethod + def from_str(cls, mesh_str: str): + """Create Mesh from string representation.""" + import ast + import numpy as np + # Remove 'Mesh' and parse dict + dict_str = mesh_str.replace('Mesh', '') + mesh_dict = ast.literal_eval(dict_str) + # Convert list back to numpy array for device_ids + return cls( + device_ids=np.array(mesh_dict['device_ids']), + mesh_shape=mesh_dict['mesh_shape'], + axis_names=mesh_dict['axis_names'] + ) _GLOBAL_MESH: Mesh = None diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index 10036aeb9ca7..4074866cba2e 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -4,17 +4,43 @@ import numpy as np import torch +from torch.library import impl, custom_op import torch_xla import torch_xla.core.xla_model as xm import torch_xla.distributed.spmd as xs import torch_xla.debug.metrics as met +from contextlib import contextmanager from typing import Any, List, Callable, Optional, Tuple, Dict from torch.library import impl from torch_xla.core.xla_model import XLA_LIB +import torch_xla.debug.profiler as xp _XLA_USE_BF16 = os.environ.get("XLA_USE_BF16", "0") == "1" +_DEBUG = False + +def safe_empty_like(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]: + """Returns empty tensor like input, or None if input is None.""" + return torch.empty_like(tensor) if tensor is not None else None + + +def generate_ctx_need_grad(*args): + ctx_need_grad = [False for _ in range(len(args))] + for i, arg in enumerate(args): + if arg is not None and isinstance(arg, torch.Tensor) and arg.requires_grad: + ctx_need_grad[i] = True + return ctx_need_grad + +def describe_value(v): + if v is not None and isinstance(v, torch.Tensor): + print(f"{type(v)}({v.shape}, dtype={v.dtype}, device={v.device})") + elif isinstance(v, list): + print(f"list({len(v)})") + elif v is None: + print("None") + else: + print(type(v)) def _extract_backend_config( module: "jaxlib.mlir._mlir_libs._mlir.ir.Module") -> Optional[str]: @@ -52,15 +78,31 @@ def _extract_backend_config( return None -def jax_import_guard(): - # Somehow, we need to grab the TPU before JAX locks it. Otherwise, any pt-xla TPU operations will hang. - torch_xla._XLAC._init_computation_client() - - +@contextmanager +def _jax_env_context(): + try: + os.environ['SKIP_MEGASCALE_PJRT_CLIENT'] = 'true' + yield + finally: + os.environ.pop('SKIP_MEGASCALE_PJRT_CLIENT', None) + +def requires_jax(func: Callable) -> Callable: + """Decorator that ensures JAX is safely imported before function execution""" + @functools.wraps(func) + def wrapper(*args, **kwargs) -> Any: + try: + torch_xla._XLAC._init_computation_client() + except ImportError as e: + raise ImportError("JAX import guard fail due to PJRT client is unavailable.") from e + with _jax_env_context(): + return func(*args, **kwargs) + return wrapper + + +@requires_jax def convert_torch_dtype_to_jax(dtype: torch.dtype) -> "jnp.dtype": # Import JAX within the function such that we don't need to call the jax_import_guard() # in the global scope which could cause problems for xmp.spawn. - jax_import_guard() import jax.numpy as jnp if _XLA_USE_BF16: raise RuntimeError( @@ -87,10 +129,10 @@ def convert_torch_dtype_to_jax(dtype: torch.dtype) -> "jnp.dtype": raise ValueError(f"Unsupported dtype: {dtype}") +@requires_jax def to_jax_shape_dtype_struct(tensor: torch.Tensor) -> "jax.ShapeDtypeStruct": # Import JAX within the function such that we don't need to call the jax_import_guard() # in the global scope which could cause problems for xmp.spawn. - jax_import_guard() import jax return jax.ShapeDtypeStruct(tensor.shape, @@ -100,6 +142,7 @@ def to_jax_shape_dtype_struct(tensor: torch.Tensor) -> "jax.ShapeDtypeStruct": trace_pallas_arg_to_payload: Dict[Tuple[Any], str] = {} +@requires_jax def trace_pallas(kernel: Callable, *args, static_argnums=None, @@ -108,7 +151,6 @@ def trace_pallas(kernel: Callable, **kwargs): # Import JAX within the function such that we don't need to call the jax_import_guard() # in the global scope which could cause problems for xmp.spawn. - jax_import_guard() import jax import jax._src.pallas.mosaic.pallas_call_registration @@ -138,12 +180,10 @@ def trace_pallas(kernel: Callable, return trace_pallas_arg_to_payload[hash_key], tensor_args # Here we ignore the kwargs for execution as most of the time, the kwargs is only used in traced code. - os.environ['SKIP_MEGASCALE_PJRT_CLIENT'] = 'true' ir = jax.jit( kernel, static_argnums=static_argnums, static_argnames=static_argnames).lower(*jax_args, **kwargs).compiler_ir() payload = _extract_backend_config(ir) - os.environ.pop('SKIP_MEGASCALE_PJRT_CLIENT', None) if use_cache: # if we reach here it means we have a cache miss. @@ -181,6 +221,374 @@ def wrapped_kernel(kernel: Callable, return functools.partial(wrapped_kernel, kernel, output_shape_dtype_fn) +def defeat_alias(v): + return v * 1.0 + +# @custom_op("xla::fa_custom_forward_dummy", mutates_args={}) +# def fa_custom_forward_dummy(q: torch.Tensor)->torch.Tensor: +# print("+++++++") +# describe_value(q) +# print(q.requires_grad) +# x = q.clone() +# return x + +# @fa_custom_forward_dummy.register_fake +# def fa_custom_forward_dummy_fake(q: torch.Tensor): +# print("Inside fake fa_custom_forward") +# return torch.empty_like(q) + + +@custom_op("xla::fa_custom_forward", mutates_args=()) +def fa_custom_forward(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, causal: bool, q_segment_ids: torch.Tensor, kv_segment_ids: torch.Tensor, sm_scale: float, ab: Optional[torch.Tensor], partition_spec: str, mesh: str, ctx_grad: List[bool]) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # Suprisingly, any tensor that is input to the custom_op decorated function will show requires_grad=False. Is this a bug or feature? + from torch_xla.distributed.spmd import Mesh + partition_spec = eval(partition_spec) + mesh = Mesh.from_str(mesh) + assert mesh is not None + + # q_segment_ids = kv_segment_ids = ab = None + + from jax.experimental.pallas.ops.tpu.flash_attention import _flash_attention_impl + + q_full_shape = None + kv_full_shape = None + # save_residuals = q.requires_grad or k.requires_grad or v.requires_grad + save_residuals = any(ctx_grad[:3]) + + # SPMD integration. + # mark_sharding is in-placed, and therefore save the full q, k, v for the backward. + # PyTorch tell us clone is necessary: + # + # RuntimeError: Found a custom (non-ATen) operator whose output has alias + # annotations: xla::fa_custom_forward(Tensor(a0!) q, Tensor(a1!) k, + # Tensor(a2!) v) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor). We only + # support functionalizing operators whose outputs do not have alias + # annotations (e.g. 'Tensor(a)' is a Tensor with an alias annotation whereas + # 'Tensor' is a Tensor without. The '(a)' is the alias annotation). The alias + # annotation specifies that the output Tensor shares storage with an input + # that has the same annotation. Please check if (1) the output needs to be an + # output (if not, don't return it), (2) if the output doesn't share storage + # with any inputs, then delete the alias annotation. (3) if the output indeed + # shares storage with an input, then add a .clone() before returning it to + # prevent storage sharing and then delete the alias annotation. Otherwise, + # please file an issue on GitHub. + with xp.Trace('shard'): + full_q = q.clone() + full_k = k.clone() + full_v = v.clone() + if ab is not None: + full_ab = ab.clone() + else: + full_ab = None + if partition_spec is not None: + q_full_shape = q.shape + kv_full_shape = k.shape + ab_full_shape = ab.shape if ab is not None else None + q = xs.enable_manual_sharding(q, partition_spec, mesh=mesh).global_tensor + k = xs.enable_manual_sharding(k, partition_spec, mesh=mesh).global_tensor + v = xs.enable_manual_sharding(v, partition_spec, mesh=mesh).global_tensor + if ab is not None: + ab = xs.enable_manual_sharding( + ab, partition_spec, mesh=mesh).global_tensor + + # It computes the shape and type of o, l, m. + shapes = [q.shape] + dtypes = [q.dtype] + if save_residuals: + res_shape = list(q.shape) + res_shape[-1] = FlashAttention.MIN_BLOCK_SIZE + for _ in range(2): + shapes.append(res_shape) + dtypes.append(torch.float32) + + with torch.no_grad(): + if partition_spec is not None and q_segment_ids is not None and kv_segment_ids is not None: + # partition_spec is for q,k,v with shape [batch, num_head, seq_len, head_dim], segment id + # is of shape [batch, seq_len], hence we need to tweak it a bit + segment_id_partition_spec = (partition_spec[0], partition_spec[2]) + q_segment_ids = xs.enable_manual_sharding( + q_segment_ids, segment_id_partition_spec, mesh=mesh).global_tensor + kv_segment_ids = xs.enable_manual_sharding( + kv_segment_ids, segment_id_partition_spec, mesh=mesh).global_tensor + segment_ids, q_segment_ids_fa, kv_segment_ids_fa = FlashAttention.prepare_segment_ids( + q_segment_ids, kv_segment_ids) + + # We can't directly use flash_attention as we need to override the save_residuals flag which returns + # l and m that is needed for the backward. Then we lose all the shape checks. + # TODO: replicate the shape checks on flash_attention. + # Here we seperate the tracing and execution part just to support SegmentIds. + payload, _ = trace_pallas( + _flash_attention_impl, + q, + k, + v, + ab, + segment_ids, + save_residuals, + causal, + sm_scale, + min(FlashAttention.DEFAULT_BLOCK_SIZES["block_b"], q.shape[0]), + min(FlashAttention.DEFAULT_BLOCK_SIZES["block_q"], q.shape[2]), + min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k_major"], k.shape[2]), + min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k"], k.shape[2]), + False, + static_argnums=range(5, 13), + use_cache=True, + ) + + args = [q, k, v] + if ab is not None: + args += [ab] + if segment_ids is not None: + args += [q_segment_ids_fa, kv_segment_ids_fa] + o = torch_xla._XLAC._xla_tpu_custom_call(args, payload, shapes, dtypes) + + if not save_residuals: + o = o[0] + # SPMD integration + if partition_spec is not None: + o = xs.disable_manual_sharding( + o, partition_spec, q_full_shape, mesh=mesh).global_tensor + # We need to consistently return full_q, full_k, full_v,... even though they are empty to support AOT. + return (o, torch.Tensor(), torch.Tensor(), torch.Tensor(), torch.Tensor(), torch.Tensor(), torch.Tensor(), torch.Tensor(), torch.Tensor()) + + assert isinstance(o, list) + o, *aux = o + + # The fancier slice notation lowers to `aten.take`, which sends a large indexing + # tensor to the device and confuses the XLA compiler when used under scan for some reason. + # See the transfer to device in a trace: http://shortn/_4zOQhGezCS. + # As a result, we get a `!IsManual()` assertion in HLO sharding propgation. + # Therefore, we spell it as a permute + index into the first dim. + # However, that causes NaN loss for some reason. So we'll perform the slicing instead. + # # l = aux[-2][:, :, :, 0] + # # l = aux[-2].permute(3, 0, 1, 2)[0] + # l = aux[-2] + # l = torch.ops.aten.slice(l, -1, 0, 1) + # # print(torch_xla._XLAC._get_xla_tensors_text([l])) + # # m = aux[-1][:, :, :, 0] + # # m = aux[-1].permute(3, 0, 1, 2)[0] + # m = aux[-1] + # m = torch.ops.aten.slice(m, -1, 0, 1) + l, m = (v[..., 0] for v in aux[-2:]) + + # SPMD integration + if partition_spec is not None: + o = xs.disable_manual_sharding( + o, partition_spec, q_full_shape, mesh=mesh).global_tensor + l = xs.disable_manual_sharding( + l, partition_spec[0:3], q_full_shape[0:3], + mesh=mesh).global_tensor + m = xs.disable_manual_sharding( + m, partition_spec[0:3], q_full_shape[0:3], + mesh=mesh).global_tensor + + # l = l.squeeze(-1) + # m = m.squeeze(-1) + + + # q_segment_ids and kv_segment_ids are sharded here if partition_spec is provided + # but it should be OK as the backward will use the same partition_spec + outs = [o] + [full_q, full_k, full_v, l, m, q_segment_ids_fa, kv_segment_ids_fa, full_ab] + + # outs = [o] + [full_q, full_k, full_v, l, m] + if _DEBUG: + print("Outs") + for t in outs: + describe_value(t) + return tuple(outs) + + +@custom_op("xla::fa_custom_backward", mutates_args=()) +# def fa_custom_backward(grad_output: torch.Tensor, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, o: torch.Tensor, l: torch.Tensor, m: torch.Tensor, ab: torch.Tensor, causal: bool, sm_scale: float, partition_spec: str, mesh: str, q_full_shape: str, kv_full_shape: str) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: +def fa_custom_backward(grad_output: torch.Tensor, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, o: torch.Tensor, l: torch.Tensor, m: torch.Tensor, q_segment_ids: torch.Tensor, kv_segment_ids: torch.Tensor, ab: Optional[torch.Tensor], causal: bool, sm_scale: float, partition_spec: str, mesh: str, q_full_shape: List[int], kv_full_shape: List[int], ab_full_shape: Optional[List[int]], ctx_grad: List[bool]) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + + from jax.experimental.pallas.ops.tpu.flash_attention import _flash_attention_bwd_dq, _flash_attention_bwd_dkv + grad_q = grad_k = grad_v = grad_ab = None + partition_spec = eval(partition_spec) + mesh = xs.Mesh.from_str(mesh) + q_full_shape = torch.Size(q_full_shape) + kv_full_shape = torch.Size(kv_full_shape) + ab_full_shape = torch.Size(ab_full_shape) if ab_full_shape is not None else None + # reconstruct the segment_ids from q_segment_ids and kv_segment_ids + segment_ids, q_segment_ids_fa, kv_segment_ids_fa = FlashAttention.prepare_segment_ids(q_segment_ids, kv_segment_ids) + grad_i = torch.sum( + o.to(torch.float32) * grad_output.to(torch.float32), + axis=-1) # [batch_size, num_heads, q_seq_len] + + expanded_l = l.unsqueeze(-1).expand([-1 for _ in l.shape] + + [FlashAttention.MIN_BLOCK_SIZE]) + expanded_m = m.unsqueeze(-1).expand([-1 for _ in m.shape] + + [FlashAttention.MIN_BLOCK_SIZE]) + expanded_grad_i = grad_i.unsqueeze(-1).expand( + [-1 for _ in grad_i.shape] + [FlashAttention.MIN_BLOCK_SIZE]) + + # SPMD integration + if partition_spec is not None: + q = xs.enable_manual_sharding(q, partition_spec, mesh=mesh).global_tensor + k = xs.enable_manual_sharding(k, partition_spec, mesh=mesh).global_tensor + v = xs.enable_manual_sharding(v, partition_spec, mesh=mesh).global_tensor + expanded_l = xs.enable_manual_sharding( + expanded_l, partition_spec, mesh=mesh).global_tensor + expanded_m = xs.enable_manual_sharding( + expanded_m, partition_spec, mesh=mesh).global_tensor + grad_output = xs.enable_manual_sharding( + grad_output, partition_spec, mesh=mesh).global_tensor + expanded_grad_i = xs.enable_manual_sharding( + expanded_grad_i, partition_spec, mesh=mesh).global_tensor + if ab is not None: + ab = xs.enable_manual_sharding( + ab, partition_spec, mesh=mesh).global_tensor + + if ctx_grad[0]: + payload, _ = trace_pallas( + _flash_attention_bwd_dq, + q, + k, + v, + ab, + segment_ids, + l, + m, + grad_output, + grad_i, + block_q_major=min(FlashAttention.DEFAULT_BLOCK_SIZES["block_q_dq"], + q.shape[2]), + block_k_major=min( + FlashAttention.DEFAULT_BLOCK_SIZES["block_k_major_dq"], + k.shape[2]), + block_k=min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k_dq"], + k.shape[2]), + sm_scale=sm_scale, + causal=causal, + mask_value=FlashAttention.DEFAULT_MASK_VALUE, + debug=False, + static_argnames=[ + "block_q_major", "block_k_major", "block_k", "sm_scale", "causal", + "mask_value", "debug" + ], + use_cache=True, + ) + + args = [q, k, v] + if ab is not None: + args += [ab] + if segment_ids is not None: + args += [q_segment_ids_fa, kv_segment_ids_fa] + args += [expanded_l, expanded_m, grad_output, expanded_grad_i] + + outputs = [q] + if ab is not None: + outputs += [ab] + grads = torch_xla._XLAC._xla_tpu_custom_call(args, payload, + [i.shape for i in outputs], + [i.dtype for i in outputs]) + if ctx_grad[0]: + grad_q = grads[0] + + if ctx_grad[-3]: + grad_ab = grads[1] + if ctx_grad[1] or ctx_grad[2]: + payload, _ = trace_pallas( + _flash_attention_bwd_dkv, + q, + k, + v, + ab, + segment_ids, + l, + m, + grad_output, + grad_i, + block_q_major=min( + FlashAttention.DEFAULT_BLOCK_SIZES["block_q_major_dkv"], + q.shape[2]), + block_k_major=min( + FlashAttention.DEFAULT_BLOCK_SIZES["block_k_major_dkv"], + k.shape[2]), + block_k=min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k_dkv"], + k.shape[2]), + block_q=min(FlashAttention.DEFAULT_BLOCK_SIZES["block_q_dkv"], + q.shape[2]), + sm_scale=sm_scale, + causal=causal, + mask_value=FlashAttention.DEFAULT_MASK_VALUE, + debug=False, + static_argnames=[ + "block_q_major", "block_k_major", "block_k", "block_q", + "sm_scale", "causal", "mask_value", "debug" + ], + use_cache=True) + + grads = torch_xla._XLAC._xla_tpu_custom_call(args, payload, + [k.shape, v.shape], + [k.dtype, v.dtype]) + + if ctx_grad[1]: + grad_k = grads[0] + if ctx_grad[2]: + grad_v = grads[1] + + # SPMD integration + if partition_spec is not None: + grad_q = xs.disable_manual_sharding( + grad_q, partition_spec, q_full_shape, mesh=mesh).global_tensor + grad_k = xs.disable_manual_sharding( + grad_k, partition_spec, kv_full_shape, mesh=mesh).global_tensor + grad_v = xs.disable_manual_sharding( + grad_v, partition_spec, kv_full_shape, mesh=mesh).global_tensor + if ab is not None: + grad_ab = xs.disable_manual_sharding( + grad_ab, partition_spec, ab_full_shape, mesh=mesh).global_tensor + # print(grad_q, grad_k, grad_v, grad_ab) + return grad_q, grad_k, grad_v, grad_ab + + +@fa_custom_forward.register_fake +# def fa_custom_forward_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, causa: bool): +def fa_custom_forward_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, causal: bool, q_segment_ids: torch.Tensor, kv_segment_ids: torch.Tensor, sm_scale: float, ab: torch.Tensor, partition_spec: str, mesh: str, ctx_grad: List[bool]): + # q_segment_ids=None + # kv_segment_ids=None + # sm_scale=1.0 + # ab=None, + # partition_spec="('fsdp', 'tensor', None, None)" + # mesh="{'device_ids': [0, 1, 2, 3], 'mesh_shape': (2, 2), 'axis_names': ('fsdp', 'tensor')}" + if _DEBUG: + print("Inside fake fa_custom_forward") + + assert q.shape == k.shape + assert k.shape == v.shape + + # full_q, full_k, full_v, o, l, m + full_q = torch.empty_like(q) + full_k = torch.empty_like(k) + full_v = torch.empty_like(v) + full_ab = safe_empty_like(ab) + o = torch.empty_like(v) + l = torch.empty_like(v, dtype=torch.float32)[..., 0] + m = torch.empty_like(v, dtype=torch.float32)[..., 0] + + return tuple([torch.empty_like(o)] + + [safe_empty_like(t) for t in ( + full_q, + full_k, + full_v, + l, + m, + q_segment_ids, + kv_segment_ids, + full_ab, + )]) + +@fa_custom_backward.register_fake +def fa_custom_backward_fake(grad_output, q, k, v, o, l, m, q_segment_ids, kv_segment_ids, ab, causal, sm_scale, partition_spec, mesh, q_full_shape, kv_full_shape, ab_full_shape, ctx_grad): + if _DEBUG: + print("Inside fake fa_custom_backward") + return safe_empty_like(q), safe_empty_like( + grad_output), safe_empty_like(grad_output), safe_empty_like(ab) + + + class FlashAttention(torch.autograd.Function): """ @@ -209,9 +617,9 @@ class FlashAttention(torch.autograd.Function): NUM_SUBLANES = 8 @staticmethod - def prepare_segment_ids(q_segment_ids, kv_segment_ids): + def prepare_segment_ids(q_segment_ids, kv_segment_ids) -> Tuple["SegmentIds", torch.Tensor, torch.Tensor]: from jax.experimental.pallas.ops.tpu.flash_attention import SegmentIds - if q_segment_ids is None or kv_segment_ids is None: + if q_segment_ids is None and kv_segment_ids is None: return None, None, None assert q_segment_ids is not None and kv_segment_ids is not None, "Both q_segment_ids and kv_segment_ids should be provided." @@ -227,262 +635,65 @@ def prepare_segment_ids(q_segment_ids, kv_segment_ids): return segment_ids, q_segment_ids, kv_segment_ids @staticmethod + @requires_jax def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab, partition_spec, mesh): - # Import JAX within the function such that we don't need to call the jax_import_guard() - # in the global scope which could cause problems for xmp.spawn. - jax_import_guard() - import jax - from jax.experimental.pallas.ops.tpu.flash_attention import _flash_attention_impl - + # with torch.no_grad(): + ctx.q_shape = q.shape + ctx.k_shape = k.shape ctx.causal = causal ctx.sm_scale = sm_scale ctx.partition_spec = partition_spec ctx.mesh = mesh - ctx.q_full_shape = None - ctx.kv_full_shape = None - save_residuals = q.requires_grad or k.requires_grad or v.requires_grad - - # SPMD integration. - # mark_sharding is in-placed, and therefore save the full q, k, v for the backward. - full_q = q - full_k = k - full_v = v - full_ab = ab - if partition_spec is not None: - ctx.q_full_shape = q.shape - ctx.kv_full_shape = k.shape - q = xs.enable_manual_sharding(q, partition_spec, mesh=mesh).global_tensor - k = xs.enable_manual_sharding(k, partition_spec, mesh=mesh).global_tensor - v = xs.enable_manual_sharding(v, partition_spec, mesh=mesh).global_tensor - if ab is not None: - ab = xs.enable_manual_sharding( - ab, partition_spec, mesh=mesh).global_tensor - - # It computes the shape and type of o, l, m. - shapes = [q.shape] - dtypes = [q.dtype] - if save_residuals: - res_shape = list(q.shape) - res_shape[-1] = FlashAttention.MIN_BLOCK_SIZE - for _ in range(2): - shapes.append(res_shape) - dtypes.append(torch.float32) - - with torch.no_grad(): - if partition_spec is not None and q_segment_ids is not None and kv_segment_ids is not None: - # partition_spec is for q,k,v with shape [batch, num_head, seq_len, head_dim], segment id - # is of shape [batch, seq_len], hence we need to tweak it a bit - segment_id_partition_spec = (partition_spec[0], partition_spec[2]) - q_segment_ids = xs.enable_manual_sharding( - q_segment_ids, segment_id_partition_spec, mesh=mesh).global_tensor - kv_segment_ids = xs.enable_manual_sharding( - kv_segment_ids, segment_id_partition_spec, mesh=mesh).global_tensor - segment_ids, q_segment_ids_fa, kv_segment_ids_fa = FlashAttention.prepare_segment_ids( - q_segment_ids, kv_segment_ids) - ctx.segment_ids = segment_ids - - # We can't directly use flash_attention as we need to override the save_residuals flag which returns - # l and m that is needed for the backward. Then we lose all the shape checks. - # TODO: replicate the shape checks on flash_attention. - # Here we seperate the tracing and execution part just to support SegmentIds. - payload, _ = trace_pallas( - _flash_attention_impl, - q, - k, - v, - ab, - segment_ids, - save_residuals, - causal, - sm_scale, - min(FlashAttention.DEFAULT_BLOCK_SIZES["block_b"], q.shape[0]), - min(FlashAttention.DEFAULT_BLOCK_SIZES["block_q"], q.shape[2]), - min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k_major"], k.shape[2]), - min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k"], k.shape[2]), - False, - static_argnums=range(5, 13), - use_cache=True, - ) - - args = [q, k, v] - if ab is not None: - args += [ab] - if segment_ids is not None: - args += [q_segment_ids_fa, kv_segment_ids_fa] - o = torch_xla._XLAC._xla_tpu_custom_call(args, payload, shapes, dtypes) - - if not save_residuals: - o = o[0] - # SPMD integration - if partition_spec is not None: - o = xs.disable_manual_sharding( - o, partition_spec, ctx.q_full_shape, mesh=mesh).global_tensor - return o - o, *aux = o - l, m = (v[..., 0] for v in aux[-2:]) - - # SPMD integration - if partition_spec is not None: - o = xs.disable_manual_sharding( - o, partition_spec, ctx.q_full_shape, mesh=mesh).global_tensor - l = xs.disable_manual_sharding( - l, partition_spec[0:3], ctx.q_full_shape[0:3], - mesh=mesh).global_tensor - m = xs.disable_manual_sharding( - m, partition_spec[0:3], ctx.q_full_shape[0:3], - mesh=mesh).global_tensor + ctx.q_full_shape = q.shape + ctx.kv_full_shape = k.shape + ctx.ab_full_shape = ab.shape if ab is not None else None + partition_spec = str(partition_spec) + mesh = str(mesh) + # x = fa_custom_forward_dummy(q) + custom_op_arg = [q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab, partition_spec, mesh] + ctx_grads = generate_ctx_need_grad(*custom_op_arg) + outs = fa_custom_forward(*custom_op_arg, ctx_grads) + + o = outs[0] + full_q, full_k, full_v, l, m, q_segment_ids_fa, kv_segment_ids_fa, full_ab = [x for x in outs[1:]] + # full_ab = ab # q_segment_ids and kv_segment_ids are sharded here if partition_spec is provided # but it should be OK as the backward will use the same partition_spec - ctx.save_for_backward(full_q, full_k, full_v, o, l, m, q_segment_ids_fa, - kv_segment_ids_fa, full_ab) + # ctx.save_for_backward(full_q, full_k, full_v, o, l, m, full_ab) + ctx.save_for_backward(full_q, full_k, full_v, o, l, m, q_segment_ids, kv_segment_ids, full_ab) return o @staticmethod + @requires_jax def backward(ctx, grad_output): - from jax.experimental.pallas.ops.tpu.flash_attention import _flash_attention_bwd_dq, _flash_attention_bwd_dkv - - q, k, v, o, l, m, q_segment_ids_fa, kv_segment_ids_fa, ab = ctx.saved_tensors + q, k, v, o, l, m, q_segment_ids, kv_segment_ids, ab = ctx.saved_tensors + # q, k, v, o, l, m, ab = ctx.saved_tensors causal = ctx.causal sm_scale = ctx.sm_scale partition_spec = ctx.partition_spec mesh = ctx.mesh q_full_shape = ctx.q_full_shape kv_full_shape = ctx.kv_full_shape + ab_full_shape = ctx.ab_full_shape + grad_output = grad_output.contiguous() + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + o = o.contiguous() + l = l.contiguous() + m = m.contiguous() # this segment_ids only reflects the local shape of segment_ids - segment_ids = ctx.segment_ids - grad_q = grad_k = grad_v = grad_ab = None - - grad_i = torch.sum( - o.to(torch.float32) * grad_output.to(torch.float32), - axis=-1) # [batch_size, num_heads, q_seq_len] - - expanded_l = l.unsqueeze(-1).expand([-1 for _ in l.shape] + - [FlashAttention.MIN_BLOCK_SIZE]) - expanded_m = m.unsqueeze(-1).expand([-1 for _ in m.shape] + - [FlashAttention.MIN_BLOCK_SIZE]) - expanded_grad_i = grad_i.unsqueeze(-1).expand( - [-1 for _ in grad_i.shape] + [FlashAttention.MIN_BLOCK_SIZE]) - - # SPMD integration - if partition_spec is not None: - q = xs.enable_manual_sharding(q, partition_spec, mesh=mesh).global_tensor - k = xs.enable_manual_sharding(k, partition_spec, mesh=mesh).global_tensor - v = xs.enable_manual_sharding(v, partition_spec, mesh=mesh).global_tensor - expanded_l = xs.enable_manual_sharding( - expanded_l, partition_spec, mesh=mesh).global_tensor - expanded_m = xs.enable_manual_sharding( - expanded_m, partition_spec, mesh=mesh).global_tensor - grad_output = xs.enable_manual_sharding( - grad_output, partition_spec, mesh=mesh).global_tensor - expanded_grad_i = xs.enable_manual_sharding( - expanded_grad_i, partition_spec, mesh=mesh).global_tensor - if ab is not None: - ab = xs.enable_manual_sharding( - ab, partition_spec, mesh=mesh).global_tensor - - if ctx.needs_input_grad[0]: - payload, _ = trace_pallas( - _flash_attention_bwd_dq, - q, - k, - v, - ab, - segment_ids, - l, - m, - grad_output, - grad_i, - block_q_major=min(FlashAttention.DEFAULT_BLOCK_SIZES["block_q_dq"], - q.shape[2]), - block_k_major=min( - FlashAttention.DEFAULT_BLOCK_SIZES["block_k_major_dq"], - k.shape[2]), - block_k=min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k_dq"], - k.shape[2]), - sm_scale=sm_scale, - causal=causal, - mask_value=FlashAttention.DEFAULT_MASK_VALUE, - debug=False, - static_argnames=[ - "block_q_major", "block_k_major", "block_k", "sm_scale", "causal", - "mask_value", "debug" - ], - use_cache=True, - ) - - args = [q, k, v] - if ab is not None: - args += [ab] - if segment_ids is not None: - args += [q_segment_ids_fa, kv_segment_ids_fa] - args += [expanded_l, expanded_m, grad_output, expanded_grad_i] - - outputs = [q] - if ab is not None: - outputs += [ab] - grads = torch_xla._XLAC._xla_tpu_custom_call(args, payload, - [i.shape for i in outputs], - [i.dtype for i in outputs]) - if ctx.needs_input_grad[0]: - grad_q = grads[0] - if ctx.needs_input_grad[-3]: - grad_ab = grads[1] - - if ctx.needs_input_grad[1] or ctx.needs_input_grad[2]: - payload, _ = trace_pallas( - _flash_attention_bwd_dkv, - q, - k, - v, - ab, - segment_ids, - l, - m, - grad_output, - grad_i, - block_q_major=min( - FlashAttention.DEFAULT_BLOCK_SIZES["block_q_major_dkv"], - q.shape[2]), - block_k_major=min( - FlashAttention.DEFAULT_BLOCK_SIZES["block_k_major_dkv"], - k.shape[2]), - block_k=min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k_dkv"], - k.shape[2]), - block_q=min(FlashAttention.DEFAULT_BLOCK_SIZES["block_q_dkv"], - q.shape[2]), - sm_scale=sm_scale, - causal=causal, - mask_value=FlashAttention.DEFAULT_MASK_VALUE, - debug=False, - static_argnames=[ - "block_q_major", "block_k_major", "block_k", "block_q", - "sm_scale", "causal", "mask_value", "debug" - ], - use_cache=True) - - grads = torch_xla._XLAC._xla_tpu_custom_call(args, payload, - [k.shape, v.shape], - [k.dtype, v.dtype]) - - if ctx.needs_input_grad[1]: - grad_k = grads[0] - if ctx.needs_input_grad[2]: - grad_v = grads[1] - - # SPMD integration - if partition_spec is not None: - grad_q = xs.disable_manual_sharding( - grad_q, partition_spec, q_full_shape, mesh=mesh).global_tensor - grad_k = xs.disable_manual_sharding( - grad_k, partition_spec, kv_full_shape, mesh=mesh).global_tensor - grad_v = xs.disable_manual_sharding( - grad_v, partition_spec, kv_full_shape, mesh=mesh).global_tensor - + # grad_q, grad_k, grad_v, grad_ab = fa_custom_backward(grad_output, q, k, v, o, l, m, ab, causal, sm_scale, str(partition_spec), str(mesh), str(q_full_shape), str(kv_full_shape)) + custom_op_arg = [grad_output, q, k, v, o, l, m, q_segment_ids, kv_segment_ids, ab, causal, sm_scale, str(partition_spec), str(mesh), q_full_shape, kv_full_shape, ab_full_shape] + # ctx_grads = generate_ctx_need_grad(*custom_op_arg) + ctx_grads = ctx.needs_input_grad + grad_q, grad_k, grad_v, grad_ab = fa_custom_backward(*custom_op_arg, ctx_grads) return grad_q, grad_k, grad_v, None, None, None, None, grad_ab, None, None + def flash_attention( q, # [batch_size, num_heads, q_seq_len, d_model] k, # [batch_size, num_heads, kv_seq_len, d_model] @@ -501,6 +712,34 @@ def flash_attention( sm_scale, ab, partition_spec, mesh) +def flash_attention_compilable( + q, # [batch_size, num_heads, q_seq_len, d_model] + k, # [batch_size, num_heads, kv_seq_len, d_model] + v, # [batch_size, num_heads, kv_seq_len, d_model] + causal=False, + q_segment_ids=None, # [batch_size, q_seq_len] + kv_segment_ids=None, # [batch_size, kv_seq_len] + sm_scale=1.0, + *, + ab=None, # [batch_size, num_heads, q_seq_len, kv_seq_len] + partition_spec: Optional[str], + mesh=Optional[str], +): + return flash_attention( + q, + k, + v, + causal, + q_segment_ids, + kv_segment_ids, + sm_scale, + ab=ab, + partition_spec=eval(partition_spec), + mesh=xs.Mesh.from_str(mesh), + ) + + + def _multi_queries_paged_attention_nonkernel( q, # [batch_size, query_len, num_heads, head_size] k_pages, # [num_kv_heads, total_num_pages, page_size, head_size] @@ -565,6 +804,7 @@ def _multi_queries_paged_attention_nonkernel( return output +@requires_jax def multi_queries_paged_attention( q, # [batch_size, query_len, num_heads, head_size] k_pages, # [num_kv_heads, total_num_pages, page_size, head_size] @@ -591,7 +831,6 @@ def multi_queries_paged_attention( # Import JAX within the function such that we don't need to call the jax_import_guard() # in the global scope which could cause problems for xmp.spawn. - jax_import_guard() from torch_xla.experimental.pallas_kernels.multi_queries_paged_attention_kernel import paged_attention payload, tensor_args = trace_pallas( paged_attention, @@ -634,6 +873,7 @@ def multi_queries_paged_attention( return output.permute(0, 2, 1, 3).to(q_dtype_for_kernel_launch) +@requires_jax def paged_attention(q, k_pages, v_pages, @@ -644,7 +884,6 @@ def paged_attention(q, attn_logits_soft_cap: float = None): # Import JAX within the function such that we don't need to call the jax_import_guard() # in the global scope which could cause problems for xmp.spawn. - jax_import_guard() from jax.experimental.pallas.ops.tpu.paged_attention.paged_attention_kernel import paged_attention assert megacore_mode in [ @@ -921,6 +1160,7 @@ def repeat_with_fixed_output_size(input: torch.Tensor, repeats: torch.Tensor, return res +@requires_jax def gmm( lhs: torch.Tensor, rhs: torch.Tensor, @@ -940,7 +1180,6 @@ def gmm( """ # Import JAX within the function such that we don't need to call the jax_import_guard() # in the global scope which could cause problems for xmp.spawn. - jax_import_guard() from jax.experimental.pallas.ops.tpu.megablox.gmm import gmm m, k, n = lhs.shape[0], lhs.shape[1], rhs.shape[2] @@ -973,6 +1212,7 @@ def gmm( ], payload, [torch.Size([m, n])], [preferred_element_type])[0] +@requires_jax def tgmm( lhs: torch.Tensor, rhs: torch.Tensor, @@ -992,7 +1232,6 @@ def tgmm( """ # Import JAX within the function such that we don't need to call the jax_import_guard() # in the global scope which could cause problems for xmp.spawn. - jax_import_guard() from jax.experimental.pallas.ops.tpu.megablox.gmm import tgmm k, m, n, num_groups = lhs.shape[0], lhs.shape[1], rhs.shape[ diff --git a/torch_xla/experimental/pallas_kernels/flash_attention_kernel.py b/torch_xla/experimental/pallas_kernels/flash_attention_kernel.py new file mode 100644 index 000000000000..e69de29bb2d1