From f0e585025e8bfebe14b469ba2277da12d133f125 Mon Sep 17 00:00:00 2001 From: kunlunl Date: Thu, 22 Aug 2024 03:28:14 -0700 Subject: [PATCH] Add high_precision_init_val --- transformer_engine/pytorch/module/base.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 3613e1fa5e..9422d8f161 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -864,7 +864,9 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: # If primary weights are in fp8, wrap the parameter as Float8Tensor fp8_meta_index = self.param_init_meta[name].fp8_meta_index + high_precision_init_val = None if self.primary_weights_in_fp8 and fp8_meta_index is not None: + high_precision_init_val = param.detach().cpu() param = Float8Tensor.to_float8( param, fp8_meta=self.fp8_meta, @@ -876,7 +878,10 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: # NOTE: Currently this can only be broken when primary weights are in Fp8 but # re-applying the nn.Parameter() wrap is a no-op when the input is already # a parameter so we always re-apply it just for extra safety. - setattr(self, name, torch.nn.Parameter(param)) + param = torch.nn.Parameter(param) + if high_precision_init_val is not None: + param._high_precision_init_val = high_precision_init_val + setattr(self, name, param) @abstractmethod def forward(self):