Skip to content

Commit b796514

Browse files
committed
[Kernel] Use call_jax to simplify the gmm pallas kernel wrapper
1 parent bbafe35 commit b796514

File tree

3 files changed

+29
-412
lines changed

3 files changed

+29
-412
lines changed

test/test_gmm.py

Lines changed: 16 additions & 169 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch_xla
88
import torch_xla.core.xla_model as xm
99
import torch_xla.debug.metrics as met
10-
from torch_xla.experimental.custom_kernel import gmm, _make_group_metadata, _histogram, tgmm, gmm_backward, GMM
10+
from torch_xla.experimental.custom_kernel import gmm, tgmm, gmm_backward, GMM
1111
from torch_xla import runtime as xr
1212
from torch_xla._internal import tpu
1313

@@ -120,10 +120,11 @@ def test_gmm(self):
120120
# torch.compiled version of the gmm will cache the payload in dynamo layer
121121
# hence won't trigger the trace_pallas cache
122122
if test_cache and gmm_func != compiled_gmm:
123-
met.clear_counters()
123+
old_cnt = xr.get_num_cached_compilation_graph()
124124
# execute the same gmm func, expected to hit the cache
125125
out = gmm_func(lhs.to("xla"), rhs.to("xla"), group_sizes.to("xla"))
126-
self.assertEqual(met.counter_value('trace_pallas_cache_hit'), 1)
126+
new_cnt = xr.get_num_cached_compilation_graph()
127+
self.assertEqual(old_cnt, new_cnt)
127128
self.assertTrue(torch.allclose(ref_out, out.cpu()))
128129

129130
# Make sure gmm doesn't fallback.
@@ -155,173 +156,16 @@ def test_gmm_bf16(self):
155156
# torch.compiled version of the gmm will cache the payload in dynamo layer
156157
# hence won't trigger the trace_pallas cache
157158
if test_cache and gmm_func != compiled_gmm:
158-
met.clear_counters()
159+
old_cnt = xr.get_num_cached_compilation_graph()
159160
# execute the same gmm func, expected to hit the cache
160161
out = gmm_func(lhs.to("xla"), rhs.to("xla"), group_sizes.to("xla"))
161-
self.assertEqual(met.counter_value('trace_pallas_cache_hit'), 1)
162+
new_cnt = xr.get_num_cached_compilation_graph()
163+
self.assertEqual(old_cnt, new_cnt)
162164
self.assertTrue(torch.allclose(ref_out, out.cpu()))
163165

164166
# Make sure gmm doesn't fallback.
165167
self.assertEqual(len(torch_xla._XLAC._get_executed_fallback_ops()), 0)
166168

