Skip to content

Commit

Permalink
Add preserve_high_precision_init_val to fp8_model_init
Browse files Browse the repository at this point in the history
Signed-off-by: kunlunl <[email protected]>
  • Loading branch information
kunlunl committed Aug 28, 2024
1 parent 3040785 commit db3e139
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 2 deletions.
1 change: 1 addition & 0 deletions qa/L0_pytorch_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@ pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py
pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py
pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops.py
pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops_distributed.py
pytest -v -s $TE_PATH/tests/pytorch/test_fp8_model_init.py
68 changes: 68 additions & 0 deletions tests/pytorch/test_fp8_model_init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

import pytest
import torch

import transformer_engine.pytorch as te
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager, fp8_model_init

# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()


@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
class TestFP8ModelInit:

@staticmethod
def setup_class(cls) -> None:
# Configure RNG
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

def test_default(self) -> None:
"""Test default parameters of fp8_model_init"""
with fp8_model_init():
model = te.Linear(768, 768)

assert isinstance(model.weight, Float8Tensor), "Weight should be Float8Tensor"
assert not hasattr(model.weight, "._high_precision_init_val"), \
"_high_precision_init_val should not exist"
assert not hasattr(model.weight, "get_high_precision_init_val"), \
"get_high_precision_init_val() should not exist"
assert not hasattr(model.weight, "clear_high_precision_init_val"), \
"clear_high_precision_init_val() should not exist"

def test_preserve_high_precision_init_val(self) -> None:
"""Test fp8_model_init with preserve_high_precision_init_val=True"""
with fp8_model_init(preserve_high_precision_init_val=True):
model = te.Linear(768, 768)

assert isinstance(model.weight, Float8Tensor), "Weight should be Float8Tensor"
assert hasattr(model.weight, "_high_precision_init_val"), \
"_high_precision_init_val not found"
assert hasattr(model.weight, "get_high_precision_init_val"), \
"get_high_precision_init_val() not found"
assert hasattr(model.weight, "clear_high_precision_init_val"), \
"clear_high_precision_init_val() not found"

high_precision = model.weight.get_high_precision_init_val()
assert high_precision.device.type == "cpu", "high_precision_init_val is not on the CPU"

new_fp8 = Float8Tensor.to_float8(
high_precision.to(model.weight.device),
fp8_meta=model.weight._fp8_meta,
fp8_meta_index=model.weight._fp8_meta_index,
amax=torch.empty(1, device="cuda"), # Dummy amax to avoid overwriting history.
)
assert torch.all(new_fp8._data == model.weight._data), \
"high_precision_init_val and model.weight are not equal"

model.weight.clear_high_precision_init_val()
assert model.weight.get_high_precision_init_val() is None, \
"clear_high_precision_init_val() not work"
assert not hasattr(model.weight, "._high_precision_init_val"), \
"clear_high_precision_init_val() not work"
29 changes: 28 additions & 1 deletion transformer_engine/pytorch/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class FP8GlobalStateManager:
FP8_RECIPE = None
FP8_DISTRIBUTED_GROUP = None
FP8_PARAMETERS = False
HIGH_PRECISION_INIT_VAL = False
IS_FIRST_FP8_MODULE = False
FP8_GRAPH_CAPTURING = False
FP8_AUTOCAST_DEPTH = 0
Expand All @@ -89,6 +90,7 @@ def reset(cls) -> None:
cls.FP8_RECIPE = None
cls.FP8_DISTRIBUTED_GROUP = None
cls.FP8_PARAMETERS = False
cls.HIGH_PRECISION_INIT_VAL = False
cls.IS_FIRST_FP8_MODULE = False
cls.FP8_GRAPH_CAPTURING = False
cls.FP8_AUTOCAST_DEPTH = 0
Expand Down Expand Up @@ -251,6 +253,11 @@ def with_fp8_parameters(cls) -> bool:
"""Should the parameters be stored as FP8"""
return cls.FP8_PARAMETERS

@classmethod
def with_high_precision_init_val(cls) -> bool:
"""Should the high precision initial values be stored with FP8 parameters"""
return cls.HIGH_PRECISION_INIT_VAL

@classmethod
def fp8_graph_capturing(cls) -> bool:
"""Is CUDA graph capture under way?"""
Expand Down Expand Up @@ -477,7 +484,10 @@ def restore_fp8_meta_tensors(fp8_meta: Dict[str, Any]) -> None:


