From fcf8a207a32ad72710390760f2034716cec20ab6 Mon Sep 17 00:00:00 2001 From: drisspg Date: Wed, 3 Jul 2024 11:53:40 -0700 Subject: [PATCH 1/3] Update [ghstack-poisoned] --- test/test_base.py | 144 ++++++++++++++++++++++++---------------------- 1 file changed, 75 insertions(+), 69 deletions(-) diff --git a/test/test_base.py b/test/test_base.py index 1fee3bc..86c1462 100644 --- a/test/test_base.py +++ b/test/test_base.py @@ -9,6 +9,7 @@ import re import unittest import warnings +from itertools import product import pytest @@ -52,6 +53,37 @@ is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) +def filtered_parametrize(param_list, filter_func=None): + """ + A decorator that works like pytest.mark.parametrize but filters out + unwanted parameter combinations. + + :param param_list: A list of tuples, each containing (arg_name, [arg_values]) + :param filter_func: A function that takes a dictionary of parameter names and values, + and returns True for valid combinations, False otherwise + """ + + def decorator(func): + arg_names = [param[0] for param in param_list] + arg_values = [param[1] for param in param_list] + + all_combinations = product(*arg_values) + if filter_func: + valid_combinations = [ + combo + for combo in all_combinations + if filter_func(dict(zip(arg_names, combo))) + ] + else: + valid_combinations = list(all_combinations) + + return pytest.mark.parametrize( + argnames=arg_names, argvalues=valid_combinations + )(func) + + return decorator + + def bitwise_identical(a: Float8Tensor, b: Float8Tensor) -> bool: assert torch.all(a._data == b._data).item(), "scales are not identical" assert torch.all(a._data == b._data).item(), "data is not identical" @@ -230,17 +262,35 @@ def _test_linear_impl( # verify initialization flags got updated assert m_fp8.is_amax_initialized, "Amax was not properly initialized" - @pytest.mark.parametrize("emulate", [True, False] if is_H100 else [True]) - @pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)]) - @pytest.mark.parametrize("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC]) - @pytest.mark.parametrize( - "scaling_type_x", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] - ) - @pytest.mark.parametrize( - "scaling_type_w", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] - ) - @pytest.mark.parametrize( - "scaling_type_dL_dY", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] + @staticmethod + def is_valid_combination(params): + if not params["emulate"]: + if not torch.cuda.is_available(): + return False + if torch.cuda.get_device_capability() < (9, 0): + return False + + if params["linear_type"] == LinearType.DYNAMIC: + return all( + params[key] == TensorScalingType.DYNAMIC + for key in ["scaling_type_x", "scaling_type_w", "scaling_type_dL_dY"] + ) + + return True + + @filtered_parametrize( + [ + ("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)]), + ("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC]), + ("emulate", [True, False] if is_H100 else [True]), + ("scaling_type_x", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC]), + ("scaling_type_w", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC]), + ( + "scaling_type_dL_dY", + [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC], + ), + ], + filter_func=is_valid_combination, ) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_linear_nobias( @@ -252,28 +302,6 @@ def test_linear_nobias( scaling_type_w: TensorScalingType, scaling_type_dL_dY: TensorScalingType, ): - if not emulate: - if not torch.cuda.is_available(): - warnings.warn("CUDA not available") - pytest.skip() - elif torch.cuda.get_device_capability() < (9, 0): - warnings.warn( - f"CUDA capability {torch.cuda.get_device_capability()} < (9.0)" - ) - pytest.skip() - if linear_type is LinearType.DYNAMIC: - # Only test one combination of scaling types, as they are a no-op - # for Float8DynamicLinear. It would be cleaner to split into two - # tests, but IMO not worth it since Float8DynamicLinear will be - # deleted soon - is_all_dynamic = ( - scaling_type_x is TensorScalingType.DYNAMIC - and scaling_type_w is TensorScalingType.DYNAMIC - and scaling_type_dL_dY is TensorScalingType.DYNAMIC - ) - if not is_all_dynamic: - pytest.skip() - x = torch.randn(*x_shape, device="cuda") m_ref = nn.Linear(16, 32, bias=False, device="cuda") self._test_linear_impl( @@ -286,20 +314,20 @@ def test_linear_nobias( scaling_type_dL_dY, ) - @pytest.mark.parametrize("emulate", [True, False] if is_H100 else [True]) - @pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)]) - @pytest.mark.parametrize("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC]) - @pytest.mark.parametrize( - "scaling_type_x", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] - ) - @pytest.mark.parametrize( - "scaling_type_w", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] - ) - @pytest.mark.parametrize( - "scaling_type_dL_dY", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] - ) - @pytest.mark.parametrize( - "linear_dtype", [torch.float16, torch.bfloat16, torch.float32] + @filtered_parametrize( + [ + ("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)]), + ("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC]), + ("emulate", [True, False] if is_H100 else [True]), + ("scaling_type_x", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC]), + ("scaling_type_w", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC]), + ( + "scaling_type_dL_dY", + [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC], + ), + ("linear_dtype", [torch.float16, torch.bfloat16, torch.float32]), + ], + filter_func=is_valid_combination, ) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_linear_bias( @@ -312,28 +340,6 @@ def test_linear_bias( emulate: bool, linear_dtype: torch.dtype, ): - if not emulate: - if not torch.cuda.is_available(): - warnings.warn("CUDA not available") - pytest.skip() - elif torch.cuda.get_device_capability() < (9, 0): - warnings.warn( - f"CUDA capability {torch.cuda.get_device_capability()} < (9.0)" - ) - pytest.skip() - if linear_type is LinearType.DYNAMIC: - # Only test one combination of scaling types, as they are a no-op - # for Float8DynamicLinear. It would be cleaner to split into two - # tests, but IMO not worth it since Float8DynamicLinear will be - # deleted soon - is_all_dynamic = ( - scaling_type_x is TensorScalingType.DYNAMIC - and scaling_type_w is TensorScalingType.DYNAMIC - and scaling_type_dL_dY is TensorScalingType.DYNAMIC - ) - if not is_all_dynamic: - pytest.skip() - x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype) m_ref = nn.Linear(16, 32, bias=True, device="cuda", dtype=linear_dtype) self._test_linear_impl( From a9bae50fd117df5fc95e036f87e7017ebeef77b4 Mon Sep 17 00:00:00 2001 From: drisspg Date: Wed, 3 Jul 2024 11:56:43 -0700 Subject: [PATCH 2/3] Update [ghstack-poisoned] --- float8_experimental/float8_linear.py | 2 +- float8_experimental/float8_linear_utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/float8_experimental/float8_linear.py b/float8_experimental/float8_linear.py index 4be0c27..cca9cab 100644 --- a/float8_experimental/float8_linear.py +++ b/float8_experimental/float8_linear.py @@ -296,7 +296,7 @@ def cast_x_to_float8( if torch.is_autocast_enabled(): # For now, hardcode to GPU's autocast dtype # if we need CPU support in the future, we can add it - autocast_dtype = torch.get_autocast_gpu_dtype() + autocast_dtype = torch.get_autocast_dtype("cuda") x = x.to(autocast_dtype) if self.scaling_type_x is TensorScalingType.DELAYED: diff --git a/float8_experimental/float8_linear_utils.py b/float8_experimental/float8_linear_utils.py index b1a17e4..28f3483 100644 --- a/float8_experimental/float8_linear_utils.py +++ b/float8_experimental/float8_linear_utils.py @@ -274,7 +274,7 @@ def sync_float8_amax_and_scale_history(model: torch.nn.Module, fp8_layers=None) fp8_layers = get_float8_layers(model) if len(fp8_layers) == 0: - log.warn( + log.warning( "Calling sync_float8_amax_and_scale_history on a module with no Float8Linear layers" ) return From bdb858623a4d5e55cc374176f55b6f4afc6a02be Mon Sep 17 00:00:00 2001 From: drisspg Date: Wed, 3 Jul 2024 12:00:44 -0700 Subject: [PATCH 3/3] Update [ghstack-poisoned] --- test/test_base.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/test/test_base.py b/test/test_base.py index 86c1462..2fc24bb 100644 --- a/test/test_base.py +++ b/test/test_base.py @@ -10,6 +10,7 @@ import unittest import warnings from itertools import product +from typing import Any, Callable, Dict, List, Optional, Tuple import pytest @@ -53,14 +54,19 @@ is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) -def filtered_parametrize(param_list, filter_func=None): +def filtered_parametrize( + param_list: List[Tuple[str, List[Any]]], + filter_func: Optional[Callable[[Dict[str, Any]], bool]] = None, +): """ A decorator that works like pytest.mark.parametrize but filters out unwanted parameter combinations. - :param param_list: A list of tuples, each containing (arg_name, [arg_values]) - :param filter_func: A function that takes a dictionary of parameter names and values, - and returns True for valid combinations, False otherwise + Args: + param_list: A list of tuples, each containing (arg_name, [arg_values]) + filter_func: A function that takes a dictionary of parameter names and values, + and returns True for valid combinations, False otherwise + """ def decorator(func):