Skip to content

[Kernel] Use call_jax to simplify the gmm pallas kernel wrapper #9180

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 16 additions & 169 deletions test/test_gmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
from torch_xla.experimental.custom_kernel import gmm, _make_group_metadata, _histogram, tgmm, gmm_backward, GMM
from torch_xla.experimental.custom_kernel import gmm, tgmm, gmm_backward, GMM
from torch_xla import runtime as xr
from torch_xla._internal import tpu

Expand Down Expand Up @@ -120,10 +120,11 @@ def test_gmm(self):
# torch.compiled version of the gmm will cache the payload in dynamo layer
# hence won't trigger the trace_pallas cache
if test_cache and gmm_func != compiled_gmm:
met.clear_counters()
old_cnt = xr.get_num_cached_compilation_graph()
# execute the same gmm func, expected to hit the cache
out = gmm_func(lhs.to("xla"), rhs.to("xla"), group_sizes.to("xla"))
self.assertEqual(met.counter_value('trace_pallas_cache_hit'), 1)
new_cnt = xr.get_num_cached_compilation_graph()
self.assertEqual(old_cnt, new_cnt)
self.assertTrue(torch.allclose(ref_out, out.cpu()))

# Make sure gmm doesn't fallback.
Expand Down Expand Up @@ -155,173 +156,16 @@ def test_gmm_bf16(self):
# torch.compiled version of the gmm will cache the payload in dynamo layer
# hence won't trigger the trace_pallas cache
if test_cache and gmm_func != compiled_gmm:
met.clear_counters()
old_cnt = xr.get_num_cached_compilation_graph()
# execute the same gmm func, expected to hit the cache
out = gmm_func(lhs.to("xla"), rhs.to("xla"), group_sizes.to("xla"))
self.assertEqual(met.counter_value('trace_pallas_cache_hit'), 1)
new_cnt = xr.get_num_cached_compilation_graph()
self.assertEqual(old_cnt, new_cnt)
self.assertTrue(torch.allclose(ref_out, out.cpu()))

# Make sure gmm doesn't fallback.
self.assertEqual(len(torch_xla._XLAC._get_executed_fallback_ops()), 0)

@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
def test_make_group_metadata(self):
from jax.experimental.pallas.ops.tpu.megablox.gmm import make_group_metadata as jax_make_group_metadata
met.clear_all()

test_grids = [
{
'group_sizes': [8, 8, 8, 8],
'm': 32,
'tm': 8
},
{
'group_sizes': [2, 14, 8, 8],
'm': 32,
'tm': 8
},
{
'group_sizes': [16, 0, 8, 8],
'm': 32,
'tm': 8
},
{
'group_sizes': [2, 0, 14, 16],
'm': 32,
'tm': 8
},
{
'group_sizes': [8, 12, 0, 12],
'm': 32,
'tm': 8
},
{
'group_sizes': [6, 12, 0, 14],
'm': 32,
'tm': 8
},
{
'group_sizes': [6, 12, 0, 14],
'm': 32,
'tm': 4
},
{
'group_sizes': [377, 588, 153, 1638, 3261, 5890, 996, 3481],
'm': 16384,
'tm': 128
},
]

for test_grid in test_grids:
jax_meta, jax_num_tiles = jax_make_group_metadata(
group_sizes=jnp.array(test_grid['group_sizes']),
m=test_grid['m'],
tm=test_grid['tm'],
start_group=0,
num_nonzero_groups=len(test_grid['group_sizes']),
)

torch_meta = _make_group_metadata(
group_sizes=torch.tensor(test_grid['group_sizes']).to(
torch.int32).to("xla"),
m=test_grid['m'],
tm=test_grid['tm'],
visit_empty_groups=True,
)

for i in range(len(jax_meta)):
self.assertTrue(
torch.all(
torch.from_numpy(np.array(jax_meta[i])) == torch_meta[i].cpu()))
self.assertEqual(jax_num_tiles, torch_meta[-1].cpu().item())

# Make sure _make_group_metadata doesn't fallback.
self.assertNotIn("aten::", met.short_metrics_report())

def test_histogram(self):
test_grids = [
{
'input': [1, 4, 4, 1, 2, 3],
'min': 1,
'max': 4,
},
{
'input': [1, 4, 4, 1, 2, 3],
'min': 2,
'max': 3,
},
{
'input': [1, 4, 4, 1, 2, 3],
'min': 0,
'max': 5,
},
{
'input': [],
'min': 0,
'max': 5,
},
]

for test_grid in test_grids:
torch_chart = torch.histc(
torch.tensor(test_grid['input'], dtype=torch.float),
bins=test_grid['max'] - test_grid['min'] + 1,
min=test_grid['min'],
max=test_grid['max'],
)

chart = _histogram(
torch.tensor(test_grid['input'], dtype=torch.int32).to("xla"),
min=test_grid['min'],
max=test_grid['max'],
)

self.assertEqual(chart.dtype, torch.int32)
self.assertTrue(torch.all(torch_chart == chart.cpu()))

def test_histogram_raise(self):
with self.assertRaisesRegex(AssertionError,
"input must be of torch.int32 dtype."):
_histogram(
torch.tensor([1, 4, 4, 1, 2, 3], dtype=torch.float),
min=4,
max=5,
)

