From db3e13965a18b34c5be9be9cd182a507b48f0726 Mon Sep 17 00:00:00 2001 From: kunlunl Date: Wed, 28 Aug 2024 01:24:15 -0700 Subject: [PATCH 1/2] Add preserve_high_precision_init_val to fp8_model_init Signed-off-by: kunlunl --- qa/L0_pytorch_unittest/test.sh | 1 + tests/pytorch/test_fp8_model_init.py | 68 +++++++++++++++++++++++ transformer_engine/pytorch/fp8.py | 29 +++++++++- transformer_engine/pytorch/module/base.py | 23 +++++++- 4 files changed, 119 insertions(+), 2 deletions(-) create mode 100644 tests/pytorch/test_fp8_model_init.py diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index e6ccf3b82f..e9196a0010 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -23,3 +23,4 @@ pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops.py pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops_distributed.py +pytest -v -s $TE_PATH/tests/pytorch/test_fp8_model_init.py diff --git a/tests/pytorch/test_fp8_model_init.py b/tests/pytorch/test_fp8_model_init.py new file mode 100644 index 0000000000..4cce7ef816 --- /dev/null +++ b/tests/pytorch/test_fp8_model_init.py @@ -0,0 +1,68 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import pytest +import torch + +import transformer_engine.pytorch as te +from transformer_engine.pytorch.float8_tensor import Float8Tensor +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager, fp8_model_init + +# Check if FP8 is supported +fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() + + +@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) +class TestFP8ModelInit: + + @staticmethod + def setup_class(cls) -> None: + # Configure RNG + seed = 1234 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + def test_default(self) -> None: + """Test default parameters of fp8_model_init""" + with fp8_model_init(): + model = te.Linear(768, 768) + + assert isinstance(model.weight, Float8Tensor), "Weight should be Float8Tensor" + assert not hasattr(model.weight, "._high_precision_init_val"), \ + "_high_precision_init_val should not exist" + assert not hasattr(model.weight, "get_high_precision_init_val"), \ + "get_high_precision_init_val() should not exist" + assert not hasattr(model.weight, "clear_high_precision_init_val"), \ + "clear_high_precision_init_val() should not exist" + + def test_preserve_high_precision_init_val(self) -> None: + """Test fp8_model_init with preserve_high_precision_init_val=True""" + with fp8_model_init(preserve_high_precision_init_val=True): + model = te.Linear(768, 768) + + assert isinstance(model.weight, Float8Tensor), "Weight should be Float8Tensor" + assert hasattr(model.weight, "_high_precision_init_val"), \ + "_high_precision_init_val not found" + assert hasattr(model.weight, "get_high_precision_init_val"), \ + "get_high_precision_init_val() not found" + assert hasattr(model.weight, "clear_high_precision_init_val"), \ + "clear_high_precision_init_val() not found" + + high_precision = model.weight.get_high_precision_init_val() + assert high_precision.device.type == "cpu", "high_precision_init_val is not on the CPU" + + new_fp8 = Float8Tensor.to_float8( + high_precision.to(model.weight.device), + fp8_meta=model.weight._fp8_meta, + fp8_meta_index=model.weight._fp8_meta_index, + amax=torch.empty(1, device="cuda"), # Dummy amax to avoid overwriting history. + ) + assert torch.all(new_fp8._data == model.weight._data), \ + "high_precision_init_val and model.weight are not equal" + + model.weight.clear_high_precision_init_val() + assert model.weight.get_high_precision_init_val() is None, \ + "clear_high_precision_init_val() not work" + assert not hasattr(model.weight, "._high_precision_init_val"), \ + "clear_high_precision_init_val() not work" diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index e15268b998..f54cda6429 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -66,6 +66,7 @@ class FP8GlobalStateManager: FP8_RECIPE = None FP8_DISTRIBUTED_GROUP = None FP8_PARAMETERS = False + HIGH_PRECISION_INIT_VAL = False IS_FIRST_FP8_MODULE = False FP8_GRAPH_CAPTURING = False FP8_AUTOCAST_DEPTH = 0 @@ -89,6 +90,7 @@ def reset(cls) -> None: cls.FP8_RECIPE = None cls.FP8_DISTRIBUTED_GROUP = None cls.FP8_PARAMETERS = False + cls.HIGH_PRECISION_INIT_VAL = False cls.IS_FIRST_FP8_MODULE = False cls.FP8_GRAPH_CAPTURING = False cls.FP8_AUTOCAST_DEPTH = 0 @@ -251,6 +253,11 @@ def with_fp8_parameters(cls) -> bool: """Should the parameters be stored as FP8""" return cls.FP8_PARAMETERS + @classmethod + def with_high_precision_init_val(cls) -> bool: + """Should the high precision initial values be stored with FP8 parameters""" + return cls.HIGH_PRECISION_INIT_VAL + @classmethod def fp8_graph_capturing(cls) -> bool: """Is CUDA graph capture under way?""" @@ -477,7 +484,10 @@ def restore_fp8_meta_tensors(fp8_meta: Dict[str, Any]) -> None: @contextmanager -def fp8_model_init(enabled: bool = True) -> None: +def fp8_model_init( + enabled: bool = True, + preserve_high_precision_init_val: bool = False, + ) -> None: """ Context manager for FP8 initialization of parameters. @@ -488,6 +498,12 @@ def fp8_model_init(enabled: bool = True) -> None: with fp8_model_init(enabled=True): model = transformer_engine.pytorch.Linear(768, 768) + # Preserving high precision initial value to initialize master weight + with fp8_model_init(enabled=True, preserve_high_precision_init_val=True): + model = transformer_engine.pytorch.Linear(768, 768) + master_weight = model.weight.get_high_precision_init_val() + model.weight.clear_high_precision_init_val() + Parameters ---------- enabled: bool, default = `True` @@ -501,15 +517,26 @@ def fp8_model_init(enabled: bool = True) -> None: precision copies of weights are already present in the optimizer. * inference, where only the FP8 copies of the parameters are used. * LoRA-like fine-tuning, where the main parameters of the model do not change. + preserve_high_precision_init_val: bool, default = `False` + when enabled, store the high precision tensor used to initialize FP8 parameters + in CPU memory, and add two function attributes named `get_high_precision_init_val()` + and `clear_high_precision_init_val()` to FP8 parameters to get/clear this high + precision tensor. The purpose is that users can use this high-precision copy + to initialize master weights, avoiding the loss of precision that can occur when + using FP8 parameters directly. Note that after the master weights are initialized, + users should call `clear_high_precision_init_val()` to release this CPU memory. This functionality is *EXPERIMENTAL*. """ _fp8_parameters = FP8GlobalStateManager.FP8_PARAMETERS FP8GlobalStateManager.FP8_PARAMETERS = enabled + _high_precision_init_val = FP8GlobalStateManager.HIGH_PRECISION_INIT_VAL + FP8GlobalStateManager.HIGH_PRECISION_INIT_VAL = preserve_high_precision_init_val try: yield finally: FP8GlobalStateManager.FP8_PARAMETERS = _fp8_parameters + FP8GlobalStateManager.HIGH_PRECISION_INIT_VAL = _high_precision_init_val @contextmanager diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 3613e1fa5e..ceb4a286e4 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -13,6 +13,7 @@ from abc import ABC, abstractmethod from typing import Dict, Generator, List, Optional, Tuple, Union from contextlib import contextmanager +from types import MethodType import torch import torch.nn.functional as F @@ -389,6 +390,7 @@ def __init__(self) -> None: self.sequence_parallel = False self.param_init_meta = {} self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters() + self.preserve_high_precision_init_val = FP8GlobalStateManager.with_high_precision_init_val() self.fsdp_wrapped = False self.fsdp_group = None self._fp8_workspaces: Dict[str, Float8Tensor] = {} @@ -864,7 +866,10 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: # If primary weights are in fp8, wrap the parameter as Float8Tensor fp8_meta_index = self.param_init_meta[name].fp8_meta_index + high_precision_init_val = None if self.primary_weights_in_fp8 and fp8_meta_index is not None: + if self.preserve_high_precision_init_val: + high_precision_init_val = param.detach().cpu() param = Float8Tensor.to_float8( param, fp8_meta=self.fp8_meta, @@ -876,7 +881,23 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: # NOTE: Currently this can only be broken when primary weights are in Fp8 but # re-applying the nn.Parameter() wrap is a no-op when the input is already # a parameter so we always re-apply it just for extra safety. - setattr(self, name, torch.nn.Parameter(param)) + param = torch.nn.Parameter(param) + if high_precision_init_val is not None: + def get(self): + if hasattr(self, "_high_precision_init_val"): + return self._high_precision_init_val + else: + return None + + def clear(self): + if hasattr(self, "_high_precision_init_val"): + del self._high_precision_init_val + + param._high_precision_init_val = high_precision_init_val + param.get_high_precision_init_val = MethodType(get, param) + param.clear_high_precision_init_val = MethodType(clear, param) + + setattr(self, name, param) @abstractmethod def forward(self): From 1d2c3cab8907665da6d251b49ce76f6ce9f58068 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 28 Aug 2024 08:26:11 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/test_fp8_model_init.py | 45 ++++++++++++++--------- transformer_engine/pytorch/fp8.py | 6 +-- transformer_engine/pytorch/module/base.py | 1 + 3 files changed, 31 insertions(+), 21 deletions(-) diff --git a/tests/pytorch/test_fp8_model_init.py b/tests/pytorch/test_fp8_model_init.py index 4cce7ef816..1dfde214e3 100644 --- a/tests/pytorch/test_fp8_model_init.py +++ b/tests/pytorch/test_fp8_model_init.py @@ -29,12 +29,15 @@ def test_default(self) -> None: model = te.Linear(768, 768) assert isinstance(model.weight, Float8Tensor), "Weight should be Float8Tensor" - assert not hasattr(model.weight, "._high_precision_init_val"), \ - "_high_precision_init_val should not exist" - assert not hasattr(model.weight, "get_high_precision_init_val"), \ - "get_high_precision_init_val() should not exist" - assert not hasattr(model.weight, "clear_high_precision_init_val"), \ - "clear_high_precision_init_val() should not exist" + assert not hasattr( + model.weight, "._high_precision_init_val" + ), "_high_precision_init_val should not exist" + assert not hasattr( + model.weight, "get_high_precision_init_val" + ), "get_high_precision_init_val() should not exist" + assert not hasattr( + model.weight, "clear_high_precision_init_val" + ), "clear_high_precision_init_val() should not exist" def test_preserve_high_precision_init_val(self) -> None: """Test fp8_model_init with preserve_high_precision_init_val=True""" @@ -42,12 +45,15 @@ def test_preserve_high_precision_init_val(self) -> None: model = te.Linear(768, 768) assert isinstance(model.weight, Float8Tensor), "Weight should be Float8Tensor" - assert hasattr(model.weight, "_high_precision_init_val"), \ - "_high_precision_init_val not found" - assert hasattr(model.weight, "get_high_precision_init_val"), \ - "get_high_precision_init_val() not found" - assert hasattr(model.weight, "clear_high_precision_init_val"), \ - "clear_high_precision_init_val() not found" + assert hasattr( + model.weight, "_high_precision_init_val" + ), "_high_precision_init_val not found" + assert hasattr( + model.weight, "get_high_precision_init_val" + ), "get_high_precision_init_val() not found" + assert hasattr( + model.weight, "clear_high_precision_init_val" + ), "clear_high_precision_init_val() not found" high_precision = model.weight.get_high_precision_init_val() assert high_precision.device.type == "cpu", "high_precision_init_val is not on the CPU" @@ -58,11 +64,14 @@ def test_preserve_high_precision_init_val(self) -> None: fp8_meta_index=model.weight._fp8_meta_index, amax=torch.empty(1, device="cuda"), # Dummy amax to avoid overwriting history. ) - assert torch.all(new_fp8._data == model.weight._data), \ - "high_precision_init_val and model.weight are not equal" + assert torch.all( + new_fp8._data == model.weight._data + ), "high_precision_init_val and model.weight are not equal" model.weight.clear_high_precision_init_val() - assert model.weight.get_high_precision_init_val() is None, \ - "clear_high_precision_init_val() not work" - assert not hasattr(model.weight, "._high_precision_init_val"), \ - "clear_high_precision_init_val() not work" + assert ( + model.weight.get_high_precision_init_val() is None + ), "clear_high_precision_init_val() not work" + assert not hasattr( + model.weight, "._high_precision_init_val" + ), "clear_high_precision_init_val() not work" diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index f54cda6429..bb799ef8e5 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -485,9 +485,9 @@ def restore_fp8_meta_tensors(fp8_meta: Dict[str, Any]) -> None: @contextmanager def fp8_model_init( - enabled: bool = True, - preserve_high_precision_init_val: bool = False, - ) -> None: + enabled: bool = True, + preserve_high_precision_init_val: bool = False, +) -> None: """ Context manager for FP8 initialization of parameters. diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index ceb4a286e4..3aebc1729b 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -883,6 +883,7 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: # a parameter so we always re-apply it just for extra safety. param = torch.nn.Parameter(param) if high_precision_init_val is not None: + def get(self): if hasattr(self, "_high_precision_init_val"): return self._high_precision_init_val