diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index 2eec624e5065..03984166acc8 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -297,6 +297,7 @@ def reshape_output(tensor): return reshape_input, reshape_output +@requires_jax def _fa_custom_forward_single_device( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, causal: bool, q_segment_ids: torch.Tensor, kv_segment_ids: torch.Tensor, sm_scale: float, @@ -484,6 +485,7 @@ def _pad_to_block_size( return padded, pad_size +@requires_jax def _fa_custom_backward_single_device( grad_output: torch.Tensor, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, o: torch.Tensor, l: torch.Tensor, m: torch.Tensor, @@ -783,7 +785,6 @@ def prepare_segment_ids( return segment_ids, q_segment_ids, kv_segment_ids @staticmethod - @requires_jax def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab, partition_spec, mesh): ctx.q_shape = q.shape @@ -815,7 +816,6 @@ def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab, return o @staticmethod - @requires_jax def backward(ctx, grad_output): q, k, v, o, l, m, q_segment_ids, kv_segment_ids, ab = ctx.saved_tensors causal = ctx.causal