@contextmanager
def fp8_model_init(enabled: bool = True) -> None:
def fp8_model_init(
enabled: bool = True,
preserve_high_precision_init_val: bool = False,
) -> None:
"""
Context manager for FP8 initialization of parameters.
Expand All @@ -488,6 +498,12 @@ def fp8_model_init(enabled: bool = True) -> None:
with fp8_model_init(enabled=True):
model = transformer_engine.pytorch.Linear(768, 768)
# Preserving high precision initial value to initialize master weight
with fp8_model_init(enabled=True, preserve_high_precision_init_val=True):
model = transformer_engine.pytorch.Linear(768, 768)
master_weight = model.weight.get_high_precision_init_val()
model.weight.clear_high_precision_init_val()
Parameters
----------
enabled: bool, default = `True`
Expand All @@ -501,15 +517,26 @@ def fp8_model_init(enabled: bool = True) -> None:
precision copies of weights are already present in the optimizer.
* inference, where only the FP8 copies of the parameters are used.
* LoRA-like fine-tuning, where the main parameters of the model do not change.
preserve_high_precision_init_val: bool, default = `False`
when enabled, store the high precision tensor used to initialize FP8 parameters
in CPU memory, and add two function attributes named `get_high_precision_init_val()`
and `clear_high_precision_init_val()` to FP8 parameters to get/clear this high
precision tensor. The purpose is that users can use this high-precision copy
to initialize master weights, avoiding the loss of precision that can occur when
using FP8 parameters directly. Note that after the master weights are initialized,
users should call `clear_high_precision_init_val()` to release this CPU memory.
This functionality is *EXPERIMENTAL*.
"""
_fp8_parameters = FP8GlobalStateManager.FP8_PARAMETERS
FP8GlobalStateManager.FP8_PARAMETERS = enabled
_high_precision_init_val = FP8GlobalStateManager.HIGH_PRECISION_INIT_VAL
FP8GlobalStateManager.HIGH_PRECISION_INIT_VAL = preserve_high_precision_init_val
try:
yield
finally:
FP8GlobalStateManager.FP8_PARAMETERS = _fp8_parameters
FP8GlobalStateManager.HIGH_PRECISION_INIT_VAL = _high_precision_init_val


@contextmanager
Expand Down
23 changes: 22 additions & 1 deletion transformer_engine/pytorch/module/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from abc import ABC, abstractmethod
from typing import Dict, Generator, List, Optional, Tuple, Union
from contextlib import contextmanager
from types import MethodType

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -389,6 +390,7 @@ def __init__(self) -> None:
self.sequence_parallel = False
self.param_init_meta = {}
self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters()
self.preserve_high_precision_init_val = FP8GlobalStateManager.with_high_precision_init_val()
self.fsdp_wrapped = False
self.fsdp_group = None
self._fp8_workspaces: Dict[str, Float8Tensor] = {}
Expand Down Expand Up @@ -864,7 +866,10 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None:

# If primary weights are in fp8, wrap the parameter as Float8Tensor
fp8_meta_index = self.param_init_meta[name].fp8_meta_index
high_precision_init_val = None
if self.primary_weights_in_fp8 and fp8_meta_index is not None:
if self.preserve_high_precision_init_val:
high_precision_init_val = param.detach().cpu()
param = Float8Tensor.to_float8(
param,
fp8_meta=self.fp8_meta,
Expand All @@ -876,7 +881,23 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None:
# NOTE: Currently this can only be broken when primary weights are in Fp8 but
# re-applying the nn.Parameter() wrap is a no-op when the input is already
# a parameter so we always re-apply it just for extra safety.
setattr(self, name, torch.nn.Parameter(param))
param = torch.nn.Parameter(param)
if high_precision_init_val is not None:
def get(self):
if hasattr(self, "_high_precision_init_val"):
return self._high_precision_init_val
else:
return None

def clear(self):
if hasattr(self, "_high_precision_init_val"):
del self._high_precision_init_val

param._high_precision_init_val = high_precision_init_val
param.get_high_precision_init_val = MethodType(get, param)
param.clear_high_precision_init_val = MethodType(clear, param)

setattr(self, name, param)

@abstractmethod
def forward(self):
Expand Down

0 comments on commit db3e139

Please sign in to comment.