Skip to content

Commit

Permalink
backward with spmd issue
Browse files Browse the repository at this point in the history
  • Loading branch information
zpcore committed Jan 30, 2025
1 parent 8e6ca60 commit f82f373
Show file tree
Hide file tree
Showing 5 changed files with 586 additions and 253 deletions.
78 changes: 76 additions & 2 deletions test/test_pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
import numpy as np

if xr.device_type() == 'TPU':
from torch_xla.experimental.custom_kernel import jax_import_guard
jax_import_guard()
# from torch_xla.experimental.custom_kernel import jax_import_guard
# jax_import_guard()
torch_xla._XLAC._init_computation_client()
import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl
Expand Down Expand Up @@ -488,6 +489,79 @@ def test_flash_attention_backward(self):
self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-05))
jax.config.update("jax_default_matmul_precision", "default")


@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
"This test only works on TPUv4+.")
def test_flash_attention_backward_aot_autograd_traceable(self):
from functorch.compile import aot_function, make_boxed_func
from torch_xla.experimental.custom_kernel import flash_attention, FlashAttention, flash_attention_compilable
import torch_xla.core.xla_model as xm
jax.config.update("jax_default_matmul_precision", "highest")
def compiler(gm, _):
print("Got graph:")
print(gm.code)
return make_boxed_func(gm)

torch.manual_seed(42)
q = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
v = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
q.retain_grad()
k.retain_grad()
v.retain_grad()
B, N, SEQ, H = q.size()
causal = True
q_segment_ids = None
kv_segment_ids = None
sm_scale = 1.0
mask = (torch.rand(4, 2, 128, 128) > 0.5).to("xla")
# ab = torch.ones(4, 2, 128, 128).to("xla")
# ab = ab.masked_fill(mask, torch.finfo(ab.dtype).min).requires_grad_(True)
# ab.retain_grad()
ab = None
partition_spec = ('fsdp', 'tensor', None, None)
# partition_spec = None
import torch_xla.runtime as xr
from torch_xla.distributed.spmd import Mesh
xr.use_spmd()
num_devices = xr.global_runtime_device_count()
mesh_shape = (num_devices // 2, 2)
device_ids = np.array(range(num_devices))
mesh = Mesh(device_ids, mesh_shape, ('fsdp', 'tensor'))

def flash_attention_wrapper(q, k, v, casual, q_segment_ids, kv_segment_ids, sm_scale, ab, partition_spec, mesh):
return flash_attention_compilable(q, k, v, casual, q_segment_ids, kv_segment_ids, sm_scale, ab=ab, partition_spec=partition_spec, mesh=mesh)


# AOT compatiable funtion only accepts argument types listed https://github.com/pytorch/pytorch/blob/82859f61857ef39898b34a5cdf0ae56ec25704d9/torch/_functorch/_aot_autograd/utils.py#L23-L34, so we serliaze partition_spec and mesh into string.
# compiled_flash_attention = aot_function(
# flash_attention_wrapper, fw_compiler=compiler)
# o_actual = compiled_flash_attention(q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab, str(partition_spec), str(mesh))
o_actual = flash_attention(q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab=ab, partition_spec=partition_spec, mesh=mesh)

print(o_actual.sum())
o_actual.sum().backward()
print(q.grad)

# if causal:
# attention_mask = torch.triu(torch.ones(SEQ, SEQ), diagonal=1).to("xla")
# # attention_mask = attention_mask.view(1, 1, SEQ, SEQ)
# # attention_mask = attention_mask.expand(q.size(0), q.size(1), -1, -1)
# else:
# attention_mask = None
# print(attention_mask)
# assert False
# import torch_xla.distributed.spmd as xs
# expected_output = self._attention(q, k, v, attn_mask = attention_mask)
# print(expected_output)
# self.assertTrue(
# torch.allclose(
# expected_output.cpu(),
# o_actual.cpu(),
# atol=1e-1,
# rtol=1e-1))


@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
"This test only works on TPUv4+.")
def test_paged_attention_wrapper(self):
Expand Down
1 change: 1 addition & 0 deletions third_party/xla
Submodule xla added at 6e91ff
23 changes: 21 additions & 2 deletions torch_xla/distributed/spmd/xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,27 @@ def get_op_sharding(self,
partition_spec)
return torch_xla._XLAC.OpSharding(tile_assignment, group_assignment,
replication_groups, sharding_type)



def __str__(self):
"""Convert Mesh to string representation."""
return (f"{{'device_ids': {self.device_ids.tolist()}, "
f"'mesh_shape': {self.mesh_shape}, "
f"'axis_names': {self.axis_names}}}")

@classmethod
def from_str(cls, mesh_str: str):
"""Create Mesh from string representation."""
import ast
import numpy as np
# Remove 'Mesh' and parse dict
dict_str = mesh_str.replace('Mesh', '')
mesh_dict = ast.literal_eval(dict_str)
# Convert list back to numpy array for device_ids
return cls(
device_ids=np.array(mesh_dict['device_ids']),
mesh_shape=mesh_dict['mesh_shape'],
axis_names=mesh_dict['axis_names']
)
_GLOBAL_MESH: Mesh = None


Expand Down
Loading

0 comments on commit f82f373

Please sign in to comment.