167-
@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
168-
def test_make_group_metadata(self):
169-
from jax.experimental.pallas.ops.tpu.megablox.gmm import make_group_metadata as jax_make_group_metadata
170-
met.clear_all()
171-
172-
test_grids = [
173-
{
174-
'group_sizes': [8, 8, 8, 8],
175-
'm': 32,
176-
'tm': 8
177-
},
178-
{
179-
'group_sizes': [2, 14, 8, 8],
180-
'm': 32,
181-
'tm': 8
182-
},
183-
{
184-
'group_sizes': [16, 0, 8, 8],
185-
'm': 32,
186-
'tm': 8
187-
},
188-
{
189-
'group_sizes': [2, 0, 14, 16],
190-
'm': 32,
191-
'tm': 8
192-
},
193-
{
194-
'group_sizes': [8, 12, 0, 12],
195-
'm': 32,
196-
'tm': 8
197-
},
198-
{
199-
'group_sizes': [6, 12, 0, 14],
200-
'm': 32,
201-
'tm': 8
202-
},
203-
{
204-
'group_sizes': [6, 12, 0, 14],
205-
'm': 32,
206-
'tm': 4
207-
},
208-
{
209-
'group_sizes': [377, 588, 153, 1638, 3261, 5890, 996, 3481],
210-
'm': 16384,
211-
'tm': 128
212-
},
213-
]
214-
215-
for test_grid in test_grids:
216-
jax_meta, jax_num_tiles = jax_make_group_metadata(
217-
group_sizes=jnp.array(test_grid['group_sizes']),
218-
m=test_grid['m'],
219-
tm=test_grid['tm'],
220-
start_group=0,
221-
num_nonzero_groups=len(test_grid['group_sizes']),
222-
)
223-
224-
torch_meta = _make_group_metadata(
225-
group_sizes=torch.tensor(test_grid['group_sizes']).to(
226-
torch.int32).to("xla"),
227-
m=test_grid['m'],
228-
tm=test_grid['tm'],
229-
visit_empty_groups=True,
230-
)
231-
232-
for i in range(len(jax_meta)):
233-
self.assertTrue(
234-
torch.all(
235-
torch.from_numpy(np.array(jax_meta[i])) == torch_meta[i].cpu()))
236-
self.assertEqual(jax_num_tiles, torch_meta[-1].cpu().item())
237-
238-
# Make sure _make_group_metadata doesn't fallback.
239-
self.assertNotIn("aten::", met.short_metrics_report())
240-
241-
def test_histogram(self):
242-
test_grids = [
243-
{
244-
'input': [1, 4, 4, 1, 2, 3],
245-
'min': 1,
246-
'max': 4,
247-
},
248-
{
249-
'input': [1, 4, 4, 1, 2, 3],
250-
'min': 2,
251-
'max': 3,
252-
},
253-
{
254-
'input': [1, 4, 4, 1, 2, 3],
255-
'min': 0,
256-
'max': 5,
257-
},
258-
{
259-
'input': [],
260-
'min': 0,
261-
'max': 5,
262-
},
263-
]
264-
265-
for test_grid in test_grids:
266-
torch_chart = torch.histc(
267-
torch.tensor(test_grid['input'], dtype=torch.float),
268-
bins=test_grid['max'] - test_grid['min'] + 1,
269-
min=test_grid['min'],
270-
max=test_grid['max'],
271-
)
272-
273-
chart = _histogram(
274-
torch.tensor(test_grid['input'], dtype=torch.int32).to("xla"),
275-
min=test_grid['min'],
276-
max=test_grid['max'],
277-
)
278-
279-
self.assertEqual(chart.dtype, torch.int32)
280-
self.assertTrue(torch.all(torch_chart == chart.cpu()))
281-
282-
def test_histogram_raise(self):
283-
with self.assertRaisesRegex(AssertionError,
284-
"input must be of torch.int32 dtype."):
285-
_histogram(
286-
torch.tensor([1, 4, 4, 1, 2, 3], dtype=torch.float),
287-
min=4,
288-
max=5,
289-
)
290-
291-
with self.assertRaisesRegex(AssertionError,
292-
"min must be less than or equal to max."):
293-
_histogram(
294-
torch.tensor([1, 4, 4, 1, 2, 3], dtype=torch.int32),
295-
min=4,
296-
max=3,
297-
)
298-
299-
def test_sorting_input(self):
300-
met.clear_all()
301-
top2 = torch.tensor([[0, 2], [1, 3], [1, 2], [2, 3]]).to("xla")
302-
303-
# We want to create one big batch of tokens that has all top-k choices in it.
304-
# Our tokens will thus be duplicated k-times in the batch. To do this we,
305-
# first flatten the expert choices list and argsort it. This gives us an array
306-
# of length B * K. We then create a tiled arange of size B * K and index
307-
# into the expert choices list. This will give us the set of indices we need
308-
# to gather from the xs to create this big batch.
309-
top_flat = top2.flatten()
310-
lhs_order = top_flat.argsort()
311-
lhs_reverse_order = lhs_order.argsort()
312-
lhs_indices = torch.arange(
313-
top2.shape[0], device="xla").repeat_interleave(2)[lhs_order]
314-
group_sizes = _histogram(top_flat.to(torch.int32), 0, 3)
315-
xm.mark_step()
316-
317-
# Make sure it doesn't fallback.
318-
self.assertNotIn("aten::", met.short_metrics_report())
319-
self.assertTrue(
320-
torch.all(lhs_indices == torch.tensor([0, 1, 2, 0, 3, 2, 1, 3],
321-
device="xla")))
322-
self.assertTrue(
323-
torch.all(group_sizes == torch.tensor([1, 2, 3, 2], device="xla")))
324-
325169
@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
326170
def test_tgmm(self):
327171
met.clear_all()
@@ -343,10 +187,11 @@ def test_tgmm(self):
343187

