diff --git a/float8_experimental/float8_dynamic_utils.py b/float8_experimental/float8_dynamic_utils.py index 9fe9d17..bfacd65 100644 --- a/float8_experimental/float8_dynamic_utils.py +++ b/float8_experimental/float8_dynamic_utils.py @@ -42,7 +42,7 @@ def backward(ctx, gradY): gradY_scale, e5m2_dtype, linear_mm_config=ctx.linear_mm_config, - gemm_input_role=GemmInputRole.DL_DY, + gemm_input_role=GemmInputRole.GRAD_OUTPUT, ) return fp8_tensor, None @@ -51,7 +51,7 @@ def cast_to_float8_e4m3_dynamic( inpt_tensor: torch.Tensor, linear_mm_config: LinearMMConfig, reduce_amax: bool = False, - gemm_input_role: GemmInputRole = GemmInputRole.X, + gemm_input_role: GemmInputRole = GemmInputRole.INPUT, ) -> Float8Tensor: if tensor_already_casted_to_fp8(inpt_tensor): return inpt_tensor diff --git a/float8_experimental/float8_linear.py b/float8_experimental/float8_linear.py index c598a93..42eeb86 100644 --- a/float8_experimental/float8_linear.py +++ b/float8_experimental/float8_linear.py @@ -125,7 +125,7 @@ def backward(ctx, go): fp8_scale_grad_output, e5m2_dtype, linear_mm_config=ctx.linear_mm_config, - gemm_input_role=GemmInputRole.DL_DY, + gemm_input_role=GemmInputRole.GRAD_OUTPUT, ) empty_grads = None, None, None, None, None, None return res, *empty_grads @@ -273,8 +273,8 @@ def convert_amax_buffer_to_float32(self): if self._buffers[key] is not None: self._buffers[key] = self._buffers[key].to(torch.float32) - def cast_x_to_float8( - self, x: torch.Tensor, is_amax_initialized: bool + def cast_input_to_float8( + self, input: torch.Tensor, is_amax_initialized: bool ) -> torch.Tensor: # Duplicate the autocast logic for F.linear, so that the output # of our module has the right original precision @@ -282,12 +282,12 @@ def cast_x_to_float8( # For now, hardcode to GPU's autocast dtype # if we need CPU support in the future, we can add it autocast_dtype = torch.get_autocast_gpu_dtype() - x = x.to(autocast_dtype) + input = input.to(autocast_dtype) if self.scaling_type_input is TensorScalingType.DELAYED: scale_fn_name = self.config.delayed_scaling_config.scale_fn_name _maybe_initialize_amaxes_scales_for_float8_cast( - x, + input, self.fp8_amax_input, self.fp8_amax_history_input, self.fp8_scale_input, @@ -296,29 +296,29 @@ def cast_x_to_float8( is_amax_initialized, reduce_amax=True, ) - x_fp8 = Float8Tensor.to_float8( - x, + input_fp8 = Float8Tensor.to_float8( + input, self.fp8_scale_input, e4m3_dtype, self.fp8_amax_input, linear_mm_config=self.linear_mm_config, - gemm_input_role=GemmInputRole.X, + gemm_input_role=GemmInputRole.INPUT, ) else: assert self.scaling_type_input is TensorScalingType.DYNAMIC - x_fp8 = cast_to_float8_e4m3_dynamic(x, self.linear_mm_config) - return x_fp8 + input_fp8 = cast_to_float8_e4m3_dynamic(input, self.linear_mm_config) + return input_fp8 - def cast_w_to_float8( - self, w: torch.Tensor, is_amax_initialized: bool + def cast_weight_to_float8( + self, weight: torch.Tensor, is_amax_initialized: bool ) -> torch.Tensor: if self.scaling_type_weight is TensorScalingType.DELAYED: if isinstance(self.weight, Float8Tensor): # cast by FSDP - w_fp8 = self.weight + weight_fp8 = self.weight else: scale_fn_name = self.config.delayed_scaling_config.scale_fn_name _maybe_initialize_amaxes_scales_for_float8_cast( - w, + weight, self.fp8_amax_weight, self.fp8_amax_history_weight, self.fp8_scale_weight, @@ -328,29 +328,31 @@ def cast_w_to_float8( reduce_amax=False, ) - w_fp8 = Float8Tensor.to_float8( - w, + weight_fp8 = Float8Tensor.to_float8( + weight, self.fp8_scale_weight, e4m3_dtype, self.fp8_amax_weight, linear_mm_config=self.linear_mm_config, - gemm_input_role=GemmInputRole.W, + gemm_input_role=GemmInputRole.WEIGHT, ) else: assert self.scaling_type_weight is TensorScalingType.DYNAMIC if isinstance(self.weight, Float8Tensor): # cast by FSDP - w_fp8 = self.weight + weight_fp8 = self.weight else: - w_fp8 = cast_to_float8_e4m3_dynamic( - self.weight, self.linear_mm_config, gemm_input_role=GemmInputRole.W + weight_fp8 = cast_to_float8_e4m3_dynamic( + self.weight, + self.linear_mm_config, + gemm_input_role=GemmInputRole.WEIGHT, ) - return w_fp8 + return weight_fp8 - def cast_y_to_float8_in_bw(self, y: torch.Tensor) -> torch.Tensor: + def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor: if self.scaling_type_grad_output is TensorScalingType.DELAYED: scale_fn_name = self.config.delayed_scaling_config.scale_fn_name - y = NoopFwToFloat8E5M2Bw.apply( - y, + output = NoopFwToFloat8E5M2Bw.apply( + output, self.fp8_amax_grad_output, self.fp8_amax_history_grad_output, self.fp8_scale_grad_output, @@ -360,10 +362,10 @@ def cast_y_to_float8_in_bw(self, y: torch.Tensor) -> torch.Tensor: ) else: assert self.scaling_type_grad_output is TensorScalingType.DYNAMIC - y = cast_to_float8_e5m2_dynamic_bw(y, self.linear_mm_config) - return y + output = cast_to_float8_e5m2_dynamic_bw(output, self.linear_mm_config) + return output - def float8_pre_forward(self, x): + def float8_pre_forward(self, input): if not self.enable_pre_and_post_forward: return if ( @@ -374,7 +376,7 @@ def float8_pre_forward(self, x): raise AssertionError( "amaxes and scales not synced, please call `sync_float8_amax_and_scale_history` before forward" ) - self.last_seen_input_dtype = x.dtype + self.last_seen_input_dtype = input.dtype def float8_post_forward(self): if not self.enable_pre_and_post_forward: @@ -388,25 +390,25 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: if self.has_any_delayed_scaling: self.float8_pre_forward(input) - x_fp8 = self.cast_x_to_float8(input, self.is_amax_initialized) - w_fp8 = self.cast_w_to_float8(self.weight, self.is_amax_initialized) + input_fp8 = self.cast_input_to_float8(input, self.is_amax_initialized) + weight_fp8 = self.cast_weight_to_float8(self.weight, self.is_amax_initialized) - y = torch.matmul(x_fp8, w_fp8.t()) + output = torch.matmul(input_fp8, weight_fp8.t()) - # Cast gradY to float8_e5m2 during backward - y = self.cast_y_to_float8_in_bw(y) + # Cast grad_output to float8_e5m2 during backward + output = self.cast_output_to_float8_in_bw(output) if self.bias is not None: - y = y + self.bias.to(y.dtype) + output = output + self.bias.to(output.dtype) if self.has_any_delayed_scaling: self.float8_post_forward() - return y + return output def scaling_repr(self): # add scaling settings without using too many characters - # example: "x:del,w:del,dldy:dyn" - return f"x:{self.scaling_type_input.short_str()},w:{self.scaling_type_weight.short_str()},dldy:{self.scaling_type_grad_output.short_str()}" + # example: "i:del,w:del,go:dyn" + return f"i:{self.scaling_type_input.short_str()},w:{self.scaling_type_weight.short_str()},go:{self.scaling_type_grad_output.short_str()}" def extra_repr(self): s = f'{super().extra_repr()}, scaling="{self.scaling_repr()}"' diff --git a/float8_experimental/float8_tensor.py b/float8_experimental/float8_tensor.py index 3b108be..a46e7ce 100644 --- a/float8_experimental/float8_tensor.py +++ b/float8_experimental/float8_tensor.py @@ -27,21 +27,21 @@ # # There are three gemms in a forward + backward of a Linear layer: # -# 1. x @ w_t = y (forward pass) -# 2. dL_dY @ w = dL_dX (backward pass) -# 3. x_t @ dL_dY = dL_dW (backward pass) +# 1. input @ weight_t = output (forward pass) +# 2. grad_output @ weight = grad_input (backward pass) +# 3. input_t @ grad_output = grad_weight (backward pass) # # In the formulas above, there are: -# A. six input tensors (x, x_t, w, w_t, dL_dY, dL_dY_t). -# - Note that dL_dY_t is implied because of memory format requirements +# A. six input tensors (input, input_t, weight, weight_t, grad_output, grad_output_t). +# - Note that grad_output_t is implied because of memory format requirements # of float8 gemms -# B. three output tensors (y, dL_dX, dL_dW) +# B. three output tensors (output, grad_input, grad_weight) # # We want each input tensor, gemm, and output tensor to be configurable. # The state of this configuration today is: # # i. pairs of input tensors (non-t and t variants) have their scaling -# configurable via the scaling_type_{x_w_dL_dY} arguments to Float8Linear +# configurable via the scaling_type_* arguments to Float8Linear # ii. each gemm + output is configurable via ScaledMMConfig, which is not user facing # iii. LinearMMConfig is a container for the three ScaledMMConfig objects needed # to configure all three gemms, also not user facing @@ -60,11 +60,12 @@ # The object below is not user facing and exists for convenience, # to allow Float8Tensor to use -# the right config based on which gemm from `y`, `dL_dX`, `dL_dW` is +# the right config based on which gemm from gemms with outputs +# `output`, `grad_input`, `grad_weight` is # being called. LinearMMConfig = namedtuple( "LinearMMConfig", - ["y", "dL_dX", "dL_dW"], + ["output", "grad_input", "grad_weight"], defaults=[ ScaledMMConfig(False, True, False, False), ScaledMMConfig(False, False, False, False), @@ -81,9 +82,9 @@ class GemmInputRole(enum.Enum): gemm is performed. """ - X = "x" - W = "w" - DL_DY = "dL_dY" + INPUT = "input" + WEIGHT = "weight" + GRAD_OUTPUT = "grad_output" # choose which scaled_mm_config to use based on gemm inputs @@ -93,21 +94,21 @@ def choose_scaled_mm_config( b_role: GemmInputRole, b_linear_mm_config: LinearMMConfig, ): - if a_role is GemmInputRole.X and b_role is GemmInputRole.W: + if a_role is GemmInputRole.INPUT and b_role is GemmInputRole.WEIGHT: assert ( - a_linear_mm_config.y == b_linear_mm_config.y - ), f"linear_mm_config.y mismatch: {a_linear_mm_config.y} vs {b_linear_mm_config.y}" - return a_linear_mm_config.y - elif a_role is GemmInputRole.DL_DY and b_role is GemmInputRole.W: + a_linear_mm_config.output == b_linear_mm_config.output + ), f"linear_mm_config.output mismatch: {a_linear_mm_config.output} vs {b_linear_mm_config.output}" + return a_linear_mm_config.output + elif a_role is GemmInputRole.GRAD_OUTPUT and b_role is GemmInputRole.WEIGHT: assert ( - a_linear_mm_config.dL_dX == b_linear_mm_config.dL_dX - ), f"linear_mm_config.dL_dX mismatch: {a_linear_mm_config.dL_dX} vs {b_linear_mm_config.dL_dX}" - return a_linear_mm_config.dL_dX - elif a_role is GemmInputRole.DL_DY and b_role is GemmInputRole.X: + a_linear_mm_config.grad_input == b_linear_mm_config.grad_input + ), f"linear_mm_config.grad_input mismatch: {a_linear_mm_config.grad_input} vs {b_linear_mm_config.grad_input}" + return a_linear_mm_config.grad_input + elif a_role is GemmInputRole.GRAD_OUTPUT and b_role is GemmInputRole.INPUT: assert ( - a_linear_mm_config.dL_dW == b_linear_mm_config.dL_dW - ), f"linear_mm_config.dL_dW mismatch: {a_linear_mm_config.dL_dW} vs {b_linear_mm_config.dL_dW}" - return a_linear_mm_config.dL_dW + a_linear_mm_config.grad_weight == b_linear_mm_config.grad_weight + ), f"linear_mm_config.grad_weight mismatch: {a_linear_mm_config.grad_weight} vs {b_linear_mm_config.grad_weight}" + return a_linear_mm_config.grad_weight else: raise AssertionError(f"unexpected a_role {a_role} and b_role {b_role}") @@ -207,7 +208,7 @@ def forward( float8_dtype=e4m3_dtype, amax_buffer: Optional[torch.Tensor] = None, linear_mm_config: Optional[LinearMMConfig] = None, - gemm_input_role: Optional[GemmInputRole] = GemmInputRole.X, + gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT, ): """Autograd enabled wrapper around to_fp8_no_autograd that will also populate the amax buffer. Args @@ -287,7 +288,7 @@ def __new__( scale: torch.Tensor, orig_dtype: torch.dtype, linear_mm_config: Optional[LinearMMConfig], - gemm_input_role: Optional[GemmInputRole] = GemmInputRole.X, + gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT, ): assert ( scale.numel() == 1 @@ -348,7 +349,7 @@ def to_float8( float8_dtype: torch.dtype, amax_buffer: Optional[torch.Tensor] = None, linear_mm_config: Optional[LinearMMConfig] = None, - gemm_input_role: Optional[GemmInputRole] = GemmInputRole.X, + gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT, ): """Converts a higher precision tensor to float8 in a differentiable way. diff --git a/float8_experimental/float8_tensor_parallel.py b/float8_experimental/float8_tensor_parallel.py index 1841553..99850ad 100644 --- a/float8_experimental/float8_tensor_parallel.py +++ b/float8_experimental/float8_tensor_parallel.py @@ -48,7 +48,7 @@ def _prepare_input_fn( input_tensor = cast_to_float8_e4m3_dynamic( input_tensor, mod.linear_mm_config, - gemm_input_role=GemmInputRole.X, + gemm_input_role=GemmInputRole.INPUT, ) # DTensor(Float8Tensor) # transform the input layouts to the desired layouts of ColwiseParallel @@ -101,7 +101,7 @@ def _prepare_input_fn( input_tensor = cast_to_float8_e4m3_dynamic( input_tensor, mod.linear_mm_config, - gemm_input_role=GemmInputRole.X, + gemm_input_role=GemmInputRole.INPUT, ) # DTensor(Float8Tensor) if input_layouts != desired_input_layouts: @@ -199,7 +199,7 @@ def _prepare_input_arg(self, input, mesh, input_layout, desired_layout): dt_inp = cast_to_float8_e4m3_dynamic( dt_inp, self.linear_mm_config, - gemm_input_role=GemmInputRole.X, + gemm_input_role=GemmInputRole.INPUT, ) # DTensor(Float8Tensor) if desired_layout is not None and input_layout != desired_layout: dt_inp = dt_inp.redistribute(placements=(desired_layout,)) diff --git a/float8_experimental/fsdp_utils.py b/float8_experimental/fsdp_utils.py index d9fd200..5de51e3 100644 --- a/float8_experimental/fsdp_utils.py +++ b/float8_experimental/fsdp_utils.py @@ -171,14 +171,14 @@ def fsdp_pre_all_gather(self, mesh): self._precomputed_scale, torch.float8_e4m3fn, linear_mm_config=self._linear_mm_config, - gemm_input_role=GemmInputRole.W, + gemm_input_role=GemmInputRole.WEIGHT, ) else: float8_tensor = cast_to_float8_e4m3_dynamic( self._tensor, self._linear_mm_config, reduce_amax=True, - gemm_input_role=GemmInputRole.W, + gemm_input_role=GemmInputRole.WEIGHT, ) return (float8_tensor._data,), (float8_tensor._scale,) @@ -201,7 +201,7 @@ def fsdp_post_all_gather( scale, param_dtype, self._linear_mm_config, - gemm_input_role=GemmInputRole.W, + gemm_input_role=GemmInputRole.WEIGHT, ), (data,) @@ -364,7 +364,7 @@ def fsdp_pre_all_gather(self, mesh): e4m3_dtype, self._amax_buffer, self._linear_mm_config, - gemm_input_role=GemmInputRole.W, + gemm_input_role=GemmInputRole.WEIGHT, ) return (float8_tensor._data,), (float8_tensor._scale,) @@ -387,5 +387,5 @@ def fsdp_post_all_gather( scale, param_dtype, self._linear_mm_config, - gemm_input_role=GemmInputRole.W, + gemm_input_role=GemmInputRole.WEIGHT, ), (data,) diff --git a/float8_experimental/inference.py b/float8_experimental/inference.py index d24fedb..0c10589 100644 --- a/float8_experimental/inference.py +++ b/float8_experimental/inference.py @@ -132,7 +132,7 @@ def quantize_weight(self, dtype: torch.dtype = e4m3_dtype) -> None: scale, dtype, self.linear_mm_config, - gemm_input_role=GemmInputRole.W, + gemm_input_role=GemmInputRole.WEIGHT, ) self.weight = nn.Parameter(quantized_weight) self.weight.requires_grad = False @@ -205,7 +205,7 @@ def cast_to_float8_e4m3_inference( scale, e4m3_dtype, linear_mm_config=linear_mm_config, - gemm_input_role=GemmInputRole.X, + gemm_input_role=GemmInputRole.INPUT, ) diff --git a/test/test_base.py b/test/test_base.py index 4d36ad1..ffc8d0c 100644 --- a/test/test_base.py +++ b/test/test_base.py @@ -395,7 +395,7 @@ def test_repr(self): config=config, ) s = m.__repr__() - assert "x:dyn,w:del,dldy:dyn" in s + assert "i:dyn,w:del,go:dyn" in s class TestScaledMM: @@ -464,18 +464,18 @@ def test_different_configs_error(self): x_scale, fp8_dtype, linear_mm_config=linear_config_a, - gemm_input_role=GemmInputRole.X, + gemm_input_role=GemmInputRole.INPUT, ) b = Float8Tensor.to_float8( x_fp32, x_scale, fp8_dtype, linear_mm_config=linear_config_b, - gemm_input_role=GemmInputRole.W, + gemm_input_role=GemmInputRole.WEIGHT, ) with pytest.raises( AssertionError, - match="linear_mm_config.y mismatch", + match="linear_mm_config.output mismatch", ): a @ b @@ -499,10 +499,10 @@ def test_pad_inner_dim(self, base_dtype, use_fast_accum): b_scale = tensor_to_scale(b, input_dtype).float() a_fp8 = Float8Tensor.to_float8( - a, a_scale, input_dtype, gemm_input_role=GemmInputRole.X + a, a_scale, input_dtype, gemm_input_role=GemmInputRole.INPUT ) b_fp8 = Float8Tensor.to_float8( - b, b_scale, input_dtype, gemm_input_role=GemmInputRole.W + b, b_scale, input_dtype, gemm_input_role=GemmInputRole.WEIGHT ) with pytest.raises( @@ -523,14 +523,14 @@ def test_pad_inner_dim(self, base_dtype, use_fast_accum): a_scale, input_dtype, linear_mm_config=pad_config, - gemm_input_role=GemmInputRole.X, + gemm_input_role=GemmInputRole.INPUT, ) b_fp8 = Float8Tensor.to_float8( b, b_scale, input_dtype, linear_mm_config=pad_config, - gemm_input_role=GemmInputRole.W, + gemm_input_role=GemmInputRole.WEIGHT, ) out_padded = a_fp8 @ b_fp8 out_padded.to(compare_type) @@ -546,14 +546,14 @@ def test_pad_inner_dim(self, base_dtype, use_fast_accum): a_scale, input_dtype, linear_mm_config=emulated_config, - gemm_input_role=GemmInputRole.X, + gemm_input_role=GemmInputRole.INPUT, ) b_fp8 = Float8Tensor.to_float8( b, b_scale, input_dtype, linear_mm_config=emulated_config, - gemm_input_role=GemmInputRole.W, + gemm_input_role=GemmInputRole.WEIGHT, ) out_emualted = a_fp8 @ b_fp8 out_emualted.to(compare_type) @@ -606,8 +606,8 @@ def test_swap_root_linear(self): config = Float8LinearConfig(emulate=emulate) module = convert_to_float8_training(module, config=config) self.assertIsInstance(module, Float8Linear) - self.assertEqual(module.linear_mm_config.y.emulate, emulate) - self.assertEqual(module.linear_mm_config.y.emulate, emulate) + self.assertEqual(module.linear_mm_config.output.emulate, emulate) + self.assertEqual(module.linear_mm_config.output.emulate, emulate) def test_swap_root_linear_with_children_raises(self): for emulate in [True, False]: diff --git a/test/test_compile.py b/test/test_compile.py index a457dd8..e7b5285 100644 --- a/test/test_compile.py +++ b/test/test_compile.py @@ -256,9 +256,9 @@ def test_float8_graph_output(self): type(y_compiled._orig_dtype) ) assert isinstance( - y_compiled._linear_mm_config.y.emulate, bool + y_compiled._linear_mm_config.output.emulate, bool ), "Float8Tensor._emulate should be a bool but got {}".format( - type(y_compiled._linear_mm_config.y.emulate) + type(y_compiled._linear_mm_config.output.emulate) ) diff --git a/test/test_dtensor.py b/test/test_dtensor.py index 0b294d8..eeca6df 100644 --- a/test/test_dtensor.py +++ b/test/test_dtensor.py @@ -87,10 +87,10 @@ def test_scaled_mm(mesh: DeviceMesh, size=16): y_scale = tensor_to_scale(y_fp32, fp8_dtype).float() x_fp8 = Float8Tensor.to_float8( - x_fp32, x_scale, fp8_dtype, gemm_input_role=GemmInputRole.X + x_fp32, x_scale, fp8_dtype, gemm_input_role=GemmInputRole.INPUT ) y_fp8 = Float8Tensor.to_float8( - y_fp32, y_scale, fp8_dtype, gemm_input_role=GemmInputRole.W + y_fp32, y_scale, fp8_dtype, gemm_input_role=GemmInputRole.WEIGHT ) dist_x_fp8 = DTensor.from_local(x_fp8, mesh, [lhs_placement], run_check=False) @@ -164,10 +164,13 @@ def test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16): dist_target = distribute_tensor(target, mesh, [Shard(0)]) dist_x_fp8 = Float8Tensor.to_float8( - dist_x_fp32, dist_x_scale, fp8_dtype, gemm_input_role=GemmInputRole.X + dist_x_fp32, dist_x_scale, fp8_dtype, gemm_input_role=GemmInputRole.INPUT ) dist_weight_fp8 = Float8Tensor.to_float8( - dist_wight_fp32, dist_weight_scale, fp8_dtype, gemm_input_role=GemmInputRole.W + dist_wight_fp32, + dist_weight_scale, + fp8_dtype, + gemm_input_role=GemmInputRole.WEIGHT, ) out = torch.nn.functional.linear(dist_x_fp8, dist_weight_fp8)