Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 653e120

Browse files
committed
Add utility for filtering out skpped tests in large paremtrization groups
ghstack-source-id: 5c80e94 Pull Request resolved: #303
1 parent 3398526 commit 653e120

File tree

3 files changed

+83
-71
lines changed

3 files changed

+83
-71
lines changed

float8_experimental/float8_linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ def cast_x_to_float8(
296296
if torch.is_autocast_enabled():
297297
# For now, hardcode to GPU's autocast dtype
298298
# if we need CPU support in the future, we can add it
299-
autocast_dtype = torch.get_autocast_gpu_dtype()
299+
autocast_dtype = torch.get_autocast_dtype("cuda")
300300
x = x.to(autocast_dtype)
301301

302302
if self.scaling_type_x is TensorScalingType.DELAYED:

float8_experimental/float8_linear_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ def sync_float8_amax_and_scale_history(model: torch.nn.Module, fp8_layers=None)
274274
fp8_layers = get_float8_layers(model)
275275

276276
if len(fp8_layers) == 0:
277-
log.warn(
277+
log.warning(
278278
"Calling sync_float8_amax_and_scale_history on a module with no Float8Linear layers"
279279
)
280280
return

test/test_base.py

Lines changed: 81 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
import re
1010
import unittest
1111
import warnings
12+
from itertools import product
13+
from typing import Any, Callable, Dict, List, Optional, Tuple
1214

1315
import pytest
1416

@@ -52,6 +54,42 @@
5254
is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
5355

5456

57+
def filtered_parametrize(
58+
param_list: List[Tuple[str, List[Any]]],
59+
filter_func: Optional[Callable[[Dict[str, Any]], bool]] = None,
60+
):
61+
"""
62+
A decorator that works like pytest.mark.parametrize but filters out
63+
unwanted parameter combinations.
64+
65+
Args:
66+
param_list: A list of tuples, each containing (arg_name, [arg_values])
67+
filter_func: A function that takes a dictionary of parameter names and values,
68+
and returns True for valid combinations, False otherwise
69+
70+
"""
71+
72+
def decorator(func):
73+
arg_names = [param[0] for param in param_list]
74+
arg_values = [param[1] for param in param_list]
75+
76+
all_combinations = product(*arg_values)
77+
if filter_func:
78+
valid_combinations = [
79+
combo
80+
for combo in all_combinations
81+
if filter_func(dict(zip(arg_names, combo)))
82+
]
83+
else:
84+
valid_combinations = list(all_combinations)
85+
86+
return pytest.mark.parametrize(
87+
argnames=arg_names, argvalues=valid_combinations
88+
)(func)
89+
90+
return decorator
91+
92+
5593
def bitwise_identical(a: Float8Tensor, b: Float8Tensor) -> bool:
5694
assert torch.all(a._data == b._data).item(), "scales are not identical"
5795
assert torch.all(a._data == b._data).item(), "data is not identical"
@@ -230,17 +268,35 @@ def _test_linear_impl(
230268
# verify initialization flags got updated
231269
assert m_fp8.is_amax_initialized, "Amax was not properly initialized"
232270

233-
@pytest.mark.parametrize("emulate", [True, False] if is_H100 else [True])
234-
@pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)])
235-
@pytest.mark.parametrize("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC])
236-
@pytest.mark.parametrize(
237-
"scaling_type_x", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC]
238-
)
239-
@pytest.mark.parametrize(
240-
"scaling_type_w", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC]
241-
)
242-
@pytest.mark.parametrize(
243-
"scaling_type_dL_dY", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC]
271+
@staticmethod
272+
def is_valid_combination(params):
273+
if not params["emulate"]:
274+
if not torch.cuda.is_available():
275+
return False
276+
if torch.cuda.get_device_capability() < (9, 0):
277+
return False
278+
279+
if params["linear_type"] == LinearType.DYNAMIC:
280+
return all(
281+
params[key] == TensorScalingType.DYNAMIC
282+
for key in ["scaling_type_x", "scaling_type_w", "scaling_type_dL_dY"]
283+
)
284+
285+
return True
286+
287+
@filtered_parametrize(
288+
[
289+
("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)]),
290+
("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC]),
291+
("emulate", [True, False] if is_H100 else [True]),
292+
("scaling_type_x", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC]),
293+
("scaling_type_w", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC]),
294+
(
295+
"scaling_type_dL_dY",
296+
[TensorScalingType.DELAYED, TensorScalingType.DYNAMIC],
297+
),
298+
],
299+
filter_func=is_valid_combination,
244300
)
245301
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
246302
def test_linear_nobias(
@@ -252,28 +308,6 @@ def test_linear_nobias(
252308
scaling_type_w: TensorScalingType,
253309
scaling_type_dL_dY: TensorScalingType,
254310
):
255-
if not emulate:
256-
if not torch.cuda.is_available():
257-
warnings.warn("CUDA not available")
258-
pytest.skip()
259-
elif torch.cuda.get_device_capability() < (9, 0):
260-
warnings.warn(
261-
f"CUDA capability {torch.cuda.get_device_capability()} < (9.0)"
262-
)
263-
pytest.skip()
264-
if linear_type is LinearType.DYNAMIC:
265-
# Only test one combination of scaling types, as they are a no-op
266-
# for Float8DynamicLinear. It would be cleaner to split into two
267-
# tests, but IMO not worth it since Float8DynamicLinear will be
268-
# deleted soon
269-
is_all_dynamic = (
270-
scaling_type_x is TensorScalingType.DYNAMIC
271-
and scaling_type_w is TensorScalingType.DYNAMIC
272-
and scaling_type_dL_dY is TensorScalingType.DYNAMIC
273-
)
274-
if not is_all_dynamic:
275-
pytest.skip()
276-
277311
x = torch.randn(*x_shape, device="cuda")
278312
m_ref = nn.Linear(16, 32, bias=False, device="cuda")
279313
self._test_linear_impl(
@@ -286,20 +320,20 @@ def test_linear_nobias(
286320
scaling_type_dL_dY,
287321
)
288322

289-
@pytest.mark.parametrize("emulate", [True, False] if is_H100 else [True])
290-
@pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)])
291-
@pytest.mark.parametrize("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC])
292-
@pytest.mark.parametrize(
293-
"scaling_type_x", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC]
294-
)
295-
@pytest.mark.parametrize(
296-
"scaling_type_w", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC]
297-
)
298-
@pytest.mark.parametrize(
299-
"scaling_type_dL_dY", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC]
300-
)
301-
@pytest.mark.parametrize(
302-
"linear_dtype", [torch.float16, torch.bfloat16, torch.float32]
323+
@filtered_parametrize(
324+
[
325+
("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)]),
326+
("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC]),
327+
("emulate", [True, False] if is_H100 else [True]),
328+
("scaling_type_x", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC]),
329+
("scaling_type_w", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC]),
330+
(
331+
"scaling_type_dL_dY",
332+
[TensorScalingType.DELAYED, TensorScalingType.DYNAMIC],
333+
),
334+
("linear_dtype", [torch.float16, torch.bfloat16, torch.float32]),
335+
],
336+
filter_func=is_valid_combination,
303337
)
304338
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
305339
def test_linear_bias(
@@ -312,28 +346,6 @@ def test_linear_bias(
312346
emulate: bool,
313347
linear_dtype: torch.dtype,
314348
):
315-
if not emulate:
316-
if not torch.cuda.is_available():
317-
warnings.warn("CUDA not available")
318-
pytest.skip()
319-
elif torch.cuda.get_device_capability() < (9, 0):
320-
warnings.warn(
321-
f"CUDA capability {torch.cuda.get_device_capability()} < (9.0)"
322-
)
323-
pytest.skip()
324-
if linear_type is LinearType.DYNAMIC:
325-
# Only test one combination of scaling types, as they are a no-op
326-
# for Float8DynamicLinear. It would be cleaner to split into two
327-
# tests, but IMO not worth it since Float8DynamicLinear will be
328-
# deleted soon
329-
is_all_dynamic = (
330-
scaling_type_x is TensorScalingType.DYNAMIC
331-
and scaling_type_w is TensorScalingType.DYNAMIC
332-
and scaling_type_dL_dY is TensorScalingType.DYNAMIC
333-
)
334-
if not is_all_dynamic:
335-
pytest.skip()
336-
337349
x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype)
338350
m_ref = nn.Linear(16, 32, bias=True, device="cuda", dtype=linear_dtype)
339351
self._test_linear_impl(

0 commit comments

Comments
 (0)