99import re
1010import unittest
1111import warnings
12+ from itertools import product
13+ from typing import Any , Callable , Dict , List , Optional , Tuple
1214
1315import pytest
1416
5254is_H100 = torch .cuda .is_available () and torch .cuda .get_device_capability () >= (9 , 0 )
5355
5456
57+ def filtered_parametrize (
58+ param_list : List [Tuple [str , List [Any ]]],
59+ filter_func : Optional [Callable [[Dict [str , Any ]], bool ]] = None ,
60+ ):
61+ """
62+ A decorator that works like pytest.mark.parametrize but filters out
63+ unwanted parameter combinations.
64+
65+ Args:
66+ param_list: A list of tuples, each containing (arg_name, [arg_values])
67+ filter_func: A function that takes a dictionary of parameter names and values,
68+ and returns True for valid combinations, False otherwise
69+
70+ """
71+
72+ def decorator (func ):
73+ arg_names = [param [0 ] for param in param_list ]
74+ arg_values = [param [1 ] for param in param_list ]
75+
76+ all_combinations = product (* arg_values )
77+ if filter_func :
78+ valid_combinations = [
79+ combo
80+ for combo in all_combinations
81+ if filter_func (dict (zip (arg_names , combo )))
82+ ]
83+ else :
84+ valid_combinations = list (all_combinations )
85+
86+ return pytest .mark .parametrize (
87+ argnames = arg_names , argvalues = valid_combinations
88+ )(func )
89+
90+ return decorator
91+
92+
5593def bitwise_identical (a : Float8Tensor , b : Float8Tensor ) -> bool :
5694 assert torch .all (a ._data == b ._data ).item (), "scales are not identical"
5795 assert torch .all (a ._data == b ._data ).item (), "data is not identical"
@@ -230,17 +268,35 @@ def _test_linear_impl(
230268 # verify initialization flags got updated
231269 assert m_fp8 .is_amax_initialized , "Amax was not properly initialized"
232270
233- @pytest .mark .parametrize ("emulate" , [True , False ] if is_H100 else [True ])
234- @pytest .mark .parametrize ("x_shape" , [(16 , 16 ), (2 , 16 , 16 ), (3 , 2 , 16 , 16 )])
235- @pytest .mark .parametrize ("linear_type" , [LinearType .DELAYED , LinearType .DYNAMIC ])
236- @pytest .mark .parametrize (
237- "scaling_type_x" , [TensorScalingType .DELAYED , TensorScalingType .DYNAMIC ]
238- )
239- @pytest .mark .parametrize (
240- "scaling_type_w" , [TensorScalingType .DELAYED , TensorScalingType .DYNAMIC ]
241- )
242- @pytest .mark .parametrize (
243- "scaling_type_dL_dY" , [TensorScalingType .DELAYED , TensorScalingType .DYNAMIC ]
271+ @staticmethod
272+ def is_valid_combination (params ):
273+ if not params ["emulate" ]:
274+ if not torch .cuda .is_available ():
275+ return False
276+ if torch .cuda .get_device_capability () < (9 , 0 ):
277+ return False
278+
279+ if params ["linear_type" ] == LinearType .DYNAMIC :
280+ return all (
281+ params [key ] == TensorScalingType .DYNAMIC
282+ for key in ["scaling_type_x" , "scaling_type_w" , "scaling_type_dL_dY" ]
283+ )
284+
285+ return True
286+
287+ @filtered_parametrize (
288+ [
289+ ("x_shape" , [(16 , 16 ), (2 , 16 , 16 ), (3 , 2 , 16 , 16 )]),
290+ ("linear_type" , [LinearType .DELAYED , LinearType .DYNAMIC ]),
291+ ("emulate" , [True , False ] if is_H100 else [True ]),
292+ ("scaling_type_x" , [TensorScalingType .DELAYED , TensorScalingType .DYNAMIC ]),
293+ ("scaling_type_w" , [TensorScalingType .DELAYED , TensorScalingType .DYNAMIC ]),
294+ (
295+ "scaling_type_dL_dY" ,
296+ [TensorScalingType .DELAYED , TensorScalingType .DYNAMIC ],
297+ ),
298+ ],
299+ filter_func = is_valid_combination ,
244300 )
245301 @unittest .skipIf (not torch .cuda .is_available (), "CUDA not available" )
246302 def test_linear_nobias (
@@ -252,28 +308,6 @@ def test_linear_nobias(
252308 scaling_type_w : TensorScalingType ,
253309 scaling_type_dL_dY : TensorScalingType ,
254310 ):
255- if not emulate :
256- if not torch .cuda .is_available ():
257- warnings .warn ("CUDA not available" )
258- pytest .skip ()
259- elif torch .cuda .get_device_capability () < (9 , 0 ):
260- warnings .warn (
261- f"CUDA capability { torch .cuda .get_device_capability ()} < (9.0)"
262- )
263- pytest .skip ()
264- if linear_type is LinearType .DYNAMIC :
265- # Only test one combination of scaling types, as they are a no-op
266- # for Float8DynamicLinear. It would be cleaner to split into two
267- # tests, but IMO not worth it since Float8DynamicLinear will be
268- # deleted soon
269- is_all_dynamic = (
270- scaling_type_x is TensorScalingType .DYNAMIC
271- and scaling_type_w is TensorScalingType .DYNAMIC
272- and scaling_type_dL_dY is TensorScalingType .DYNAMIC
273- )
274- if not is_all_dynamic :
275- pytest .skip ()
276-
277311 x = torch .randn (* x_shape , device = "cuda" )
278312 m_ref = nn .Linear (16 , 32 , bias = False , device = "cuda" )
279313 self ._test_linear_impl (
@@ -286,20 +320,20 @@ def test_linear_nobias(
286320 scaling_type_dL_dY ,
287321 )
288322
289- @pytest . mark . parametrize ( "emulate" , [ True , False ] if is_H100 else [ True ])
290- @ pytest . mark . parametrize ( "x_shape" , [( 16 , 16 ), ( 2 , 16 , 16 ), ( 3 , 2 , 16 , 16 )])
291- @ pytest . mark . parametrize ( "linear_type " , [LinearType . DELAYED , LinearType . DYNAMIC ])
292- @ pytest . mark . parametrize (
293- "scaling_type_x " , [TensorScalingType . DELAYED , TensorScalingType . DYNAMIC ]
294- )
295- @ pytest . mark . parametrize (
296- "scaling_type_w" , [ TensorScalingType . DELAYED , TensorScalingType . DYNAMIC ]
297- )
298- @ pytest . mark . parametrize (
299- "scaling_type_dL_dY" , [ TensorScalingType . DELAYED , TensorScalingType . DYNAMIC ]
300- )
301- @ pytest . mark . parametrize (
302- "linear_dtype" , [ torch . float16 , torch . bfloat16 , torch . float32 ]
323+ @filtered_parametrize (
324+ [
325+ ( "x_shape " , [( 16 , 16 ), ( 2 , 16 , 16 ), ( 3 , 2 , 16 , 16 )]),
326+ ( "linear_type" , [ LinearType . DELAYED , LinearType . DYNAMIC ]),
327+ ( "emulate " , [True , False ] if is_H100 else [ True ]),
328+ ( "scaling_type_x" , [ TensorScalingType . DELAYED , TensorScalingType . DYNAMIC ]),
329+ ( "scaling_type_w" , [ TensorScalingType . DELAYED , TensorScalingType . DYNAMIC ]),
330+ (
331+ "scaling_type_dL_dY" ,
332+ [ TensorScalingType . DELAYED , TensorScalingType . DYNAMIC ],
333+ ),
334+ ( "linear_dtype" , [ torch . float16 , torch . bfloat16 , torch . float32 ]),
335+ ],
336+ filter_func = is_valid_combination ,
303337 )
304338 @unittest .skipIf (not torch .cuda .is_available (), "CUDA not available" )
305339 def test_linear_bias (
@@ -312,28 +346,6 @@ def test_linear_bias(
312346 emulate : bool ,
313347 linear_dtype : torch .dtype ,
314348 ):
315- if not emulate :
316- if not torch .cuda .is_available ():
317- warnings .warn ("CUDA not available" )
318- pytest .skip ()
319- elif torch .cuda .get_device_capability () < (9 , 0 ):
320- warnings .warn (
321- f"CUDA capability { torch .cuda .get_device_capability ()} < (9.0)"
322- )
323- pytest .skip ()
324- if linear_type is LinearType .DYNAMIC :
325- # Only test one combination of scaling types, as they are a no-op
326- # for Float8DynamicLinear. It would be cleaner to split into two
327- # tests, but IMO not worth it since Float8DynamicLinear will be
328- # deleted soon
329- is_all_dynamic = (
330- scaling_type_x is TensorScalingType .DYNAMIC
331- and scaling_type_w is TensorScalingType .DYNAMIC
332- and scaling_type_dL_dY is TensorScalingType .DYNAMIC
333- )
334- if not is_all_dynamic :
335- pytest .skip ()
336-
337349 x = torch .randn (* x_shape , device = "cuda" , dtype = linear_dtype )
338350 m_ref = nn .Linear (16 , 32 , bias = True , device = "cuda" , dtype = linear_dtype )
339351 self ._test_linear_impl (
0 commit comments