with self.assertRaisesRegex(AssertionError,
"min must be less than or equal to max."):
_histogram(
torch.tensor([1, 4, 4, 1, 2, 3], dtype=torch.int32),
min=4,
max=3,
)

def test_sorting_input(self):
met.clear_all()
top2 = torch.tensor([[0, 2], [1, 3], [1, 2], [2, 3]]).to("xla")

# We want to create one big batch of tokens that has all top-k choices in it.
# Our tokens will thus be duplicated k-times in the batch. To do this we,
# first flatten the expert choices list and argsort it. This gives us an array
# of length B * K. We then create a tiled arange of size B * K and index
# into the expert choices list. This will give us the set of indices we need
# to gather from the xs to create this big batch.
top_flat = top2.flatten()
lhs_order = top_flat.argsort()
lhs_reverse_order = lhs_order.argsort()
lhs_indices = torch.arange(
top2.shape[0], device="xla").repeat_interleave(2)[lhs_order]
group_sizes = _histogram(top_flat.to(torch.int32), 0, 3)
torch_xla.sync()

# Make sure it doesn't fallback.
self.assertNotIn("aten::", met.short_metrics_report())
self.assertTrue(
torch.all(lhs_indices == torch.tensor([0, 1, 2, 0, 3, 2, 1, 3],
device="xla")))
self.assertTrue(
torch.all(group_sizes == torch.tensor([1, 2, 3, 2], device="xla")))

@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
def test_tgmm(self):
met.clear_all()
Expand All @@ -343,10 +187,11 @@ def test_tgmm(self):

out = tgmm(lhs.to("xla"), rhs.to("xla"), group_sizes.to("xla"))
if test_cache:
met.clear_counters()
old_cnt = xr.get_num_cached_compilation_graph()
# execute the same gmm func, expected to hit the cache
out = tgmm(lhs.to("xla"), rhs.to("xla"), group_sizes.to("xla"))
self.assertEqual(met.counter_value('trace_pallas_cache_hit'), 1)
new_cnt = xr.get_num_cached_compilation_graph()
self.assertEqual(new_cnt, old_cnt)
self.assertTrue(torch.allclose(ref_out, out.cpu()))

# Make sure tgmm doesn't fallback.
Expand All @@ -373,10 +218,11 @@ def test_tgmm_bf16(self):

out = tgmm(lhs.to("xla"), rhs.to("xla"), group_sizes.to("xla"))
if test_cache:
met.clear_counters()
old_cnt = xr.get_num_cached_compilation_graph()
# execute the same gmm func, expected to hit the cache
out = tgmm(lhs.to("xla"), rhs.to("xla"), group_sizes.to("xla"))
self.assertEqual(met.counter_value('trace_pallas_cache_hit'), 1)
new_cnt = xr.get_num_cached_compilation_graph()
self.assertEqual(new_cnt, old_cnt)
self.assertTrue(torch.allclose(ref_out, out.cpu()))

# Make sure tgmm doesn't fallback.
Expand All @@ -393,7 +239,7 @@ def test_gmm_backward(self):
lhs_dtype = rhs_dtype = torch.bfloat16

for test_cache in [False, True]:
met.clear_all()
old_cnt = xr.get_num_cached_compilation_graph()
lhs = torch.rand(m, k, dtype=lhs_dtype, requires_grad=True)
rhs = torch.rand(num_groups, k, n, dtype=rhs_dtype, requires_grad=True)
group_sizes = self._group_sizes_strategy(m=m, num_groups=num_groups)
Expand All @@ -409,8 +255,9 @@ def test_gmm_backward(self):
group_sizes.to("xla"))
# same gmm/tgmm was run for the `test_cache=False` case so the
# cache should be populated now
new_cnt = xr.get_num_cached_compilation_graph()
if test_cache:
self.assertEqual(met.counter_value('trace_pallas_cache_hit'), 2)
self.assertEqual(new_cnt, old_cnt)

self.assertTrue(torch.allclose(lhs.grad, grad_lhs.cpu()))
self.assertTrue(torch.allclose(rhs.grad, grad_rhs.cpu()))
Expand Down
8 changes: 8 additions & 0 deletions torch_xla/core/xla_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,6 +873,13 @@ def jax_func_to_xla_computation(jax_func, args, kwargs, name=None):

flattened_inputs, spec = jax.tree.flatten((args, kwargs))

def convert_arg_to_jax(a):
if a is None:
return None
elif isinstance(a, torch.dtype):
return tx.ops.mappings.t2j_dtype(a)
return a

def abstractify(a): # make a pytree leaf abstract
if a is None:
return None
Expand All @@ -881,6 +888,7 @@ def abstractify(a): # make a pytree leaf abstract
return jax.ShapeDtypeStruct(a.shape, tx.ops.mappings.t2j_dtype(a.dtype))
return a

flattened_inputs = list(convert_arg_to_jax(a) for a in flattened_inputs)
sample_inputs = tuple(abstractify(a) for a in flattened_inputs)

# Pick out the non-static args.
Expand Down
Loading
Loading