7
7
import torch_xla
8
8
import torch_xla .core .xla_model as xm
9
9
import torch_xla .debug .metrics as met
10
- from torch_xla .experimental .custom_kernel import gmm , _make_group_metadata , _histogram , tgmm , gmm_backward , GMM
10
+ from torch_xla .experimental .custom_kernel import gmm , tgmm , gmm_backward , GMM
11
11
from torch_xla import runtime as xr
12
12
from torch_xla ._internal import tpu
13
13
@@ -120,10 +120,11 @@ def test_gmm(self):
120
120
# torch.compiled version of the gmm will cache the payload in dynamo layer
121
121
# hence won't trigger the trace_pallas cache
122
122
if test_cache and gmm_func != compiled_gmm :
123
- met . clear_counters ()
123
+ old_cnt = xr . get_num_cached_compilation_graph ()
124
124
# execute the same gmm func, expected to hit the cache
125
125
out = gmm_func (lhs .to ("xla" ), rhs .to ("xla" ), group_sizes .to ("xla" ))
126
- self .assertEqual (met .counter_value ('trace_pallas_cache_hit' ), 1 )
126
+ new_cnt = xr .get_num_cached_compilation_graph ()
127
+ self .assertEqual (old_cnt , new_cnt )
127
128
self .assertTrue (torch .allclose (ref_out , out .cpu ()))
128
129
129
130
# Make sure gmm doesn't fallback.
@@ -155,173 +156,16 @@ def test_gmm_bf16(self):
155
156
# torch.compiled version of the gmm will cache the payload in dynamo layer
156
157
# hence won't trigger the trace_pallas cache
157
158
if test_cache and gmm_func != compiled_gmm :
158
- met . clear_counters ()
159
+ old_cnt = xr . get_num_cached_compilation_graph ()
159
160
# execute the same gmm func, expected to hit the cache
160
161
out = gmm_func (lhs .to ("xla" ), rhs .to ("xla" ), group_sizes .to ("xla" ))
161
- self .assertEqual (met .counter_value ('trace_pallas_cache_hit' ), 1 )
162
+ new_cnt = xr .get_num_cached_compilation_graph ()
163
+ self .assertEqual (old_cnt , new_cnt )
162
164
self .assertTrue (torch .allclose (ref_out , out .cpu ()))
163
165
164
166
# Make sure gmm doesn't fallback.
165
167
self .assertEqual (len (torch_xla ._XLAC ._get_executed_fallback_ops ()), 0 )
166
168
167
- @unittest .skipIf (xr .device_type () != 'TPU' , "This test only works on TPU." )
168
- def test_make_group_metadata (self ):
169
- from jax .experimental .pallas .ops .tpu .megablox .gmm import make_group_metadata as jax_make_group_metadata
170
- met .clear_all ()
171
-
172
- test_grids = [
173
- {
174
- 'group_sizes' : [8 , 8 , 8 , 8 ],
175
- 'm' : 32 ,
176
- 'tm' : 8
177
- },
178
- {
179
- 'group_sizes' : [2 , 14 , 8 , 8 ],
180
- 'm' : 32 ,
181
- 'tm' : 8
182
- },
183
- {
184
- 'group_sizes' : [16 , 0 , 8 , 8 ],
185
- 'm' : 32 ,
186
- 'tm' : 8
187
- },
188
- {
189
- 'group_sizes' : [2 , 0 , 14 , 16 ],
190
- 'm' : 32 ,
191
- 'tm' : 8
192
- },
193
- {
194
- 'group_sizes' : [8 , 12 , 0 , 12 ],
195
- 'm' : 32 ,
196
- 'tm' : 8
197
- },
198
- {
199
- 'group_sizes' : [6 , 12 , 0 , 14 ],
200
- 'm' : 32 ,
201
- 'tm' : 8
202
- },
203
- {
204
- 'group_sizes' : [6 , 12 , 0 , 14 ],
205
- 'm' : 32 ,
206
- 'tm' : 4
207
- },
208
- {
209
- 'group_sizes' : [377 , 588 , 153 , 1638 , 3261 , 5890 , 996 , 3481 ],
210
- 'm' : 16384 ,
211
- 'tm' : 128
212
- },
213
- ]
214
-
215
- for test_grid in test_grids :
216
- jax_meta , jax_num_tiles = jax_make_group_metadata (
217
- group_sizes = jnp .array (test_grid ['group_sizes' ]),
218
- m = test_grid ['m' ],
219
- tm = test_grid ['tm' ],
220
- start_group = 0 ,
221
- num_nonzero_groups = len (test_grid ['group_sizes' ]),
222
- )
223
-
224
- torch_meta = _make_group_metadata (
225
- group_sizes = torch .tensor (test_grid ['group_sizes' ]).to (
226
- torch .int32 ).to ("xla" ),
227
- m = test_grid ['m' ],
228
- tm = test_grid ['tm' ],
229
- visit_empty_groups = True ,
230
- )
231
-
232
- for i in range (len (jax_meta )):
233
- self .assertTrue (
234
- torch .all (
235
- torch .from_numpy (np .array (jax_meta [i ])) == torch_meta [i ].cpu ()))
236
- self .assertEqual (jax_num_tiles , torch_meta [- 1 ].cpu ().item ())
237
-
238
- # Make sure _make_group_metadata doesn't fallback.
239
- self .assertNotIn ("aten::" , met .short_metrics_report ())
240
-
241
- def test_histogram (self ):
242
- test_grids = [
243
- {
244
- 'input' : [1 , 4 , 4 , 1 , 2 , 3 ],
245
- 'min' : 1 ,
246
- 'max' : 4 ,
247
- },
248
- {
249
- 'input' : [1 , 4 , 4 , 1 , 2 , 3 ],
250
- 'min' : 2 ,
251
- 'max' : 3 ,
252
- },
253
- {
254
- 'input' : [1 , 4 , 4 , 1 , 2 , 3 ],
255
- 'min' : 0 ,
256
- 'max' : 5 ,
257
- },
258
- {
259
- 'input' : [],
260
- 'min' : 0 ,
261
- 'max' : 5 ,
262
- },
263
- ]
264
-
265
- for test_grid in test_grids :
266
- torch_chart = torch .histc (
267
- torch .tensor (test_grid ['input' ], dtype = torch .float ),
268
- bins = test_grid ['max' ] - test_grid ['min' ] + 1 ,
269
- min = test_grid ['min' ],
270
- max = test_grid ['max' ],
271
- )
272
-
273
- chart = _histogram (
274
- torch .tensor (test_grid ['input' ], dtype = torch .int32 ).to ("xla" ),
275
- min = test_grid ['min' ],
276
- max = test_grid ['max' ],
277
- )
278
-
279
- self .assertEqual (chart .dtype , torch .int32 )
280
- self .assertTrue (torch .all (torch_chart == chart .cpu ()))
281
-
282
- def test_histogram_raise (self ):
283
- with self .assertRaisesRegex (AssertionError ,
284
- "input must be of torch.int32 dtype." ):
285
- _histogram (
286
- torch .tensor ([1 , 4 , 4 , 1 , 2 , 3 ], dtype = torch .float ),
287
- min = 4 ,
288
- max = 5 ,
289
- )
290
-
291
- with self .assertRaisesRegex (AssertionError ,
292
- "min must be less than or equal to max." ):
293
- _histogram (
294
- torch .tensor ([1 , 4 , 4 , 1 , 2 , 3 ], dtype = torch .int32 ),
295
- min = 4 ,
296
- max = 3 ,
297
- )
298
-
299
- def test_sorting_input (self ):
300
- met .clear_all ()
301
- top2 = torch .tensor ([[0 , 2 ], [1 , 3 ], [1 , 2 ], [2 , 3 ]]).to ("xla" )
302
-
303
- # We want to create one big batch of tokens that has all top-k choices in it.
304
- # Our tokens will thus be duplicated k-times in the batch. To do this we,
305
- # first flatten the expert choices list and argsort it. This gives us an array
306
- # of length B * K. We then create a tiled arange of size B * K and index
307
- # into the expert choices list. This will give us the set of indices we need
308
- # to gather from the xs to create this big batch.
309
- top_flat = top2 .flatten ()
310
- lhs_order = top_flat .argsort ()
311
- lhs_reverse_order = lhs_order .argsort ()
312
- lhs_indices = torch .arange (
313
- top2 .shape [0 ], device = "xla" ).repeat_interleave (2 )[lhs_order ]
314
- group_sizes = _histogram (top_flat .to (torch .int32 ), 0 , 3 )
315
- xm .mark_step ()
316
-
317
- # Make sure it doesn't fallback.
318
- self .assertNotIn ("aten::" , met .short_metrics_report ())
319
- self .assertTrue (
320
- torch .all (lhs_indices == torch .tensor ([0 , 1 , 2 , 0 , 3 , 2 , 1 , 3 ],
321
- device = "xla" )))
322
- self .assertTrue (
323
- torch .all (group_sizes == torch .tensor ([1 , 2 , 3 , 2 ], device = "xla" )))
324
-
325
169
@unittest .skipIf (xr .device_type () != 'TPU' , "This test only works on TPU." )
326
170
def test_tgmm (self ):
327
171
met .clear_all ()
@@ -343,10 +187,11 @@ def test_tgmm(self):
343
187
344
188
out = tgmm (lhs .to ("xla" ), rhs .to ("xla" ), group_sizes .to ("xla" ))
345
189
if test_cache :
346
- met . clear_counters ()
190
+ old_cnt = xr . get_num_cached_compilation_graph ()
347
191
# execute the same gmm func, expected to hit the cache
348
192
out = tgmm (lhs .to ("xla" ), rhs .to ("xla" ), group_sizes .to ("xla" ))
349
- self .assertEqual (met .counter_value ('trace_pallas_cache_hit' ), 1 )
193
+ new_cnt = xr .get_num_cached_compilation_graph ()
194
+ self .assertEqual (new_cnt , old_cnt )
350
195
self .assertTrue (torch .allclose (ref_out , out .cpu ()))
351
196
352
197
# Make sure tgmm doesn't fallback.
@@ -373,10 +218,11 @@ def test_tgmm_bf16(self):
373
218
374
219
out = tgmm (lhs .to ("xla" ), rhs .to ("xla" ), group_sizes .to ("xla" ))
375
220
if test_cache :
376
- met . clear_counters ()
221
+ old_cnt = xr . get_num_cached_compilation_graph ()
377
222
# execute the same gmm func, expected to hit the cache
378
223
out = tgmm (lhs .to ("xla" ), rhs .to ("xla" ), group_sizes .to ("xla" ))
379
- self .assertEqual (met .counter_value ('trace_pallas_cache_hit' ), 1 )
224
+ new_cnt = xr .get_num_cached_compilation_graph ()
225
+ self .assertEqual (new_cnt , old_cnt )
380
226
self .assertTrue (torch .allclose (ref_out , out .cpu ()))
381
227
382
228
# Make sure tgmm doesn't fallback.
@@ -393,7 +239,7 @@ def test_gmm_backward(self):
393
239
lhs_dtype = rhs_dtype = torch .bfloat16
394
240
395
241
for test_cache in [False , True ]:
396
- met . clear_all ()
242
+ old_cnt = xr . get_num_cached_compilation_graph ()
397
243
lhs = torch .rand (m , k , dtype = lhs_dtype , requires_grad = True )
398
244
rhs = torch .rand (num_groups , k , n , dtype = rhs_dtype , requires_grad = True )
399
245
group_sizes = self ._group_sizes_strategy (m = m , num_groups = num_groups )
@@ -409,8 +255,9 @@ def test_gmm_backward(self):
409
255
group_sizes .to ("xla" ))
410
256
# same gmm/tgmm was run for the `test_cache=False` case so the
411
257
# cache should be populated now
258
+ new_cnt = xr .get_num_cached_compilation_graph ()
412
259
if test_cache :
413
- self .assertEqual (met . counter_value ( 'trace_pallas_cache_hit' ), 2 )
260
+ self .assertEqual (new_cnt , old_cnt )
414
261
415
262
self .assertTrue (torch .allclose (lhs .grad , grad_lhs .cpu ()))
416
263
self .assertTrue (torch .allclose (rhs .grad , grad_rhs .cpu ()))
0 commit comments