344188
out = tgmm(lhs.to("xla"), rhs.to("xla"), group_sizes.to("xla"))
345189
if test_cache:
346-
met.clear_counters()
190+
old_cnt = xr.get_num_cached_compilation_graph()
347191
# execute the same gmm func, expected to hit the cache
348192
out = tgmm(lhs.to("xla"), rhs.to("xla"), group_sizes.to("xla"))
349-
self.assertEqual(met.counter_value('trace_pallas_cache_hit'), 1)
193+
new_cnt = xr.get_num_cached_compilation_graph()
194+
self.assertEqual(new_cnt, old_cnt)
350195
self.assertTrue(torch.allclose(ref_out, out.cpu()))
351196

352197
# Make sure tgmm doesn't fallback.
@@ -373,10 +218,11 @@ def test_tgmm_bf16(self):
373218

374219
out = tgmm(lhs.to("xla"), rhs.to("xla"), group_sizes.to("xla"))
375220
if test_cache:
376-
met.clear_counters()
221+
old_cnt = xr.get_num_cached_compilation_graph()
377222
# execute the same gmm func, expected to hit the cache
378223
out = tgmm(lhs.to("xla"), rhs.to("xla"), group_sizes.to("xla"))
379-
self.assertEqual(met.counter_value('trace_pallas_cache_hit'), 1)
224+
new_cnt = xr.get_num_cached_compilation_graph()
225+
self.assertEqual(new_cnt, old_cnt)
380226
self.assertTrue(torch.allclose(ref_out, out.cpu()))
381227

382228
# Make sure tgmm doesn't fallback.
@@ -393,7 +239,7 @@ def test_gmm_backward(self):
393239
lhs_dtype = rhs_dtype = torch.bfloat16
394240

395241
for test_cache in [False, True]:
396-
met.clear_all()
242+
old_cnt = xr.get_num_cached_compilation_graph()
397243
lhs = torch.rand(m, k, dtype=lhs_dtype, requires_grad=True)
398244
rhs = torch.rand(num_groups, k, n, dtype=rhs_dtype, requires_grad=True)
399245
group_sizes = self._group_sizes_strategy(m=m, num_groups=num_groups)
@@ -409,8 +255,9 @@ def test_gmm_backward(self):
409255
group_sizes.to("xla"))
410256
# same gmm/tgmm was run for the `test_cache=False` case so the
411257
# cache should be populated now
258+
new_cnt = xr.get_num_cached_compilation_graph()
412259
if test_cache:
413-
self.assertEqual(met.counter_value('trace_pallas_cache_hit'), 2)
260+
self.assertEqual(new_cnt, old_cnt)
414261

415262
self.assertTrue(torch.allclose(lhs.grad, grad_lhs.cpu()))
416263
self.assertTrue(torch.allclose(rhs.grad, grad_rhs.cpu()))

torch_xla/core/xla_builder.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -873,6 +873,13 @@ def jax_func_to_xla_computation(jax_func, args, kwargs, name=None):
873873

874874
flattened_inputs, spec = jax.tree.flatten((args, kwargs))
875875

876+
def convert_arg_to_jax(a):
877+
if a is None:
878+
return None
879+
elif isinstance(a, torch.dtype):
880+
return tx.ops.mappings.t2j_dtype(a)
881+
return a
882+
876883
def abstractify(a): # make a pytree leaf abstract
877884
if a is None:
878885
return None
@@ -881,6 +888,7 @@ def abstractify(a): # make a pytree leaf abstract
881888
return jax.ShapeDtypeStruct(a.shape, tx.ops.mappings.t2j_dtype(a.dtype))
882889
return a
883890

891+
flattened_inputs = list(convert_arg_to_jax(a) for a in flattened_inputs)
884892
sample_inputs = tuple(abstractify(a) for a in flattened_inputs)
885893

886894
# Pick out the non-static args.

0 commit comments

Comments
 (0)