@@ -131,23 +131,6 @@ def backward(ctx, go):
131131 return res , * empty_grads
132132
133133
134- @dataclasses .dataclass
135- class DelayedScalingRecipe :
136- # Controls the history length of amax buffers
137- history_len : int
138-
139- # Controls the way to calculate current scale from amax history
140- # TODO(future): add other functions as needed, hardcoded or user defined
141- scale_fn_name : str
142-
143- def __init__ (self , history_len : int = 16 , scale_fn_name : str = "max" ):
144- self .history_len = history_len
145- self .scale_fn_name = scale_fn_name
146- assert (
147- self .scale_fn_name == "max"
148- ), f"{ self .scale_fn_name } is not implemented yet. Only max is supported for now."
149-
150-
151134class Float8Linear (torch .nn .Linear ):
152135 """
153136 Note: this is **not** a public API and is only intended to be used
@@ -161,13 +144,9 @@ class Float8Linear(torch.nn.Linear):
161144 def __init__ (self , * args , ** kwargs ):
162145 """
163146 Additional arguments on top of `torch.nn.Linear`'s arguments:
164- * `delayed_scaling_recipe`: configuration for delayed scaling
165147 * `config`: Float8LinearConfig
166148 """
167149
168- delayed_scaling_recipe = kwargs .pop (
169- "delayed_scaling_recipe" , DelayedScalingRecipe ()
170- )
171150 # Amax scales should always be kept as float32.
172151 self .always_float32_buffers = set ()
173152 config = kwargs .pop ("config" )
@@ -187,11 +166,6 @@ def __init__(self, *args, **kwargs):
187166
188167 self .config = config
189168
190- # TODO(future): have a unique recipe per buffer instead of one per
191- # module, saving implementing that until we need it.
192- # TODO(future): serialization for recipes
193- self .recipe = delayed_scaling_recipe
194-
195169 self .create_buffers ()
196170
197171 # TODO(future): user level configuration of gemms
@@ -237,7 +211,7 @@ def __init__(self, *args, **kwargs):
237211
238212 def create_buffers (self ):
239213 # Default values for history buffers, see above TODO
240- history_len = self .recipe .history_len
214+ history_len = self .config . delayed_scaling_config .history_len
241215 device = self .weight .device
242216 # TODO(future PR): dtype values below don't have the other float8
243217 # flavors, fix it
@@ -307,7 +281,7 @@ def cast_x_to_float8(
307281 x = x .to (autocast_dtype )
308282
309283 if self .scaling_type_input is TensorScalingType .DELAYED :
310- scale_fn_name = self .recipe .scale_fn_name
284+ scale_fn_name = self .config . delayed_scaling_config .scale_fn_name
311285 _maybe_initialize_amaxes_scales_for_float8_cast (
312286 x ,
313287 self .fp8_amax_input ,
@@ -338,7 +312,7 @@ def cast_w_to_float8(
338312 if isinstance (self .weight , Float8Tensor ): # cast by FSDP
339313 w_fp8 = self .weight
340314 else :
341- scale_fn_name = self .recipe .scale_fn_name
315+ scale_fn_name = self .config . delayed_scaling_config .scale_fn_name
342316 _maybe_initialize_amaxes_scales_for_float8_cast (
343317 w ,
344318 self .fp8_amax_weight ,
@@ -370,7 +344,7 @@ def cast_w_to_float8(
370344
371345 def cast_y_to_float8_in_bw (self , y : torch .Tensor ) -> torch .Tensor :
372346 if self .scaling_type_grad_output is TensorScalingType .DELAYED :
373- scale_fn_name = self .recipe .scale_fn_name
347+ scale_fn_name = self .config . delayed_scaling_config .scale_fn_name
374348 y = NoopFwToFloat8E5M2Bw .apply (
375349 y ,
376350 self .fp8_amax_grad_output ,
0 commit comments