diff --git a/float8_experimental/float8_dynamic_linear.py b/float8_experimental/float8_dynamic_linear.py index caeb31c..2ebdeaf 100644 --- a/float8_experimental/float8_dynamic_linear.py +++ b/float8_experimental/float8_dynamic_linear.py @@ -52,6 +52,13 @@ def backward(ctx, gradY): ) return fp8_tensor, None +def forward_pre_hook(mod, x): + x = cast_to_float8_e4m3fn(x[0], mod.forward_config) + return x + +def forward_post_hook(mod, x, y): + y = cast_to_float8_e5m2_bw(y, mod.backward_config) + return y class Float8DynamicLinear(torch.nn.Linear): """ @@ -62,14 +69,14 @@ class Float8DynamicLinear(torch.nn.Linear): def __init__(self, **super_kwargs): super().__init__(**super_kwargs) - def forward(self, x): - x_fp8 = cast_to_float8_e4m3fn(x, self.forward_config) + def forward(self, x_fp8): + # x_fp8 = cast_to_float8_e4m3fn(x, self.forward_config) if isinstance(self.weight, Float8Tensor): # cast by FSDP w_fp8 = self.weight else: w_fp8 = cast_to_float8_e4m3fn(self.weight, self.forward_config) y = torch.nn.functional.linear(x_fp8, w_fp8, self.bias) - y = cast_to_float8_e5m2_bw(y, self.backward_config) + # y = cast_to_float8_e5m2_bw(y, self.backward_config) return y @classmethod @@ -97,6 +104,8 @@ def from_float(cls, mod, emulate: bool = False) -> "Float8DynamicLinear": else: new_mod.weight = mod.weight new_mod.bias = mod.bias + new_mod.register_forward_pre_hook(forward_pre_hook) + new_mod.register_forward_hook(forward_post_hook) return new_mod