Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add high_precision_init_val to model params when using fp8_model_init #1121

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions qa/L0_pytorch_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@ pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py
pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py
pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops.py
pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops_distributed.py
pytest -v -s $TE_PATH/tests/pytorch/test_fp8_model_init.py
77 changes: 77 additions & 0 deletions tests/pytorch/test_fp8_model_init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

import pytest
import torch

import transformer_engine.pytorch as te
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager, fp8_model_init

# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()


@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
class TestFP8ModelInit:

@staticmethod
def setup_class(cls) -> None:
# Configure RNG
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

def test_default(self) -> None:
"""Test default parameters of fp8_model_init"""
with fp8_model_init():
model = te.Linear(768, 768)

assert isinstance(model.weight, Float8Tensor), "Weight should be Float8Tensor"
assert not hasattr(
model.weight, "._high_precision_init_val"
), "_high_precision_init_val should not exist"
assert not hasattr(
model.weight, "get_high_precision_init_val"
), "get_high_precision_init_val() should not exist"
assert not hasattr(
model.weight, "clear_high_precision_init_val"
), "clear_high_precision_init_val() should not exist"

def test_preserve_high_precision_init_val(self) -> None:
"""Test fp8_model_init with preserve_high_precision_init_val=True"""
with fp8_model_init(preserve_high_precision_init_val=True):
model = te.Linear(768, 768)

assert isinstance(model.weight, Float8Tensor), "Weight should be Float8Tensor"
assert hasattr(
model.weight, "_high_precision_init_val"
), "_high_precision_init_val not found"
assert hasattr(
model.weight, "get_high_precision_init_val"
), "get_high_precision_init_val() not found"
assert hasattr(
model.weight, "clear_high_precision_init_val"
), "clear_high_precision_init_val() not found"

high_precision = model.weight.get_high_precision_init_val()
assert high_precision.device.type == "cpu", "high_precision_init_val is not on the CPU"

new_fp8 = Float8Tensor.to_float8(
high_precision.to(model.weight.device),
fp8_meta=model.weight._fp8_meta,
fp8_meta_index=model.weight._fp8_meta_index,
amax=torch.empty(1, device="cuda"), # Dummy amax to avoid overwriting history.
)
assert torch.all(
new_fp8._data == model.weight._data
), "high_precision_init_val and model.weight are not equal"

model.weight.clear_high_precision_init_val()
assert (
model.weight.get_high_precision_init_val() is None
), "clear_high_precision_init_val() not work"
assert not hasattr(
model.weight, "._high_precision_init_val"
), "clear_high_precision_init_val() not work"
29 changes: 28 additions & 1 deletion transformer_engine/pytorch/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class FP8GlobalStateManager:
FP8_RECIPE = None
FP8_DISTRIBUTED_GROUP = None
FP8_PARAMETERS = False
HIGH_PRECISION_INIT_VAL = False
IS_FIRST_FP8_MODULE = False
FP8_GRAPH_CAPTURING = False
FP8_AUTOCAST_DEPTH = 0
Expand All @@ -89,6 +90,7 @@ def reset(cls) -> None:
cls.FP8_RECIPE = None
cls.FP8_DISTRIBUTED_GROUP = None
cls.FP8_PARAMETERS = False
cls.HIGH_PRECISION_INIT_VAL = False
cls.IS_FIRST_FP8_MODULE = False
cls.FP8_GRAPH_CAPTURING = False
cls.FP8_AUTOCAST_DEPTH = 0
Expand Down Expand Up @@ -251,6 +253,11 @@ def with_fp8_parameters(cls) -> bool:
"""Should the parameters be stored as FP8"""
return cls.FP8_PARAMETERS

@classmethod
def with_high_precision_init_val(cls) -> bool:
"""Should the high precision initial values be stored with FP8 parameters"""
return cls.HIGH_PRECISION_INIT_VAL

@classmethod
def fp8_graph_capturing(cls) -> bool:
"""Is CUDA graph capture under way?"""
Expand Down Expand Up @@ -477,7 +484,10 @@ def restore_fp8_meta_tensors(fp8_meta: Dict[str, Any]) -> None:


