Skip to content

Commit

Permalink
Add high_precision_init_val
Browse files Browse the repository at this point in the history
  • Loading branch information
kunlunl committed Aug 22, 2024
1 parent 3040785 commit f0e5850
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion transformer_engine/pytorch/module/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,7 +864,9 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None:

# If primary weights are in fp8, wrap the parameter as Float8Tensor
fp8_meta_index = self.param_init_meta[name].fp8_meta_index
high_precision_init_val = None
if self.primary_weights_in_fp8 and fp8_meta_index is not None:
high_precision_init_val = param.detach().cpu()
param = Float8Tensor.to_float8(
param,
fp8_meta=self.fp8_meta,
Expand All @@ -876,7 +878,10 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None:
# NOTE: Currently this can only be broken when primary weights are in Fp8 but
# re-applying the nn.Parameter() wrap is a no-op when the input is already
# a parameter so we always re-apply it just for extra safety.
setattr(self, name, torch.nn.Parameter(param))
param = torch.nn.Parameter(param)
if high_precision_init_val is not None:
param._high_precision_init_val = high_precision_init_val
setattr(self, name, param)

@abstractmethod
def forward(self):
Expand Down

0 comments on commit f0e5850

Please sign in to comment.