@@ -169,25 +169,25 @@ def sync_float8_amax_and_scale_history(model: torch.nn.Module, fp8_layers=None)
169169 return
170170
171171 # Loop over all fp8 layers and grab the needed tensors
172- fp8_amax_x_tensor_list = []
173- fp8_amax_w_tensor_list = []
174- fp8_amax_dL_dY_tensor_list = []
172+ fp8_amax_x_tensor_list = [None ] * len ( fp8_layers )
173+ fp8_amax_w_tensor_list = [None ] * len ( fp8_layers )
174+ fp8_amax_dL_dY_tensor_list = [None ] * len ( fp8_layers )
175175
176- fp8_x_amax_history_stack = []
177- fp8_w_amax_history_stack = []
178- fp8_dL_dY_amax_history_stack = []
176+ fp8_x_amax_history_stack = [None ] * len ( fp8_layers )
177+ fp8_w_amax_history_stack = [None ] * len ( fp8_layers )
178+ fp8_dL_dY_amax_history_stack = [None ] * len ( fp8_layers )
179179
180180 x_dtypes = set ()
181181 scale_fn_recipes = set ()
182182
183- for child in fp8_layers :
184- fp8_amax_x_tensor_list . append ( child .fp8_amax_x )
185- fp8_amax_w_tensor_list . append ( child .fp8_amax_w )
186- fp8_amax_dL_dY_tensor_list . append ( child .fp8_amax_dL_dY )
183+ for idx , child in enumerate ( fp8_layers ) :
184+ fp8_amax_x_tensor_list [ idx ] = child .fp8_amax_x
185+ fp8_amax_w_tensor_list [ idx ] = child .fp8_amax_w
186+ fp8_amax_dL_dY_tensor_list [ idx ] = child .fp8_amax_dL_dY
187187
188- fp8_x_amax_history_stack . append ( child .fp8_amax_history_x )
189- fp8_w_amax_history_stack . append ( child .fp8_amax_history_w )
190- fp8_dL_dY_amax_history_stack . append ( child .fp8_amax_history_dL_dY )
188+ fp8_x_amax_history_stack [ idx ] = child .fp8_amax_history_x
189+ fp8_w_amax_history_stack [ idx ] = child .fp8_amax_history_w
190+ fp8_dL_dY_amax_history_stack [ idx ] = child .fp8_amax_history_dL_dY
191191
192192 x_dtypes .add (child .last_seen_input_dtype )
193193 scale_fn_recipes .add (child .recipe .scale_fn_name )
0 commit comments