1414
1515import torch
1616
17- from float8_experimental .config import Float8LinearConfig , TensorScalingType
17+ from float8_experimental .config import Float8LinearConfig , ScalingType
1818
1919from float8_experimental .float8_dynamic_utils import (
2020 cast_to_float8_e4m3_dynamic ,
@@ -159,9 +159,9 @@ def __init__(self, *args, **kwargs):
159159 self .scaling_type_grad_output = config .cast_config_grad_output .scaling_type
160160 # Convenience flag to skip code related to delayed scaling
161161 self .has_any_delayed_scaling = (
162- self .scaling_type_input is TensorScalingType .DELAYED
163- or self .scaling_type_weight is TensorScalingType .DELAYED
164- or self .scaling_type_grad_output is TensorScalingType .DELAYED
162+ self .scaling_type_input is ScalingType .DELAYED
163+ or self .scaling_type_weight is ScalingType .DELAYED
164+ or self .scaling_type_grad_output is ScalingType .DELAYED
165165 )
166166
167167 self .config = config
@@ -284,7 +284,7 @@ def cast_input_to_float8(
284284 autocast_dtype = torch .get_autocast_gpu_dtype ()
285285 input = input .to (autocast_dtype )
286286
287- if self .scaling_type_input is TensorScalingType .DELAYED :
287+ if self .scaling_type_input is ScalingType .DELAYED :
288288 scale_fn_name = self .config .delayed_scaling_config .scale_fn_name
289289 _maybe_initialize_amaxes_scales_for_float8_cast (
290290 input ,
@@ -305,14 +305,14 @@ def cast_input_to_float8(
305305 gemm_input_role = GemmInputRole .INPUT ,
306306 )
307307 else :
308- assert self .scaling_type_input is TensorScalingType .DYNAMIC
308+ assert self .scaling_type_input is ScalingType .DYNAMIC
309309 input_fp8 = cast_to_float8_e4m3_dynamic (input , self .linear_mm_config )
310310 return input_fp8
311311
312312 def cast_weight_to_float8 (
313313 self , weight : torch .Tensor , is_amax_initialized : bool
314314 ) -> torch .Tensor :
315- if self .scaling_type_weight is TensorScalingType .DELAYED :
315+ if self .scaling_type_weight is ScalingType .DELAYED :
316316 if isinstance (self .weight , Float8Tensor ): # cast by FSDP
317317 weight_fp8 = self .weight
318318 else :
@@ -337,7 +337,7 @@ def cast_weight_to_float8(
337337 gemm_input_role = GemmInputRole .WEIGHT ,
338338 )
339339 else :
340- assert self .scaling_type_weight is TensorScalingType .DYNAMIC
340+ assert self .scaling_type_weight is ScalingType .DYNAMIC
341341 if isinstance (self .weight , Float8Tensor ): # cast by FSDP
342342 weight_fp8 = self .weight
343343 else :
@@ -349,7 +349,7 @@ def cast_weight_to_float8(
349349 return weight_fp8
350350
351351 def cast_output_to_float8_in_bw (self , output : torch .Tensor ) -> torch .Tensor :
352- if self .scaling_type_grad_output is TensorScalingType .DELAYED :
352+ if self .scaling_type_grad_output is ScalingType .DELAYED :
353353 scale_fn_name = self .config .delayed_scaling_config .scale_fn_name
354354 output = NoopFwToFloat8E5M2Bw .apply (
355355 output ,
@@ -361,7 +361,7 @@ def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor:
361361 self .linear_mm_config ,
362362 )
363363 else :
364- assert self .scaling_type_grad_output is TensorScalingType .DYNAMIC
364+ assert self .scaling_type_grad_output is ScalingType .DYNAMIC
365365 output = cast_to_float8_e5m2_dynamic_bw (output , self .linear_mm_config )
366366 return output
367367
@@ -448,17 +448,15 @@ def from_float(
448448 # 2. buffers need to be already created for the delayed scaling version
449449 # of the weight wrapper to be initialized
450450 if config .enable_fsdp_float8_all_gather :
451- if config .cast_config_weight .scaling_type is TensorScalingType .DYNAMIC :
451+ if config .cast_config_weight .scaling_type is ScalingType .DYNAMIC :
452452 new_mod .weight = torch .nn .Parameter (
453453 WeightWithDynamicFloat8CastTensor (
454454 new_mod .weight ,
455455 new_mod .linear_mm_config ,
456456 )
457457 )
458458 else :
459- assert (
460- config .cast_config_weight .scaling_type is TensorScalingType .DELAYED
461- )
459+ assert config .cast_config_weight .scaling_type is ScalingType .DELAYED
462460 new_mod .weight = torch .nn .Parameter (
463461 WeightWithDelayedFloat8CastTensor (
464462 new_mod .weight ,
0 commit comments