@contextmanager
def fp8_model_init(enabled: bool = True) -> None:
def fp8_model_init(
enabled: bool = True,
preserve_high_precision_init_val: bool = False,
) -> None:
"""
Context manager for FP8 initialization of parameters.

Expand All @@ -488,6 +498,12 @@ def fp8_model_init(enabled: bool = True) -> None:
with fp8_model_init(enabled=True):
model = transformer_engine.pytorch.Linear(768, 768)

# Preserving high precision initial value to initialize master weight
with fp8_model_init(enabled=True, preserve_high_precision_init_val=True):
model = transformer_engine.pytorch.Linear(768, 768)
master_weight = model.weight.get_high_precision_init_val()
model.weight.clear_high_precision_init_val()

Parameters
----------
enabled: bool, default = `True`
Expand All @@ -501,15 +517,26 @@ def fp8_model_init(enabled: bool = True) -> None:
precision copies of weights are already present in the optimizer.
* inference, where only the FP8 copies of the parameters are used.
* LoRA-like fine-tuning, where the main parameters of the model do not change.
preserve_high_precision_init_val: bool, default = `False`
when enabled, store the high precision tensor used to initialize FP8 parameters
in CPU memory, and add two function attributes named `get_high_precision_init_val()`
and `clear_high_precision_init_val()` to FP8 parameters to get/clear this high
precision tensor. The purpose is that users can use this high-precision copy
to initialize master weights, avoiding the loss of precision that can occur when
using FP8 parameters directly. Note that after the master weights are initialized,
users should call `clear_high_precision_init_val()` to release this CPU memory.

This functionality is *EXPERIMENTAL*.
"""
_fp8_parameters = FP8GlobalStateManager.FP8_PARAMETERS
FP8GlobalStateManager.FP8_PARAMETERS = enabled
_high_precision_init_val = FP8GlobalStateManager.HIGH_PRECISION_INIT_VAL
FP8GlobalStateManager.HIGH_PRECISION_INIT_VAL = preserve_high_precision_init_val
try:
yield
finally:
FP8GlobalStateManager.FP8_PARAMETERS = _fp8_parameters
FP8GlobalStateManager.HIGH_PRECISION_INIT_VAL = _high_precision_init_val


@contextmanager
Expand Down
24 changes: 23 additions & 1 deletion transformer_engine/pytorch/module/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from abc import ABC, abstractmethod
from typing import Dict, Generator, List, Optional, Tuple, Union
from contextlib import contextmanager
from types import MethodType

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -389,6 +390,7 @@ def __init__(self) -> None:
self.sequence_parallel = False
self.param_init_meta = {}
self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters()
self.preserve_high_precision_init_val = FP8GlobalStateManager.with_high_precision_init_val()
self.fsdp_wrapped = False
self.fsdp_group = None
self._fp8_workspaces: Dict[str, Float8Tensor] = {}
Expand Down Expand Up @@ -864,7 +866,10 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None:

# If primary weights are in fp8, wrap the parameter as Float8Tensor
fp8_meta_index = self.param_init_meta[name].fp8_meta_index
high_precision_init_val = None
if self.primary_weights_in_fp8 and fp8_meta_index is not None:
if self.preserve_high_precision_init_val:
high_precision_init_val = param.detach().cpu()
param = Float8Tensor.to_float8(
param,
fp8_meta=self.fp8_meta,
Expand All @@ -876,7 +881,24 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None:
# NOTE: Currently this can only be broken when primary weights are in Fp8 but
# re-applying the nn.Parameter() wrap is a no-op when the input is already
# a parameter so we always re-apply it just for extra safety.
setattr(self, name, torch.nn.Parameter(param))
param = torch.nn.Parameter(param)
if high_precision_init_val is not None:

def get(self):
if hasattr(self, "_high_precision_init_val"):
return self._high_precision_init_val
else:
return None

def clear(self):
if hasattr(self, "_high_precision_init_val"):
del self._high_precision_init_val

param._high_precision_init_val = high_precision_init_val
param.get_high_precision_init_val = MethodType(get, param)
param.clear_high_precision_init_val = MethodType(clear, param)

setattr(self, name, param)

@abstractmethod
def forward(self):
Expand Down