From d771ca545298c954b6d17b39352cc48996c0d935 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Thu, 15 Aug 2024 17:22:05 -0700 Subject: [PATCH 1/5] Debug CUDA graph support with operation-based API Signed-off-by: Tim Moon --- tests/pytorch/test_cuda_graphs.py | 27 ++++- transformer_engine/pytorch/graph.py | 47 ++++++-- .../pytorch/ops/basic/basic_linear.py | 4 +- transformer_engine/pytorch/ops/basic/bias.py | 4 +- transformer_engine/pytorch/ops/op.py | 105 ++++++++++++++++-- 5 files changed, 160 insertions(+), 27 deletions(-) diff --git a/tests/pytorch/test_cuda_graphs.py b/tests/pytorch/test_cuda_graphs.py index 60a5a1ea99..97f18b037b 100644 --- a/tests/pytorch/test_cuda_graphs.py +++ b/tests/pytorch/test_cuda_graphs.py @@ -21,6 +21,7 @@ ) from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.utils import is_bf16_compatible +import transformer_engine.pytorch.ops as te_ops # Only run FP8 tests on H100. @@ -48,7 +49,15 @@ class ModelConfig: model_configs = {"small": ModelConfig(2, 32, 64, 2, 32)} -modules = ["transformer", "layernorm_mlp", "layernorm_linear", "linear", "mha", "dpa"] +modules = [ + "transformer", + "layernorm_mlp", + "layernorm_linear", + "linear", + "mha", + "dpa", + "linear_op", +] all_boolean = [True, False] @@ -171,7 +180,10 @@ def _test_cuda_graphs( """Helper function for CUDA graph test.""" reset_rng_states() FP8GlobalStateManager.reset() + dpa = module == "dpa" + if module == "linear_op": + fp8_weight_caching = False with fp8_model_init(enabled=fp8_params): # Create modules. @@ -209,18 +221,27 @@ def _test_cuda_graphs( ) for _ in range(num_layers) ] - elif dpa: + elif module == "dpa": assert config.hidden_size % config.num_heads == 0, "Err." assert num_layers == 1, "Err." modules = [ DotProductAttention(config.num_heads, config.kv_channels, attention_dropout=0.0) for _ in range(num_layers) ] - else: + elif module == "linear": modules = [ Linear(config.hidden_size, config.hidden_size, device="cuda", params_dtype=dtype) for _ in range(num_layers) ] + elif module == "linear_op": + modules = [ + te_ops.Sequential( + te_ops.Linear(config.hidden_size, config.hidden_size, dtype=dtype), + ) + for _ in range(num_layers) + ] + else: + raise ValueError(f"Unknown module type ({module})") # Initialize gradient buffers. for module in modules: diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index e2642bc360..b8b383ad6e 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -3,6 +3,7 @@ # See LICENSE for license information. """Functions for CUDA Graphs support in FP8""" +from collections.abc import Iterable from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union import torch @@ -19,7 +20,6 @@ from .distributed import get_all_rng_states, graph_safe_rng_available from .module.base import TransformerEngineBaseModule - __all__ = ["make_graphed_callables"] @@ -483,27 +483,56 @@ def new_fwd(*user_args, **user_kwargs): return tuple(ret) -def save_fp8_tensors(modules, amax_history_len): +def save_fp8_tensors( + modules: Iterable[torch.nn.Module], + fp8_recipe: DelayedScaling, +) -> Any: """ Returns the FP8 tensors for all modules with adjusted amax history sizes. """ - saved_fp8_meta_tensors = [] + from .ops import Sequential, FusibleOperation # Avoid circular import + fp8_tensors = [] for module in modules: for m in module.modules(): + module_tensors = None if isinstance(m, TransformerEngineBaseModule): if m.primary_weights_in_fp8: - m.adjust_amax_history_length(amax_history_len) - saved_fp8_meta_tensors.append(m.get_fp8_meta_tensors()) - return saved_fp8_meta_tensors + m.adjust_amax_history_length(fp8_recipe.amax_history_len) + module_tensors = m.get_fp8_meta_tensors() + elif isinstance(m, FusibleOperation): + if m.is_fused_op: + module_tensors = save_fp8_tensors(m.basic_ops, fp8_recipe) + else: + m.pre_forward( + fp8_enabled=True, + fp8_recipe=fp8_recipe, + ) + module_tensors = m._save_fp8_metas() + elif isinstance(m, Sequential): + module_tensors = save_fp8_tensors(m, fp8_recipe) + fp8_tensors.append(module_tensors) + return fp8_tensors -def restore_fp8_tensors(modules, fp8_tensors): +def restore_fp8_tensors( + modules: Iterable[torch.nn.Module], + fp8_tensors: Any, +) -> None: """Restore FP8 tensors.""" + from .ops import Sequential, FusibleOperation # Avoid circular import for module in modules: for m in module.modules(): + module_tensors = fp8_tensors.pop(0) if isinstance(m, TransformerEngineBaseModule): - m.reset_fp8_meta_tensors(fp8_tensors.pop(0)) + m.reset_fp8_meta_tensors(module_tensors) + elif isinstance(m, FusibleOperation): + if m.is_fused_op: + restore_fp8_tensors(m.basic_ops, module_tensors) + else: + m._load_fp8_metas(module_tensors) + elif isinstance(m, Sequential): + restore_fp8_tensors(m, module_tensors) assert len(fp8_tensors) == 0, "TE internal error." @@ -573,7 +602,7 @@ def make_graphed_callables( modules = (modules,) # Store FP8 tensors to reset later. - saved_fp8_tensors = save_fp8_tensors(modules, fp8_recipe.amax_history_len) + saved_fp8_tensors = save_fp8_tensors(modules, fp8_recipe=fp8_recipe) # FP8 wrapper. def wrap_autocast(block): diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 826807d1c0..3c9f3b3bc8 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -300,8 +300,8 @@ def reset_parameters(self) -> None: weight = torch.nn.Parameter(weight) self.weight = weight - def pre_forward(self) -> None: - super().pre_forward() + def pre_forward(self, *args, **kwargs) -> None: + super().pre_forward(*args, **kwargs) if self.weight.device.type == "meta": self.reset_parameters() diff --git a/transformer_engine/pytorch/ops/basic/bias.py b/transformer_engine/pytorch/ops/basic/bias.py index b8e8cc5e56..7688aa2ea1 100644 --- a/transformer_engine/pytorch/ops/basic/bias.py +++ b/transformer_engine/pytorch/ops/basic/bias.py @@ -113,8 +113,8 @@ def reset_parameters(self) -> None: bias = torch.nn.Parameter(bias) self.bias = bias - def pre_forward(self) -> None: - super().pre_forward() + def pre_forward(self, *args, **kwargs) -> None: + super().pre_forward(*args, **kwargs) if self.bias.device.type == "meta": self.reset_parameters() diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index 47c6567056..d1f5f2c719 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -14,6 +14,7 @@ import transformer_engine_torch as tex from transformer_engine.pytorch.fp8 import ( + DelayedScaling, FP8GlobalStateManager, get_default_fp8_recipe, ) @@ -232,25 +233,39 @@ def _make_meta( ) @classmethod - def _maybe_update_fp8_meta(cls, fp8_meta: Optional[dict[str, Any]]) -> None: + def _maybe_update_fp8_meta( + cls, + fp8_meta: Optional[dict[str, Any]], + *, + fp8_recipe: Optional[DelayedScaling] = None, + ) -> None: if fp8_meta is None: return - # Update FP8 recipe and communication group - recipe = FP8GlobalStateManager.get_fp8_recipe() - fp8_meta["recipe"] = recipe + # Update FP8 recipe + if fp8_recipe is None: + fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() + if fp8_recipe is None: + fp8_recipe = get_default_fp8_recipe() + fp8_meta["recipe"] = fp8_recipe + + # Update FP8 communication group fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group() # Adjust amax history length if needed - amax_history_len = recipe.amax_history_len + amax_history_len = fp8_recipe.amax_history_len for is_forward in (True, False): - key = FP8GlobalStateManager.get_meta_tensor_key(forward=is_forward) - if key not in fp8_meta: + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=is_forward) + if fp8_meta_key not in fp8_meta: continue - meta = fp8_meta[key] + meta = fp8_meta[fp8_meta_key] curr_len = meta.amax_history.size(0) + + # Nothing to be done if amax history is already correct if curr_len == amax_history_len: continue + + # Reallocate amax history with torch.no_grad(): if curr_len > amax_history_len: meta.amax_history = meta.amax_history[:amax_history_len].clone() @@ -260,6 +275,21 @@ def _maybe_update_fp8_meta(cls, fp8_meta: Optional[dict[str, Any]]) -> None: pad=(0, 0, 0, amax_history_len - curr_len), ) + # Update global buffers for amax reductions + buffer_info_key = FP8GlobalStateManager.get_buffer_info() + if buffer_info_key in fp8_meta: + fwd_pos, fwd_key, bwd_pos, bwd_key = fp8_meta[buffer_info_key] + for pos, buffer_key in zip((fwd_pos, bwd_pos), (fwd_key, bwd_key)): + assert ( + buffer_key in FP8GlobalStateManager.global_amax_history_buffer + ), "TE internal error during amax history change." + FP8GlobalStateManager.global_amax_buffer[buffer_key][pos] = fp8_meta[ + fp8_meta_key + ].amax_history[0] + FP8GlobalStateManager.global_amax_history_buffer[buffer_key][pos] = ( + fp8_meta[fp8_meta_key].amax_history + ) + def get_fp8_meta(self, mode: str) -> Optional[dict[str, Any]]: """FP8 metadata @@ -273,11 +303,64 @@ def get_fp8_meta(self, mode: str) -> Optional[dict[str, Any]]: self._fp8_metas = self._make_fp8_metas() return self._fp8_metas[mode] - def pre_forward(self) -> None: + @torch.no_grad() + def _save_fp8_metas(self) -> Optional[dict[str, Any]]: + """Create copies of tensors in FP8 metadata + + Tensor copies can be loaded with _load_fp8_metas. + + """ + if self._fp8_metas is None: + return None + out = {} + for mode, fp8_meta in self._fp8_metas.items(): + if fp8_meta is None: + continue + out[mode] = {} + for is_forward in (True, False): + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=is_forward) + if fp8_meta_key not in fp8_meta: + continue + out[mode][fp8_meta_key] = ( + fp8_meta[fp8_meta_key].scale.clone(), + fp8_meta[fp8_meta_key].scale_inv.clone(), + fp8_meta[fp8_meta_key].amax_history.clone(), + ) + return out + + @torch.no_grad() + def _load_fp8_metas(self, fp8_metas: Optional[dict[str, Any]]) -> None: + """Update FP8 metadata with saved tensor copies + + Tensor copies should be generated with _save_fp8_metas. + + """ + assert (self._fp8_metas is None) == (fp8_metas is None), \ + "Saved FP8 metadata does not match operation's FP8 metadata" + if fp8_metas is None: + return + for mode, fp8_meta in fp8_metas.items(): + assert mode in self._fp8_metas, \ + f"Found an unexpected key ({mode=}) in saved FP8 metadata" + for fp8_meta_key, tensors in fp8_meta.items(): + assert fp8_meta_key in self._fp8_metas[mode], \ + f"Found an unexpected key ({mode=}, {fp8_meta_key=}) in saved FP8 metadata" + scale, scale_inv, amax_history = tensors + self._fp8_metas[mode][fp8_meta_key].scale.copy_(scale) + self._fp8_metas[mode][fp8_meta_key].scale_inv.copy_(scale_inv) + self._fp8_metas[mode][fp8_meta_key].amax_history.copy_(amax_history) + + def pre_forward( + self, + *, + fp8_enabled: Optional[bool] = None, + fp8_recipe: Optional[DelayedScaling] = None, + ) -> None: """Preprocessing before forward pass""" # Initialize FP8 metadata if needed - fp8_enabled = FP8GlobalStateManager.is_fp8_enabled() + if fp8_enabled is None: + fp8_enabled = FP8GlobalStateManager.is_fp8_enabled() if fp8_enabled: # Construct FP8 metadata if needed @@ -286,7 +369,7 @@ def pre_forward(self) -> None: # Make sure FP8 metadata matches FP8 autocast context for fp8_meta in self._fp8_metas.values(): - self._maybe_update_fp8_meta(fp8_meta) + self._maybe_update_fp8_meta(fp8_meta, fp8_recipe=fp8_recipe) # Register FP8 metadata for amax and scale update if not FP8GlobalStateManager.fp8_graph_capturing(): From ade0c029bc13ad652fcd8ba379043ae3719271d3 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Thu, 15 Aug 2024 18:49:09 -0700 Subject: [PATCH 2/5] Refactoring CUDA graph tests Signed-off-by: Tim Moon --- tests/pytorch/test_cuda_graphs.py | 371 +++++++++++++++++++----------- 1 file changed, 231 insertions(+), 140 deletions(-) diff --git a/tests/pytorch/test_cuda_graphs.py b/tests/pytorch/test_cuda_graphs.py index 97f18b037b..1af004f1ad 100644 --- a/tests/pytorch/test_cuda_graphs.py +++ b/tests/pytorch/test_cuda_graphs.py @@ -13,25 +13,25 @@ LayerNormLinear, LayerNormMLP, Linear, - make_graphed_callables, MultiheadAttention, TransformerLayer, fp8_autocast, fp8_model_init, + make_graphed_callables, ) from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.utils import is_bf16_compatible import transformer_engine.pytorch.ops as te_ops -# Only run FP8 tests on H100. +# Check if FP8 is supported. fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() +# Record initial RNG state. seed = 1234 torch.manual_seed(seed) torch.cuda.manual_seed(seed) -# Record initial RNG state from script run. _cpu_rng_state = torch.get_rng_state() _cuda_rng_state = torch.cuda.get_rng_state() @@ -49,25 +49,14 @@ class ModelConfig: model_configs = {"small": ModelConfig(2, 32, 64, 2, 32)} -modules = [ - "transformer", - "layernorm_mlp", - "layernorm_linear", - "linear", - "mha", - "dpa", - "linear_op", -] - -all_boolean = [True, False] - -dtypes = [torch.float32, torch.float16] +# Supported data types +dtypes: List[torch.dtype] = [torch.float32, torch.float16] if is_bf16_compatible(): # bf16 requires sm_80 or higher dtypes.append(torch.bfloat16) def reset_rng_states() -> None: - """revert back to initial RNG state.""" + """Revert to initial RNG state.""" torch.set_rng_state(_cpu_rng_state) torch.cuda.set_rng_state(_cuda_rng_state) @@ -79,64 +68,40 @@ def reset_global_fp8_state(): def assert_all_equal(l1: List[torch.Tensor], l2: List[torch.Tensor], names=None) -> bool: - """Ensures two lists are equal.""" + """Check that two lists of tensors match exactly.""" assert len(l1) == len(l2), "Unequal number of outputs." - failed = False - failed_tensors = "" + failure_message = "Output mismatches in:" + failed_tensors = [] for i, (t1, t2) in enumerate(zip(l1, l2)): if not torch.equal(t1, t2): - failed = True - failed_tensors += ( - f" {names[i]}\n" if names is not None else f" tensor at idx={i}\n" - ) - assert not failed, "Output mismatches in:\n" + failed_tensors + failure_message += "\n " + if names is None: + failure_message += f"tensor at idx={i}" + else: + failure_message += names[i] + failed_tensors.append((t1, t2)) + if failed_tensors: + print(failure_message) + t1, t2 = failed_tensors[0] + torch.testing.assert_close(t1, t2, rtol=0, atol=0) def generate_data( - config: ModelConfig, + model_config: ModelConfig, dtype: torch.dtype, - dpa: bool = False, warmup: bool = False, - return_grad_output: bool = False, -) -> Tuple[List[torch.Tensor], torch.Tensor]: + requires_grad: bool = True, +) -> torch.Tensor: """Generate synthetic data.""" gen_func = torch.ones if warmup else torch.randn - if dpa: - inputs = [ - gen_func( - config.sequence_length, - config.batch_size, - config.num_heads, - config.kv_channels, - device="cuda", - requires_grad=True, - dtype=dtype, - ) - for _ in range(3) - ] - else: - inputs = [ - gen_func( - config.sequence_length, - config.batch_size, - config.hidden_size, - device="cuda", - requires_grad=True, - dtype=dtype, - ) - ] - - if not return_grad_output: - return inputs - - grad_output = torch.randn( - config.sequence_length, - config.batch_size, - config.hidden_size, + return gen_func( + model_config.sequence_length, + model_config.batch_size, + model_config.hidden_size, device="cuda", + requires_grad=requires_grad, dtype=dtype, ) - return inputs, grad_output def get_outputs( @@ -166,33 +131,43 @@ def forward(self, input_: torch.Tensor, **kwargs) -> torch.Tensor: return x +# Supported modules +_test_cuda_graphs_modules: List[str] = [ + "transformer", + "layernorm_mlp", + "layernorm_linear", + "linear", + "mha", + "linear_op", +] + def _test_cuda_graphs( *, - config: ModelConfig, + graph_mode: str, + module: str, + model_config: ModelConfig, num_layers: int, dtype: torch.dtype, fp8: bool, fp8_params: bool, fp8_weight_caching: bool, - module: str, - graph_mode: str, ) -> List[torch.Tensor]: """Helper function for CUDA graph test.""" reset_rng_states() FP8GlobalStateManager.reset() - dpa = module == "dpa" + # Operation-based API does not support FP8 weight caching. if module == "linear_op": fp8_weight_caching = False + # Create modules. with fp8_model_init(enabled=fp8_params): - # Create modules. if module == "transformer": modules = [ TransformerLayer( - config.hidden_size, - config.hidden_size, - config.num_heads, + model_config.hidden_size, + model_config.hidden_size, + model_config.num_heads, hidden_dropout=0.0, attention_dropout=0.0, fuse_qkv_params=True, @@ -202,41 +177,51 @@ def _test_cuda_graphs( ] elif module == "layernorm_mlp": modules = [ - LayerNormMLP(config.hidden_size, config.hidden_size, params_dtype=dtype) + LayerNormMLP( + model_config.hidden_size, + model_config.hidden_size, + params_dtype=dtype, + ) for _ in range(num_layers) ] elif module == "layernorm_linear": modules = [ - LayerNormLinear(config.hidden_size, config.hidden_size, params_dtype=dtype) + LayerNormLinear( + model_config.hidden_size, + model_config.hidden_size, + params_dtype=dtype, + ) for _ in range(num_layers) ] elif module == "mha": modules = [ MultiheadAttention( - config.hidden_size, - config.num_heads, + model_config.hidden_size, + model_config.num_heads, attention_dropout=0.0, params_dtype=dtype, fuse_qkv_params=True, ) for _ in range(num_layers) ] - elif module == "dpa": - assert config.hidden_size % config.num_heads == 0, "Err." - assert num_layers == 1, "Err." - modules = [ - DotProductAttention(config.num_heads, config.kv_channels, attention_dropout=0.0) - for _ in range(num_layers) - ] elif module == "linear": modules = [ - Linear(config.hidden_size, config.hidden_size, device="cuda", params_dtype=dtype) + Linear( + model_config.hidden_size, + model_config.hidden_size, + device="cuda", + params_dtype=dtype, + ) for _ in range(num_layers) ] elif module == "linear_op": modules = [ te_ops.Sequential( - te_ops.Linear(config.hidden_size, config.hidden_size, dtype=dtype), + te_ops.Linear( + model_config.hidden_size, + model_config.hidden_size, + dtype=dtype, + ), ) for _ in range(num_layers) ] @@ -251,111 +236,207 @@ def _test_cuda_graphs( # Generate model and wrap API to return graphed version. if graph_mode == "full": # Graph entire model at once. - model = modules[0] if dpa else torch.nn.Sequential(*modules) + model = torch.nn.Sequential(*modules) model = make_graphed_callables( model, - generate_data(config, dtype, dpa=dpa, warmup=True), + (generate_data(model_config, dtype, warmup=True),), num_warmup_iters=10, fp8_enabled=fp8, fp8_weight_caching=fp8_weight_caching, ) elif graph_mode == "individual": - # Graph individual modules + # Graph individual modules. modules = [ make_graphed_callables( module, - generate_data(config, dtype, dpa=dpa, warmup=True), + (generate_data(model_config, dtype, warmup=True),), num_warmup_iters=10, fp8_enabled=fp8, fp8_weight_caching=fp8_weight_caching, ) for module in modules ] - model = modules[0] if dpa else _Sequential(*modules) + model = _Sequential(*modules) else: - model = modules[0] if dpa else _Sequential(*modules) + model = _Sequential(*modules) # Optimizer. - if not dpa: - optimizer = torch.optim.SGD(model.parameters(), lr=0.001) + optimizer = torch.optim.SGD(model.parameters(), lr=0.001) - # Launch. + # Training steps. for _ in range(3): - if not dpa: - optimizer.zero_grad(set_to_none=False) + optimizer.zero_grad(set_to_none=False) for grad_accumulation_step in range(2): - inputs, grad_output = generate_data(config, dtype, dpa=dpa, return_grad_output=True) + input_ = generate_data(model_config, dtype) + grad_output = generate_data(model_config, dtype, requires_grad=False) with fp8_autocast(enabled=fp8): kwargs = {} if fp8_weight_caching: kwargs["is_first_microbatch"] = grad_accumulation_step == 0 - output = model(*inputs, **kwargs) + output = model(input_, **kwargs) output.backward(grad_output) - if not dpa: - optimizer.step() + optimizer.step() return get_outputs(model, output) +@pytest.mark.parametrize("module", _test_cuda_graphs_modules) @pytest.mark.parametrize("dtype", dtypes) -@pytest.mark.parametrize("model", model_configs.keys()) -@pytest.mark.parametrize("num_layers", [1, 3]) -@pytest.mark.parametrize("fp8", all_boolean) -@pytest.mark.parametrize("fp8_params", all_boolean) -@pytest.mark.parametrize("fp8_weight_caching", all_boolean) -@pytest.mark.parametrize("module", modules) -def test_gpt_make_graphed_callables( +@pytest.mark.parametrize("fp8", (False, True)) +@pytest.mark.parametrize("fp8_params", (False, True)) +def test_make_graphed_callables( + *, + module: str, + model_config: str = "small", + num_layers: int = 3, dtype: torch.dtype, - model: str, - num_layers: int, fp8: bool, fp8_params: bool, - fp8_weight_caching: bool, - module: str, + fp8_weight_caching: bool = False, ) -> None: + + # Skip invalid configurations. if fp8 and not fp8_available: pytest.skip(reason_for_no_fp8) if fp8_params and not fp8: pytest.skip("FP8 needed for FP8 parameters.") if fp8_weight_caching and not fp8: pytest.skip("FP8 needed for FP8 parameters.") - if module == "dpa" and num_layers > 1: - pytest.skip("Max 1 layer for DPA.") - - config = model_configs[model] + # Run model with different CUDA graph settings. + model_config = model_configs[model_config] kwargs = dict( - config=config, + module=module, + model_config=model_config, num_layers=num_layers, dtype=dtype, fp8=fp8, fp8_params=fp8_params, fp8_weight_caching=fp8_weight_caching, - module=module, ) outputs = _test_cuda_graphs(graph_mode="none", **kwargs) graph_outputs_mode1 = _test_cuda_graphs(graph_mode="full", **kwargs) graph_outputs_mode2 = _test_cuda_graphs(graph_mode="individual", **kwargs) - # Check that results match + # Check that results match. assert_all_equal(outputs, graph_outputs_mode1) assert_all_equal(outputs, graph_outputs_mode2) -def _test_cuda_graphs_with_kwargs( +_test_make_graphed_callables_with_fp8_weight_caching_modules = [ + "transformer", + "layernorm_mlp", + "layernorm_linear", + "linear", + "mha", +] + +@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) +@pytest.mark.parametrize( + "module", + _test_make_graphed_callables_with_fp8_weight_caching_modules, +) +@pytest.mark.parametrize("fp8_params", (False, True)) +def test_make_graphed_callables_with_fp8_weight_caching( *, - config: ModelConfig, + module: str, + fp8_params: bool, +) -> None: + test_make_graphed_callables( + module=module, + dtype=torch.float32, + fp8=True, + fp8_params=fp8_params, + fp8_weight_caching=True, + ) + + +def generate_data_for_dot_product_attention( + model_config: ModelConfig, dtype: torch.dtype, + warmup: bool = False, +) -> List[torch.Tensor]: + """Generate synthetic data for dot product attention.""" + gen_func = torch.ones if warmup else torch.randn + return [ + gen_func( + model_config.sequence_length, + model_config.batch_size, + model_config.num_heads, + model_config.kv_channels, + device="cuda", + requires_grad=True, + dtype=dtype, + ) + for _ in range(3) + ] + + +def _test_cuda_graphs_with_dot_product_attention( + *, with_graph: bool, + model_config: ModelConfig, + dtype: torch.dtype, ) -> List[torch.Tensor]: - """Simulate Megatron-LM interleaved pipeline parallelism.""" + """Helper function for CUDA graph test.""" + reset_rng_states() + FP8GlobalStateManager.reset() + + # Create dot product attention module. + assert model_config.hidden_size % model_config.num_heads == 0 + model = DotProductAttention( + model_config.num_heads, + model_config.kv_channels, + attention_dropout=0.0, + ) + + # Graph model if needed. + if with_graph: + model = make_graphed_callables( + model, + generate_data_for_dot_product_attention(model_config, dtype, warmup=True), + num_warmup_iters=10, + fp8_enabled=False, + ) + + # Forward and backward passes. + for _ in range(3): + inputs = generate_data_for_dot_product_attention(model_config, dtype) + grad_output = generate_data(model_config, dtype, requires_grad=False) + output = model(*inputs) + output.backward(grad_output) + + return get_outputs(model, output) + + +@pytest.mark.parametrize("dtype", dtypes) +def test_make_graphed_callables_with_dot_product_attention( + *, + model_config: str = "small", + dtype: torch.dtype, +) -> None: + """Test CUDA graphs with dot product attention.""" + model_config = model_configs[model_config] + kwargs = dict(model_config=model_config, dtype=dtype) + outputs = _test_cuda_graphs_with_dot_product_attention(with_graph=False, **kwargs) + graph_outputs = _test_cuda_graphs_with_dot_product_attention(with_graph=True, **kwargs) + assert_all_equal(outputs, graph_outputs) + + +def _test_cuda_graphs_with_kwargs( + *, + with_graph: bool, + model_config: ModelConfig, + dtype: torch.dtype, +) -> List[torch.Tensor]: + """Helper function for CUDA graph test with keyword arguments.""" reset_rng_states() # Initialize model. model = TransformerLayer( - config.hidden_size, - config.hidden_size, - config.num_heads, + model_config.hidden_size, + model_config.hidden_size, + model_config.num_heads, hidden_dropout=0.0, attention_dropout=0.0, self_attn_mask_type="arbitrary", @@ -370,13 +451,18 @@ def _test_cuda_graphs_with_kwargs( # Make graphed version of model if needed. if with_graph: attn_mask = torch.zeros( - (config.batch_size, 1, config.sequence_length, config.sequence_length), + ( + model_config.batch_size, + 1, + model_config.sequence_length, + model_config.sequence_length, + ), dtype=torch.bool, device="cuda", ) model = make_graphed_callables( model, - generate_data(config, dtype, warmup=True), + (generate_data(model_config, dtype, warmup=True),), sample_kwargs=dict(attention_mask=attn_mask), allow_unused_input=True, ) @@ -388,14 +474,15 @@ def _test_cuda_graphs_with_kwargs( for _ in range(3): optimizer.zero_grad(set_to_none=False) for grad_accumulation_step in range(2): - inputs, grad_output = generate_data(config, dtype, return_grad_output=True) + input_ = generate_data(model_config, dtype) + grad_output = generate_data(model_config, dtype, requires_grad=False) attn_mask = torch.randint( 2, - (config.batch_size, 1, config.sequence_length, config.sequence_length), + (model_config.batch_size, 1, model_config.sequence_length, model_config.sequence_length), dtype=torch.bool, device="cuda", ) - output = model(*inputs, attention_mask=attn_mask) + output = model(input_, attention_mask=attn_mask) output.backward(grad_output) optimizer.step() @@ -403,12 +490,13 @@ def _test_cuda_graphs_with_kwargs( def test_make_graphed_callables_with_kwargs( + *, + model_config: str = "small", dtype: torch.dtype = torch.float32, - model: str = "small", ) -> None: """Test CUDA graphs with keyword arguments.""" - config = model_configs[model] - kwargs = dict(config=config, dtype=dtype) + model_config = model_configs[model_config] + kwargs = dict(model_config=model_config, dtype=dtype) outputs = _test_cuda_graphs_with_kwargs(with_graph=False, **kwargs) graph_outputs = _test_cuda_graphs_with_kwargs(with_graph=True, **kwargs) assert_all_equal(outputs, graph_outputs) @@ -416,9 +504,9 @@ def test_make_graphed_callables_with_kwargs( def _test_cuda_graphs_with_interleaved_pipeline_parallelism( *, - config: ModelConfig, - dtype: torch.dtype, with_graph: bool, + model_config: ModelConfig, + dtype: torch.dtype, ) -> List[torch.Tensor]: """Simulate Megatron-LM interleaved pipeline parallelism.""" reset_rng_states() @@ -432,8 +520,8 @@ def _test_cuda_graphs_with_interleaved_pipeline_parallelism( model = torch.nn.ModuleList( [ Linear( - config.hidden_size, - config.hidden_size, + model_config.hidden_size, + model_config.hidden_size, params_dtype=dtype, ) for _ in range(num_layers) @@ -451,7 +539,8 @@ def _test_cuda_graphs_with_interleaved_pipeline_parallelism( } if with_graph: sample_args = tuple( - generate_data(config, dtype, warmup=True) for _ in range(num_layers * num_microbatches) + (generate_data(model_config, dtype, warmup=True),) + for _ in range(num_layers * num_microbatches) ) layer_forwards = make_graphed_callables( tuple(model), @@ -476,9 +565,10 @@ def _test_cuda_graphs_with_interleaved_pipeline_parallelism( grad_outputs = {} for layer_idx in range(num_layers): for microbatch_idx in range(num_microbatches): - x, dy = generate_data(config, dtype, return_grad_output=True) + x = generate_data(model_config, dtype) + dy = generate_data(model_config, dtype, requires_grad=False) idxs = (layer_idx, microbatch_idx) - inputs[idxs] = x[0] + inputs[idxs] = x grad_outputs[idxs] = dy # Cache for layer outputs. @@ -515,12 +605,13 @@ def backward(layer_idx: int, microbatch_idx: int): def test_make_graphed_callables_with_interleaved_pipeline_parallelism( + *, + model_config: str = "small", dtype: torch.dtype = torch.float16, - model: str = "small", ) -> None: """Test CUDA graphs with Megatron-LM interleaved pipeline parallelism.""" - config = model_configs[model] - kwargs = dict(config=config, dtype=dtype) + model_config = model_configs[model_config] + kwargs = dict(model_config=model_config, dtype=dtype) outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism( with_graph=False, **kwargs, From e5d40a69517dac4474b3ab86edb92fe5ef1bf27f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 16 Aug 2024 01:53:46 +0000 Subject: [PATCH 3/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/test_cuda_graphs.py | 9 ++++++++- transformer_engine/pytorch/graph.py | 2 ++ transformer_engine/pytorch/ops/op.py | 21 ++++++++++++--------- 3 files changed, 22 insertions(+), 10 deletions(-) diff --git a/tests/pytorch/test_cuda_graphs.py b/tests/pytorch/test_cuda_graphs.py index 1af004f1ad..010050baea 100644 --- a/tests/pytorch/test_cuda_graphs.py +++ b/tests/pytorch/test_cuda_graphs.py @@ -141,6 +141,7 @@ def forward(self, input_: torch.Tensor, **kwargs) -> torch.Tensor: "linear_op", ] + def _test_cuda_graphs( *, graph_mode: str, @@ -331,6 +332,7 @@ def test_make_graphed_callables( "mha", ] + @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.parametrize( "module", @@ -478,7 +480,12 @@ def _test_cuda_graphs_with_kwargs( grad_output = generate_data(model_config, dtype, requires_grad=False) attn_mask = torch.randint( 2, - (model_config.batch_size, 1, model_config.sequence_length, model_config.sequence_length), + ( + model_config.batch_size, + 1, + model_config.sequence_length, + model_config.sequence_length, + ), dtype=torch.bool, device="cuda", ) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index b8b383ad6e..7193d33476 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -492,6 +492,7 @@ def save_fp8_tensors( with adjusted amax history sizes. """ from .ops import Sequential, FusibleOperation # Avoid circular import + fp8_tensors = [] for module in modules: for m in module.modules(): @@ -521,6 +522,7 @@ def restore_fp8_tensors( ) -> None: """Restore FP8 tensors.""" from .ops import Sequential, FusibleOperation # Avoid circular import + for module in modules: for m in module.modules(): module_tensors = fp8_tensors.pop(0) diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index d1f5f2c719..dd3307c0fb 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -286,9 +286,9 @@ def _maybe_update_fp8_meta( FP8GlobalStateManager.global_amax_buffer[buffer_key][pos] = fp8_meta[ fp8_meta_key ].amax_history[0] - FP8GlobalStateManager.global_amax_history_buffer[buffer_key][pos] = ( - fp8_meta[fp8_meta_key].amax_history - ) + FP8GlobalStateManager.global_amax_history_buffer[buffer_key][pos] = fp8_meta[ + fp8_meta_key + ].amax_history def get_fp8_meta(self, mode: str) -> Optional[dict[str, Any]]: """FP8 metadata @@ -335,16 +335,19 @@ def _load_fp8_metas(self, fp8_metas: Optional[dict[str, Any]]) -> None: Tensor copies should be generated with _save_fp8_metas. """ - assert (self._fp8_metas is None) == (fp8_metas is None), \ - "Saved FP8 metadata does not match operation's FP8 metadata" + assert (self._fp8_metas is None) == ( + fp8_metas is None + ), "Saved FP8 metadata does not match operation's FP8 metadata" if fp8_metas is None: return for mode, fp8_meta in fp8_metas.items(): - assert mode in self._fp8_metas, \ - f"Found an unexpected key ({mode=}) in saved FP8 metadata" + assert ( + mode in self._fp8_metas + ), f"Found an unexpected key ({mode=}) in saved FP8 metadata" for fp8_meta_key, tensors in fp8_meta.items(): - assert fp8_meta_key in self._fp8_metas[mode], \ - f"Found an unexpected key ({mode=}, {fp8_meta_key=}) in saved FP8 metadata" + assert ( + fp8_meta_key in self._fp8_metas[mode] + ), f"Found an unexpected key ({mode=}, {fp8_meta_key=}) in saved FP8 metadata" scale, scale_inv, amax_history = tensors self._fp8_metas[mode][fp8_meta_key].scale.copy_(scale) self._fp8_metas[mode][fp8_meta_key].scale_inv.copy_(scale_inv) From 7d04de59bc5a6cd5b030cbc5889949012a285b8a Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Thu, 19 Sep 2024 20:06:39 -0700 Subject: [PATCH 4/5] Review suggestions from @ptrendx Return default recipe from FP8GlobalStateManager.get_fp8_recipe if needed. Expand error message when failing to load FP8 state after capturing CUDA graph. Signed-off-by: Tim Moon --- transformer_engine/pytorch/fp8.py | 4 +++- transformer_engine/pytorch/graph.py | 6 +++++- transformer_engine/pytorch/ops/op.py | 2 -- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 76679eb064..f95ba515cb 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -277,7 +277,9 @@ def is_first_fp8_module(cls): @classmethod def get_fp8_recipe(cls) -> DelayedScaling: """Return the fp8 recipe""" - return cls.FP8_RECIPE + if cls.FP8_RECIPE is not None: + return cls.FP8_RECIPE + return get_default_fp8_recipe() @classmethod def get_fp8_group(cls) -> Union[dist_group_type, None]: diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 7193d33476..e040a34529 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -535,7 +535,11 @@ def restore_fp8_tensors( m._load_fp8_metas(module_tensors) elif isinstance(m, Sequential): restore_fp8_tensors(m, module_tensors) - assert len(fp8_tensors) == 0, "TE internal error." + if len(fp8_tensors) != 0: + raise RuntimeError( + f"Got FP8 state for {len(fp8_tensors)} more modules than expected. " + "There is probably a discrepancy with `save_fp8_tensors`." + ) def make_graphed_callables( diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index dd3307c0fb..5c146c360e 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -245,8 +245,6 @@ def _maybe_update_fp8_meta( # Update FP8 recipe if fp8_recipe is None: fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() - if fp8_recipe is None: - fp8_recipe = get_default_fp8_recipe() fp8_meta["recipe"] = fp8_recipe # Update FP8 communication group From 805abc191cb99cf1772647b12c7ad477c5a82a33 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Thu, 19 Sep 2024 20:15:25 -0700 Subject: [PATCH 5/5] Avoid unnecessary recursion when saving/loading FP8 state Signed-off-by: Tim Moon --- transformer_engine/pytorch/graph.py | 32 ++++++++--------------------- 1 file changed, 8 insertions(+), 24 deletions(-) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index e040a34529..3a0c68381e 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -19,6 +19,7 @@ ) from .distributed import get_all_rng_states, graph_safe_rng_available from .module.base import TransformerEngineBaseModule +from .ops.op import BasicOperation __all__ = ["make_graphed_callables"] @@ -486,13 +487,11 @@ def new_fwd(*user_args, **user_kwargs): def save_fp8_tensors( modules: Iterable[torch.nn.Module], fp8_recipe: DelayedScaling, -) -> Any: +) -> List[Any]: """ Returns the FP8 tensors for all modules with adjusted amax history sizes. """ - from .ops import Sequential, FusibleOperation # Avoid circular import - fp8_tensors = [] for module in modules: for m in module.modules(): @@ -501,40 +500,25 @@ def save_fp8_tensors( if m.primary_weights_in_fp8: m.adjust_amax_history_length(fp8_recipe.amax_history_len) module_tensors = m.get_fp8_meta_tensors() - elif isinstance(m, FusibleOperation): - if m.is_fused_op: - module_tensors = save_fp8_tensors(m.basic_ops, fp8_recipe) - else: - m.pre_forward( - fp8_enabled=True, - fp8_recipe=fp8_recipe, - ) - module_tensors = m._save_fp8_metas() - elif isinstance(m, Sequential): - module_tensors = save_fp8_tensors(m, fp8_recipe) + elif isinstance(m, BasicOperation): + m.pre_forward(fp8_enabled=True, fp8_recipe=fp8_recipe) + module_tensors = m._save_fp8_metas() fp8_tensors.append(module_tensors) return fp8_tensors def restore_fp8_tensors( modules: Iterable[torch.nn.Module], - fp8_tensors: Any, + fp8_tensors: List[Any], ) -> None: """Restore FP8 tensors.""" - from .ops import Sequential, FusibleOperation # Avoid circular import - for module in modules: for m in module.modules(): module_tensors = fp8_tensors.pop(0) if isinstance(m, TransformerEngineBaseModule): m.reset_fp8_meta_tensors(module_tensors) - elif isinstance(m, FusibleOperation): - if m.is_fused_op: - restore_fp8_tensors(m.basic_ops, module_tensors) - else: - m._load_fp8_metas(module_tensors) - elif isinstance(m, Sequential): - restore_fp8_tensors(m, module_tensors) + elif isinstance(m, BasicOperation): + m._load_fp8_metas(module_tensors) if len(fp8_tensors) != 0: raise RuntimeError( f"Got FP8 state for {len(fp8_tensors)} more modules than expected. "