diff --git a/tests/pytorch/test_cuda_graphs.py b/tests/pytorch/test_cuda_graphs.py index 60a5a1ea99..010050baea 100644 --- a/tests/pytorch/test_cuda_graphs.py +++ b/tests/pytorch/test_cuda_graphs.py @@ -13,24 +13,25 @@ LayerNormLinear, LayerNormMLP, Linear, - make_graphed_callables, MultiheadAttention, TransformerLayer, fp8_autocast, fp8_model_init, + make_graphed_callables, ) from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.utils import is_bf16_compatible +import transformer_engine.pytorch.ops as te_ops -# Only run FP8 tests on H100. +# Check if FP8 is supported. fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() +# Record initial RNG state. seed = 1234 torch.manual_seed(seed) torch.cuda.manual_seed(seed) -# Record initial RNG state from script run. _cpu_rng_state = torch.get_rng_state() _cuda_rng_state = torch.cuda.get_rng_state() @@ -48,17 +49,14 @@ class ModelConfig: model_configs = {"small": ModelConfig(2, 32, 64, 2, 32)} -modules = ["transformer", "layernorm_mlp", "layernorm_linear", "linear", "mha", "dpa"] - -all_boolean = [True, False] - -dtypes = [torch.float32, torch.float16] +# Supported data types +dtypes: List[torch.dtype] = [torch.float32, torch.float16] if is_bf16_compatible(): # bf16 requires sm_80 or higher dtypes.append(torch.bfloat16) def reset_rng_states() -> None: - """revert back to initial RNG state.""" + """Revert to initial RNG state.""" torch.set_rng_state(_cpu_rng_state) torch.cuda.set_rng_state(_cuda_rng_state) @@ -70,64 +68,40 @@ def reset_global_fp8_state(): def assert_all_equal(l1: List[torch.Tensor], l2: List[torch.Tensor], names=None) -> bool: - """Ensures two lists are equal.""" + """Check that two lists of tensors match exactly.""" assert len(l1) == len(l2), "Unequal number of outputs." - failed = False - failed_tensors = "" + failure_message = "Output mismatches in:" + failed_tensors = [] for i, (t1, t2) in enumerate(zip(l1, l2)): if not torch.equal(t1, t2): - failed = True - failed_tensors += ( - f" {names[i]}\n" if names is not None else f" tensor at idx={i}\n" - ) - assert not failed, "Output mismatches in:\n" + failed_tensors + failure_message += "\n " + if names is None: + failure_message += f"tensor at idx={i}" + else: + failure_message += names[i] + failed_tensors.append((t1, t2)) + if failed_tensors: + print(failure_message) + t1, t2 = failed_tensors[0] + torch.testing.assert_close(t1, t2, rtol=0, atol=0) def generate_data( - config: ModelConfig, + model_config: ModelConfig, dtype: torch.dtype, - dpa: bool = False, warmup: bool = False, - return_grad_output: bool = False, -) -> Tuple[List[torch.Tensor], torch.Tensor]: + requires_grad: bool = True, +) -> torch.Tensor: """Generate synthetic data.""" gen_func = torch.ones if warmup else torch.randn - if dpa: - inputs = [ - gen_func( - config.sequence_length, - config.batch_size, - config.num_heads, - config.kv_channels, - device="cuda", - requires_grad=True, - dtype=dtype, - ) - for _ in range(3) - ] - else: - inputs = [ - gen_func( - config.sequence_length, - config.batch_size, - config.hidden_size, - device="cuda", - requires_grad=True, - dtype=dtype, - ) - ] - - if not return_grad_output: - return inputs - - grad_output = torch.randn( - config.sequence_length, - config.batch_size, - config.hidden_size, + return gen_func( + model_config.sequence_length, + model_config.batch_size, + model_config.hidden_size, device="cuda", + requires_grad=requires_grad, dtype=dtype, ) - return inputs, grad_output def get_outputs( @@ -157,30 +131,44 @@ def forward(self, input_: torch.Tensor, **kwargs) -> torch.Tensor: return x +# Supported modules +_test_cuda_graphs_modules: List[str] = [ + "transformer", + "layernorm_mlp", + "layernorm_linear", + "linear", + "mha", + "linear_op", +] + + def _test_cuda_graphs( *, - config: ModelConfig, + graph_mode: str, + module: str, + model_config: ModelConfig, num_layers: int, dtype: torch.dtype, fp8: bool, fp8_params: bool, fp8_weight_caching: bool, - module: str, - graph_mode: str, ) -> List[torch.Tensor]: """Helper function for CUDA graph test.""" reset_rng_states() FP8GlobalStateManager.reset() - dpa = module == "dpa" + # Operation-based API does not support FP8 weight caching. + if module == "linear_op": + fp8_weight_caching = False + + # Create modules. with fp8_model_init(enabled=fp8_params): - # Create modules. if module == "transformer": modules = [ TransformerLayer( - config.hidden_size, - config.hidden_size, - config.num_heads, + model_config.hidden_size, + model_config.hidden_size, + model_config.num_heads, hidden_dropout=0.0, attention_dropout=0.0, fuse_qkv_params=True, @@ -190,37 +178,56 @@ def _test_cuda_graphs( ] elif module == "layernorm_mlp": modules = [ - LayerNormMLP(config.hidden_size, config.hidden_size, params_dtype=dtype) + LayerNormMLP( + model_config.hidden_size, + model_config.hidden_size, + params_dtype=dtype, + ) for _ in range(num_layers) ] elif module == "layernorm_linear": modules = [ - LayerNormLinear(config.hidden_size, config.hidden_size, params_dtype=dtype) + LayerNormLinear( + model_config.hidden_size, + model_config.hidden_size, + params_dtype=dtype, + ) for _ in range(num_layers) ] elif module == "mha": modules = [ MultiheadAttention( - config.hidden_size, - config.num_heads, + model_config.hidden_size, + model_config.num_heads, attention_dropout=0.0, params_dtype=dtype, fuse_qkv_params=True, ) for _ in range(num_layers) ] - elif dpa: - assert config.hidden_size % config.num_heads == 0, "Err." - assert num_layers == 1, "Err." + elif module == "linear": modules = [ - DotProductAttention(config.num_heads, config.kv_channels, attention_dropout=0.0) + Linear( + model_config.hidden_size, + model_config.hidden_size, + device="cuda", + params_dtype=dtype, + ) for _ in range(num_layers) ] - else: + elif module == "linear_op": modules = [ - Linear(config.hidden_size, config.hidden_size, device="cuda", params_dtype=dtype) + te_ops.Sequential( + te_ops.Linear( + model_config.hidden_size, + model_config.hidden_size, + dtype=dtype, + ), + ) for _ in range(num_layers) ] + else: + raise ValueError(f"Unknown module type ({module})") # Initialize gradient buffers. for module in modules: @@ -230,111 +237,208 @@ def _test_cuda_graphs( # Generate model and wrap API to return graphed version. if graph_mode == "full": # Graph entire model at once. - model = modules[0] if dpa else torch.nn.Sequential(*modules) + model = torch.nn.Sequential(*modules) model = make_graphed_callables( model, - generate_data(config, dtype, dpa=dpa, warmup=True), + (generate_data(model_config, dtype, warmup=True),), num_warmup_iters=10, fp8_enabled=fp8, fp8_weight_caching=fp8_weight_caching, ) elif graph_mode == "individual": - # Graph individual modules + # Graph individual modules. modules = [ make_graphed_callables( module, - generate_data(config, dtype, dpa=dpa, warmup=True), + (generate_data(model_config, dtype, warmup=True),), num_warmup_iters=10, fp8_enabled=fp8, fp8_weight_caching=fp8_weight_caching, ) for module in modules ] - model = modules[0] if dpa else _Sequential(*modules) + model = _Sequential(*modules) else: - model = modules[0] if dpa else _Sequential(*modules) + model = _Sequential(*modules) # Optimizer. - if not dpa: - optimizer = torch.optim.SGD(model.parameters(), lr=0.001) + optimizer = torch.optim.SGD(model.parameters(), lr=0.001) - # Launch. + # Training steps. for _ in range(3): - if not dpa: - optimizer.zero_grad(set_to_none=False) + optimizer.zero_grad(set_to_none=False) for grad_accumulation_step in range(2): - inputs, grad_output = generate_data(config, dtype, dpa=dpa, return_grad_output=True) + input_ = generate_data(model_config, dtype) + grad_output = generate_data(model_config, dtype, requires_grad=False) with fp8_autocast(enabled=fp8): kwargs = {} if fp8_weight_caching: kwargs["is_first_microbatch"] = grad_accumulation_step == 0 - output = model(*inputs, **kwargs) + output = model(input_, **kwargs) output.backward(grad_output) - if not dpa: - optimizer.step() + optimizer.step() return get_outputs(model, output) +@pytest.mark.parametrize("module", _test_cuda_graphs_modules) @pytest.mark.parametrize("dtype", dtypes) -@pytest.mark.parametrize("model", model_configs.keys()) -@pytest.mark.parametrize("num_layers", [1, 3]) -@pytest.mark.parametrize("fp8", all_boolean) -@pytest.mark.parametrize("fp8_params", all_boolean) -@pytest.mark.parametrize("fp8_weight_caching", all_boolean) -@pytest.mark.parametrize("module", modules) -def test_gpt_make_graphed_callables( +@pytest.mark.parametrize("fp8", (False, True)) +@pytest.mark.parametrize("fp8_params", (False, True)) +def test_make_graphed_callables( + *, + module: str, + model_config: str = "small", + num_layers: int = 3, dtype: torch.dtype, - model: str, - num_layers: int, fp8: bool, fp8_params: bool, - fp8_weight_caching: bool, - module: str, + fp8_weight_caching: bool = False, ) -> None: + + # Skip invalid configurations. if fp8 and not fp8_available: pytest.skip(reason_for_no_fp8) if fp8_params and not fp8: pytest.skip("FP8 needed for FP8 parameters.") if fp8_weight_caching and not fp8: pytest.skip("FP8 needed for FP8 parameters.") - if module == "dpa" and num_layers > 1: - pytest.skip("Max 1 layer for DPA.") - - config = model_configs[model] + # Run model with different CUDA graph settings. + model_config = model_configs[model_config] kwargs = dict( - config=config, + module=module, + model_config=model_config, num_layers=num_layers, dtype=dtype, fp8=fp8, fp8_params=fp8_params, fp8_weight_caching=fp8_weight_caching, - module=module, ) outputs = _test_cuda_graphs(graph_mode="none", **kwargs) graph_outputs_mode1 = _test_cuda_graphs(graph_mode="full", **kwargs) graph_outputs_mode2 = _test_cuda_graphs(graph_mode="individual", **kwargs) - # Check that results match + # Check that results match. assert_all_equal(outputs, graph_outputs_mode1) assert_all_equal(outputs, graph_outputs_mode2) -def _test_cuda_graphs_with_kwargs( +_test_make_graphed_callables_with_fp8_weight_caching_modules = [ + "transformer", + "layernorm_mlp", + "layernorm_linear", + "linear", + "mha", +] + + +@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) +@pytest.mark.parametrize( + "module", + _test_make_graphed_callables_with_fp8_weight_caching_modules, +) +@pytest.mark.parametrize("fp8_params", (False, True)) +def test_make_graphed_callables_with_fp8_weight_caching( *, - config: ModelConfig, + module: str, + fp8_params: bool, +) -> None: + test_make_graphed_callables( + module=module, + dtype=torch.float32, + fp8=True, + fp8_params=fp8_params, + fp8_weight_caching=True, + ) + + +def generate_data_for_dot_product_attention( + model_config: ModelConfig, dtype: torch.dtype, + warmup: bool = False, +) -> List[torch.Tensor]: + """Generate synthetic data for dot product attention.""" + gen_func = torch.ones if warmup else torch.randn + return [ + gen_func( + model_config.sequence_length, + model_config.batch_size, + model_config.num_heads, + model_config.kv_channels, + device="cuda", + requires_grad=True, + dtype=dtype, + ) + for _ in range(3) + ] + + +def _test_cuda_graphs_with_dot_product_attention( + *, with_graph: bool, + model_config: ModelConfig, + dtype: torch.dtype, ) -> List[torch.Tensor]: - """Simulate Megatron-LM interleaved pipeline parallelism.""" + """Helper function for CUDA graph test.""" + reset_rng_states() + FP8GlobalStateManager.reset() + + # Create dot product attention module. + assert model_config.hidden_size % model_config.num_heads == 0 + model = DotProductAttention( + model_config.num_heads, + model_config.kv_channels, + attention_dropout=0.0, + ) + + # Graph model if needed. + if with_graph: + model = make_graphed_callables( + model, + generate_data_for_dot_product_attention(model_config, dtype, warmup=True), + num_warmup_iters=10, + fp8_enabled=False, + ) + + # Forward and backward passes. + for _ in range(3): + inputs = generate_data_for_dot_product_attention(model_config, dtype) + grad_output = generate_data(model_config, dtype, requires_grad=False) + output = model(*inputs) + output.backward(grad_output) + + return get_outputs(model, output) + + +@pytest.mark.parametrize("dtype", dtypes) +def test_make_graphed_callables_with_dot_product_attention( + *, + model_config: str = "small", + dtype: torch.dtype, +) -> None: + """Test CUDA graphs with dot product attention.""" + model_config = model_configs[model_config] + kwargs = dict(model_config=model_config, dtype=dtype) + outputs = _test_cuda_graphs_with_dot_product_attention(with_graph=False, **kwargs) + graph_outputs = _test_cuda_graphs_with_dot_product_attention(with_graph=True, **kwargs) + assert_all_equal(outputs, graph_outputs) + + +def _test_cuda_graphs_with_kwargs( + *, + with_graph: bool, + model_config: ModelConfig, + dtype: torch.dtype, +) -> List[torch.Tensor]: + """Helper function for CUDA graph test with keyword arguments.""" reset_rng_states() # Initialize model. model = TransformerLayer( - config.hidden_size, - config.hidden_size, - config.num_heads, + model_config.hidden_size, + model_config.hidden_size, + model_config.num_heads, hidden_dropout=0.0, attention_dropout=0.0, self_attn_mask_type="arbitrary", @@ -349,13 +453,18 @@ def _test_cuda_graphs_with_kwargs( # Make graphed version of model if needed. if with_graph: attn_mask = torch.zeros( - (config.batch_size, 1, config.sequence_length, config.sequence_length), + ( + model_config.batch_size, + 1, + model_config.sequence_length, + model_config.sequence_length, + ), dtype=torch.bool, device="cuda", ) model = make_graphed_callables( model, - generate_data(config, dtype, warmup=True), + (generate_data(model_config, dtype, warmup=True),), sample_kwargs=dict(attention_mask=attn_mask), allow_unused_input=True, ) @@ -367,14 +476,20 @@ def _test_cuda_graphs_with_kwargs( for _ in range(3): optimizer.zero_grad(set_to_none=False) for grad_accumulation_step in range(2): - inputs, grad_output = generate_data(config, dtype, return_grad_output=True) + input_ = generate_data(model_config, dtype) + grad_output = generate_data(model_config, dtype, requires_grad=False) attn_mask = torch.randint( 2, - (config.batch_size, 1, config.sequence_length, config.sequence_length), + ( + model_config.batch_size, + 1, + model_config.sequence_length, + model_config.sequence_length, + ), dtype=torch.bool, device="cuda", ) - output = model(*inputs, attention_mask=attn_mask) + output = model(input_, attention_mask=attn_mask) output.backward(grad_output) optimizer.step() @@ -382,12 +497,13 @@ def _test_cuda_graphs_with_kwargs( def test_make_graphed_callables_with_kwargs( + *, + model_config: str = "small", dtype: torch.dtype = torch.float32, - model: str = "small", ) -> None: """Test CUDA graphs with keyword arguments.""" - config = model_configs[model] - kwargs = dict(config=config, dtype=dtype) + model_config = model_configs[model_config] + kwargs = dict(model_config=model_config, dtype=dtype) outputs = _test_cuda_graphs_with_kwargs(with_graph=False, **kwargs) graph_outputs = _test_cuda_graphs_with_kwargs(with_graph=True, **kwargs) assert_all_equal(outputs, graph_outputs) @@ -395,9 +511,9 @@ def test_make_graphed_callables_with_kwargs( def _test_cuda_graphs_with_interleaved_pipeline_parallelism( *, - config: ModelConfig, - dtype: torch.dtype, with_graph: bool, + model_config: ModelConfig, + dtype: torch.dtype, ) -> List[torch.Tensor]: """Simulate Megatron-LM interleaved pipeline parallelism.""" reset_rng_states() @@ -411,8 +527,8 @@ def _test_cuda_graphs_with_interleaved_pipeline_parallelism( model = torch.nn.ModuleList( [ Linear( - config.hidden_size, - config.hidden_size, + model_config.hidden_size, + model_config.hidden_size, params_dtype=dtype, ) for _ in range(num_layers) @@ -430,7 +546,8 @@ def _test_cuda_graphs_with_interleaved_pipeline_parallelism( } if with_graph: sample_args = tuple( - generate_data(config, dtype, warmup=True) for _ in range(num_layers * num_microbatches) + (generate_data(model_config, dtype, warmup=True),) + for _ in range(num_layers * num_microbatches) ) layer_forwards = make_graphed_callables( tuple(model), @@ -455,9 +572,10 @@ def _test_cuda_graphs_with_interleaved_pipeline_parallelism( grad_outputs = {} for layer_idx in range(num_layers): for microbatch_idx in range(num_microbatches): - x, dy = generate_data(config, dtype, return_grad_output=True) + x = generate_data(model_config, dtype) + dy = generate_data(model_config, dtype, requires_grad=False) idxs = (layer_idx, microbatch_idx) - inputs[idxs] = x[0] + inputs[idxs] = x grad_outputs[idxs] = dy # Cache for layer outputs. @@ -494,12 +612,13 @@ def backward(layer_idx: int, microbatch_idx: int): def test_make_graphed_callables_with_interleaved_pipeline_parallelism( + *, + model_config: str = "small", dtype: torch.dtype = torch.float16, - model: str = "small", ) -> None: """Test CUDA graphs with Megatron-LM interleaved pipeline parallelism.""" - config = model_configs[model] - kwargs = dict(config=config, dtype=dtype) + model_config = model_configs[model_config] + kwargs = dict(model_config=model_config, dtype=dtype) outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism( with_graph=False, **kwargs, diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 76679eb064..f95ba515cb 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -277,7 +277,9 @@ def is_first_fp8_module(cls): @classmethod def get_fp8_recipe(cls) -> DelayedScaling: """Return the fp8 recipe""" - return cls.FP8_RECIPE + if cls.FP8_RECIPE is not None: + return cls.FP8_RECIPE + return get_default_fp8_recipe() @classmethod def get_fp8_group(cls) -> Union[dist_group_type, None]: diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index cba71e1326..ed0ed1c008 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -3,6 +3,7 @@ # See LICENSE for license information. """Functions for CUDA Graphs support in FP8""" +from collections.abc import Iterable from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union import torch @@ -18,7 +19,7 @@ ) from .distributed import get_all_rng_states, graph_safe_rng_available from .module.base import TransformerEngineBaseModule - +from .ops.op import BasicOperation __all__ = ["make_graphed_callables"] @@ -486,28 +487,46 @@ def new_fwd(*user_args, **user_kwargs): return tuple(ret) -def save_fp8_tensors(modules, amax_history_len): +def save_fp8_tensors( + modules: Iterable[torch.nn.Module], + fp8_recipe: DelayedScaling, +) -> List[Any]: """ Returns the FP8 tensors for all modules with adjusted amax history sizes. """ - saved_fp8_meta_tensors = [] + fp8_tensors = [] for module in modules: for m in module.modules(): + module_tensors = None if isinstance(m, TransformerEngineBaseModule): if m.primary_weights_in_fp8: - m.adjust_amax_history_length(amax_history_len) - saved_fp8_meta_tensors.append(m.get_fp8_meta_tensors()) - return saved_fp8_meta_tensors - - -def restore_fp8_tensors(modules, fp8_tensors): + m.adjust_amax_history_length(fp8_recipe.amax_history_len) + module_tensors = m.get_fp8_meta_tensors() + elif isinstance(m, BasicOperation): + m.pre_forward(fp8_enabled=True, fp8_recipe=fp8_recipe) + module_tensors = m._save_fp8_metas() + fp8_tensors.append(module_tensors) + return fp8_tensors + + +def restore_fp8_tensors( + modules: Iterable[torch.nn.Module], + fp8_tensors: List[Any], +) -> None: """Restore FP8 tensors.""" for module in modules: for m in module.modules(): + module_tensors = fp8_tensors.pop(0) if isinstance(m, TransformerEngineBaseModule): - m.reset_fp8_meta_tensors(fp8_tensors.pop(0)) - assert len(fp8_tensors) == 0, "TE internal error." + m.reset_fp8_meta_tensors(module_tensors) + elif isinstance(m, BasicOperation): + m._load_fp8_metas(module_tensors) + if len(fp8_tensors) != 0: + raise RuntimeError( + f"Got FP8 state for {len(fp8_tensors)} more modules than expected. " + "There is probably a discrepancy with `save_fp8_tensors`." + ) def make_graphed_callables( @@ -580,7 +599,7 @@ def make_graphed_callables( modules = (modules,) # Store FP8 tensors to reset later. - saved_fp8_tensors = save_fp8_tensors(modules, fp8_recipe.amax_history_len) + saved_fp8_tensors = save_fp8_tensors(modules, fp8_recipe=fp8_recipe) # FP8 wrapper. def wrap_autocast(block): diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 859b1ba1d7..46a72a08d2 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -308,8 +308,8 @@ def reset_parameters(self) -> None: weight = torch.nn.Parameter(weight) self.weight = weight - def pre_forward(self) -> None: - super().pre_forward() + def pre_forward(self, *args, **kwargs) -> None: + super().pre_forward(*args, **kwargs) if self.weight.device.type == "meta": self.reset_parameters() diff --git a/transformer_engine/pytorch/ops/basic/bias.py b/transformer_engine/pytorch/ops/basic/bias.py index 44a97b3b2d..eac1865566 100644 --- a/transformer_engine/pytorch/ops/basic/bias.py +++ b/transformer_engine/pytorch/ops/basic/bias.py @@ -111,8 +111,8 @@ def reset_parameters(self) -> None: bias = torch.nn.Parameter(bias) self.bias = bias - def pre_forward(self) -> None: - super().pre_forward() + def pre_forward(self, *args, **kwargs) -> None: + super().pre_forward(*args, **kwargs) if self.bias.device.type == "meta": self.reset_parameters() diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index a7c99c592d..be37ab8976 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -5,12 +5,12 @@ """Manager class for a pipeline of fusible operations.""" from __future__ import annotations +from collections.abc import Callable from typing import Any, Optional import torch from transformer_engine.pytorch.fp8 import FP8GlobalStateManager -from transformer_engine.pytorch.graph import is_graph_capturing from transformer_engine.pytorch.ops.op import ( BasicOperation, FusibleOperation, @@ -28,6 +28,24 @@ def _split_tuple(t: tuple, idx: int) -> tuple[tuple, tuple]: return t[:idx], t[idx:] +# Lazily imported function used in _is_graph_capturing +_is_graph_capturing_function: Optional[Callable[[], bool]] = None + + +def _is_graph_capturing() -> bool: + """Whether function is called within `make_graphed_callables` + + Avoid circular import with lazy import. + + """ + global _is_graph_capturing_function + if _is_graph_capturing_function is None: + from ..graph import is_graph_capturing + + _is_graph_capturing_function = is_graph_capturing + return _is_graph_capturing_function() + + class _OperationFuserAutogradFunction(torch.autograd.Function): """Autograd function for a pipeline of operations @@ -255,7 +273,7 @@ def backward( grad_extra_inputs_flat.extend(dxs) # Update FP8 scaling factors - if func_ctx.is_first_module and not is_graph_capturing(): + if func_ctx.is_first_module and not _is_graph_capturing(): FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) return ( diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index 75905ad854..87a7e825bc 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -14,6 +14,7 @@ import transformer_engine_torch as tex from transformer_engine.pytorch.fp8 import ( + DelayedScaling, FP8GlobalStateManager, get_default_fp8_recipe, ) @@ -231,25 +232,37 @@ def _make_meta( } @classmethod - def _maybe_update_fp8_meta(cls, fp8_meta: Optional[dict[str, Any]]) -> None: + def _maybe_update_fp8_meta( + cls, + fp8_meta: Optional[dict[str, Any]], + *, + fp8_recipe: Optional[DelayedScaling] = None, + ) -> None: if fp8_meta is None: return - # Update FP8 recipe and communication group - recipe = FP8GlobalStateManager.get_fp8_recipe() - fp8_meta["recipe"] = recipe + # Update FP8 recipe + if fp8_recipe is None: + fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() + fp8_meta["recipe"] = fp8_recipe + + # Update FP8 communication group fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group() # Adjust amax history length if needed - amax_history_len = recipe.amax_history_len + amax_history_len = fp8_recipe.amax_history_len for is_forward in (True, False): - key = FP8GlobalStateManager.get_meta_tensor_key(forward=is_forward) - if key not in fp8_meta: + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=is_forward) + if fp8_meta_key not in fp8_meta: continue - meta = fp8_meta[key] + meta = fp8_meta[fp8_meta_key] curr_len = meta.amax_history.size(0) + + # Nothing to be done if amax history is already correct if curr_len == amax_history_len: continue + + # Reallocate amax history with torch.no_grad(): if curr_len > amax_history_len: meta.amax_history = meta.amax_history[:amax_history_len].clone() @@ -259,6 +272,21 @@ def _maybe_update_fp8_meta(cls, fp8_meta: Optional[dict[str, Any]]) -> None: pad=(0, 0, 0, amax_history_len - curr_len), ) + # Update global buffers for amax reductions + buffer_info_key = FP8GlobalStateManager.get_buffer_info() + if buffer_info_key in fp8_meta: + fwd_pos, fwd_key, bwd_pos, bwd_key = fp8_meta[buffer_info_key] + for pos, buffer_key in zip((fwd_pos, bwd_pos), (fwd_key, bwd_key)): + assert ( + buffer_key in FP8GlobalStateManager.global_amax_history_buffer + ), "TE internal error during amax history change." + FP8GlobalStateManager.global_amax_buffer[buffer_key][pos] = fp8_meta[ + fp8_meta_key + ].amax_history[0] + FP8GlobalStateManager.global_amax_history_buffer[buffer_key][pos] = fp8_meta[ + fp8_meta_key + ].amax_history + def get_fp8_meta(self, mode: str) -> Optional[dict[str, Any]]: """FP8 metadata @@ -272,11 +300,67 @@ def get_fp8_meta(self, mode: str) -> Optional[dict[str, Any]]: self._fp8_metas = self._make_fp8_metas() return self._fp8_metas[mode] - def pre_forward(self) -> None: + @torch.no_grad() + def _save_fp8_metas(self) -> Optional[dict[str, Any]]: + """Create copies of tensors in FP8 metadata + + Tensor copies can be loaded with _load_fp8_metas. + + """ + if self._fp8_metas is None: + return None + out = {} + for mode, fp8_meta in self._fp8_metas.items(): + if fp8_meta is None: + continue + out[mode] = {} + for is_forward in (True, False): + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=is_forward) + if fp8_meta_key not in fp8_meta: + continue + out[mode][fp8_meta_key] = ( + fp8_meta[fp8_meta_key].scale.clone(), + fp8_meta[fp8_meta_key].scale_inv.clone(), + fp8_meta[fp8_meta_key].amax_history.clone(), + ) + return out + + @torch.no_grad() + def _load_fp8_metas(self, fp8_metas: Optional[dict[str, Any]]) -> None: + """Update FP8 metadata with saved tensor copies + + Tensor copies should be generated with _save_fp8_metas. + + """ + assert (self._fp8_metas is None) == ( + fp8_metas is None + ), "Saved FP8 metadata does not match operation's FP8 metadata" + if fp8_metas is None: + return + for mode, fp8_meta in fp8_metas.items(): + assert ( + mode in self._fp8_metas + ), f"Found an unexpected key ({mode=}) in saved FP8 metadata" + for fp8_meta_key, tensors in fp8_meta.items(): + assert ( + fp8_meta_key in self._fp8_metas[mode] + ), f"Found an unexpected key ({mode=}, {fp8_meta_key=}) in saved FP8 metadata" + scale, scale_inv, amax_history = tensors + self._fp8_metas[mode][fp8_meta_key].scale.copy_(scale) + self._fp8_metas[mode][fp8_meta_key].scale_inv.copy_(scale_inv) + self._fp8_metas[mode][fp8_meta_key].amax_history.copy_(amax_history) + + def pre_forward( + self, + *, + fp8_enabled: Optional[bool] = None, + fp8_recipe: Optional[DelayedScaling] = None, + ) -> None: """Preprocessing before forward pass""" # Initialize FP8 metadata if needed - fp8_enabled = FP8GlobalStateManager.is_fp8_enabled() + if fp8_enabled is None: + fp8_enabled = FP8GlobalStateManager.is_fp8_enabled() if fp8_enabled: # Construct FP8 metadata if needed @@ -285,7 +369,7 @@ def pre_forward(self) -> None: # Make sure FP8 metadata matches FP8 autocast context for fp8_meta in self._fp8_metas.values(): - self._maybe_update_fp8_meta(fp8_meta) + self._maybe_update_fp8_meta(fp8_meta, fp8_recipe=fp8_recipe) # Register FP8 metadata for amax and scale update if not FP8GlobalStateManager.fp8_graph_capturing():