|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD 3-Clause license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | +from math import modf |
| 7 | +import torch |
| 8 | +import torch.nn as nn |
| 9 | +from torch.distributed._tensor import DTensor |
| 10 | +from torch.distributed.device_mesh import DeviceMesh |
| 11 | +from torch.distributed.tensor.parallel import ( |
| 12 | + ColwiseParallel, |
| 13 | + PrepareModuleInput, |
| 14 | + RowwiseParallel, |
| 15 | +) |
| 16 | + |
| 17 | +from torchao.float8.config import ScalingType, e4m3_dtype |
| 18 | +from torchao.float8.distributed_utils import tensor_already_casted_to_fp8 |
| 19 | +from torchao.float8.float8_scaling_utils import ( |
| 20 | + hp_tensor_to_float8_dynamic, |
| 21 | +) |
| 22 | +from torchao.float8.float8_tensor import GemmInputRole, LinearMMConfig |
| 23 | + |
| 24 | +# subclass the ColwiseParallel and RowwiseParallel classes |
| 25 | +# to add the float8 support |
| 26 | +# The parameter sharding stays the same as the core |
| 27 | +# ColwiseParallel and RowwiseParallel, the only difference |
| 28 | +# here is that in input/output handling we do casting after |
| 29 | +# creating the DTensor. |
| 30 | + |
| 31 | +# NOTE: This only works and tested with the dynamic scaling |
| 32 | + |
| 33 | + |
| 34 | +def _float8_linear_supports_float8_allgather(m): |
| 35 | + # TODO(future PR): also gate this by granularity |
| 36 | + return ( |
| 37 | + m.scaling_type_input == ScalingType.DYNAMIC |
| 38 | + and m.scaling_type_grad_output == ScalingType.DYNAMIC |
| 39 | + ) |
| 40 | + |
| 41 | + |
| 42 | +class Float8ColwiseParallel(ColwiseParallel): |
| 43 | + """ |
| 44 | + Like `ColwiseParallel`, but with all-gather in float8 with rowwise scales. |
| 45 | + """ |
| 46 | + |
| 47 | + @staticmethod |
| 48 | + def _prepare_input_fn( |
| 49 | + input_layouts, desired_input_layouts, mod, inputs, device_mesh, |
| 50 | + ): |
| 51 | + # annotate module input placements/sharding with input_layouts |
| 52 | + input_tensor = inputs[0] |
| 53 | + if not isinstance(input_tensor, DTensor): |
| 54 | + input_tensor = DTensor.from_local( |
| 55 | + input_tensor, device_mesh, input_layouts, run_check=False |
| 56 | + ) |
| 57 | + |
| 58 | + if not tensor_already_casted_to_fp8(input_tensor): |
| 59 | + input_tensor = Float8RowwiseFwdColwiseBwd.apply(input_tensor, mod) |
| 60 | + |
| 61 | + # transform the input layouts to the desired layouts of ColwiseParallel |
| 62 | + if input_layouts != desired_input_layouts: |
| 63 | + input_tensor = input_tensor.redistribute( |
| 64 | + placements=desired_input_layouts, async_op=True |
| 65 | + ) |
| 66 | + return input_tensor |
| 67 | + |
| 68 | + @staticmethod |
| 69 | + def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): |
| 70 | + # outputs is a shard on last dimension DTensor, i.e. Shard(-1) |
| 71 | + if outputs.placements != output_layouts: |
| 72 | + outputs = outputs.redistribute( |
| 73 | + placements=output_layouts, async_op=True |
| 74 | + ) # DTensor(torch.Tensor) |
| 75 | + |
| 76 | + outputs = NoopFwToFloat8RowwiseBwDynamic.apply( |
| 77 | + outputs, |
| 78 | + mod.linear_mm_config, |
| 79 | + mod.config.cast_config_grad_output.target_dtype, |
| 80 | + -1, |
| 81 | + ) |
| 82 | + |
| 83 | + # back to local tensor |
| 84 | + return outputs.to_local() if use_local_output else outputs |
| 85 | + |
| 86 | + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: |
| 87 | + from torchao.float8.float8_linear import Float8Linear |
| 88 | + |
| 89 | + if not isinstance(module, Float8Linear): |
| 90 | + raise ValueError( |
| 91 | + f"Expecting module to be Float8Linear but found {type(module)}" |
| 92 | + ) |
| 93 | + elif isinstance( |
| 94 | + module, Float8Linear |
| 95 | + ) and not _float8_linear_supports_float8_allgather(module): |
| 96 | + raise AssertionError("unsupported") |
| 97 | + |
| 98 | + return super()._apply(module, device_mesh) |
| 99 | + |
| 100 | + |
| 101 | +class Float8RowwiseParallel(RowwiseParallel): |
| 102 | + """ |
| 103 | + Like `RowwiseParallel`, but with all-gather in float8 with rowwise scales. |
| 104 | + """ |
| 105 | + |
| 106 | + @staticmethod |
| 107 | + def _prepare_input_fn( |
| 108 | + input_layouts, desired_input_layouts, mod, inputs, device_mesh, |
| 109 | + ): |
| 110 | + input_tensor = inputs[0] |
| 111 | + if not isinstance(input_tensor, DTensor): |
| 112 | + input_tensor = DTensor.from_local( |
| 113 | + input_tensor, device_mesh, input_layouts, run_check=False |
| 114 | + ) |
| 115 | + |
| 116 | + if not tensor_already_casted_to_fp8(input_tensor): |
| 117 | + input_tensor = Float8RowwiseFwdColwiseBwd.apply(input_tensor, mod) |
| 118 | + |
| 119 | + if input_layouts != desired_input_layouts: |
| 120 | + input_tensor = input_tensor.redistribute( |
| 121 | + placements=desired_input_layouts, async_op=True |
| 122 | + ) |
| 123 | + return input_tensor |
| 124 | + |
| 125 | + @staticmethod |
| 126 | + def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): |
| 127 | + # Rowwise sharding produces partial output, depending on output layouts: |
| 128 | + # 1. to replicate -> allreduce |
| 129 | + # 2. to shard -> reduce_scatter |
| 130 | + if outputs.placements != output_layouts: |
| 131 | + outputs = outputs.redistribute(placements=output_layouts, async_op=True) |
| 132 | + |
| 133 | + outputs = NoopFwToFloat8RowwiseBwDynamic.apply( |
| 134 | + outputs, |
| 135 | + mod.linear_mm_config, |
| 136 | + mod.config.cast_config_grad_output.target_dtype, |
| 137 | + -1, |
| 138 | + ) |
| 139 | + |
| 140 | + # back to local tensor if use_local_output is True |
| 141 | + return outputs.to_local() if use_local_output else outputs |
| 142 | + |
| 143 | + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: |
| 144 | + from torchao.float8.float8_linear import Float8Linear |
| 145 | + |
| 146 | + if not isinstance(module, Float8Linear): |
| 147 | + raise ValueError( |
| 148 | + f"Expecting module to be Float8Linear but found {type(module)}" |
| 149 | + ) |
| 150 | + elif isinstance( |
| 151 | + module, Float8Linear |
| 152 | + ) and not _float8_linear_supports_float8_allgather(module): |
| 153 | + raise AssertionError("unsupported") |
| 154 | + |
| 155 | + return super()._apply(module, device_mesh) |
| 156 | + |
| 157 | + |
| 158 | +class PrepareFloat8ModuleInput(PrepareModuleInput): |
| 159 | + """ |
| 160 | + Like `PrepareModuleInput`, but with all-gather in float8 with rowwise scales. |
| 161 | +
|
| 162 | + The only difference from `PrepareModuleInput` is that |
| 163 | + after we prepare the input DTensor, we cast the input to DTensor(Float8Tensor) |
| 164 | + This is to ensure the float8 cast happens before the all-gather (i.e. Shard -> Replicate) |
| 165 | + so that if there are multiple float8 users of the input activation, we perform fp8 allgather |
| 166 | + only once. |
| 167 | + FP8 Args: |
| 168 | + float8_dtype (torch.dtype, optional): control what float8 dtype to cast to when prepare the module input, |
| 169 | + we currently only support torch.float8_e4m3fn. default: torch.float8_e4m3fn |
| 170 | + fwd_config_submodule_fqn (str, optional): the fqn of the submodule that contains the forward config used |
| 171 | + for the float8 cast. If not specified, we will search for the Float8Linear in the submodules |
| 172 | + and use the forward config from that module, in this case all module's forward config must be |
| 173 | + the same. |
| 174 | + """ |
| 175 | + |
| 176 | + def __init__( |
| 177 | + self, |
| 178 | + *, |
| 179 | + input_layouts=None, |
| 180 | + desired_input_layouts=None, |
| 181 | + input_kwarg_layouts=None, |
| 182 | + desired_input_kwarg_layouts=None, |
| 183 | + use_local_output=False, |
| 184 | + float8_dtype=torch.float8_e4m3fn, |
| 185 | + fwd_config_submodule_fqn=None, |
| 186 | + ): |
| 187 | + super().__init__( |
| 188 | + input_layouts=input_layouts, |
| 189 | + desired_input_layouts=desired_input_layouts, |
| 190 | + input_kwarg_layouts=input_kwarg_layouts, |
| 191 | + desired_input_kwarg_layouts=desired_input_kwarg_layouts, |
| 192 | + use_local_output=use_local_output, |
| 193 | + ) |
| 194 | + |
| 195 | + # fp8 specific fields |
| 196 | + self.float8_dtype = float8_dtype |
| 197 | + self.linear_mm_config = None |
| 198 | + self.fwd_config_submodule_fqn = fwd_config_submodule_fqn |
| 199 | + |
| 200 | + if self.float8_dtype != torch.float8_e4m3fn: |
| 201 | + raise NotImplementedError( |
| 202 | + "PrepareFloat8ModuleInput only support casting to float8_e4m3fn for now" |
| 203 | + ) |
| 204 | + |
| 205 | + def _prepare_input_arg(self, input, mesh, input_layout, desired_layout): |
| 206 | + if input_layout is not None: |
| 207 | + if isinstance(input, DTensor): |
| 208 | + # TODO: re-enable the check once we fix the compile path |
| 209 | + # assert inp.placements[0] == input_layout |
| 210 | + dt_inp = input |
| 211 | + else: |
| 212 | + assert isinstance(input, torch.Tensor), ( |
| 213 | + "expecting input to be a torch.Tensor!" |
| 214 | + ) |
| 215 | + dt_inp = DTensor.from_local( |
| 216 | + input, mesh, (input_layout,), run_check=False |
| 217 | + ) |
| 218 | + |
| 219 | + dt_inp = Float8RowwiseFwdColwiseBwd.apply(input, mod) |
| 220 | + if desired_layout is not None and input_layout != desired_layout: |
| 221 | + dt_inp = dt_inp.redistribute(placements=(desired_layout,)) |
| 222 | + |
| 223 | + return dt_inp.to_local() if self.use_local_output else dt_inp |
| 224 | + else: |
| 225 | + return input |
| 226 | + |
| 227 | + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: |
| 228 | + from torchao.float8.float8_linear import Float8Linear |
| 229 | + |
| 230 | + if self.fwd_config_submodule_fqn is not None: |
| 231 | + fwd_linear = module.get_submodule(self.fwd_config_submodule_fqn) |
| 232 | + assert isinstance(fwd_linear, Float8Linear) |
| 233 | + self.linear_mm_config = fwd_linear.linear_mm_config |
| 234 | + else: |
| 235 | + # search for ScaledMM configs for all the submodules and make sure they are the same |
| 236 | + for mod in module.modules(): |
| 237 | + if isinstance(mod, Float8Linear): |
| 238 | + if self.linear_mm_config is None: |
| 239 | + self.linear_mm_config = mod.linear_mm_config |
| 240 | + else: |
| 241 | + assert self.linear_mm_config == mod.linear_mm_config, ( |
| 242 | + "All the Float8Linear modules should have same linear_mm_config!" |
| 243 | + ) |
| 244 | + |
| 245 | + assert self.linear_mm_config is not None |
| 246 | + super()._apply(module, device_mesh) |
| 247 | + return module |
| 248 | + |
| 249 | + |
| 250 | +class Float8RowwiseFwdColwiseBwd(torch.autograd.Function): |
| 251 | + @staticmethod |
| 252 | + def forward(ctx, tensor: torch.Tensor, mod: nn.Module) -> torch.Tensor: |
| 253 | + ctx.mod_config = mod.config |
| 254 | + ctx.linear_mm_config = mod.linear_mm_config |
| 255 | + if not tensor_already_casted_to_fp8(tensor): |
| 256 | + return hp_tensor_to_float8_dynamic( |
| 257 | + tensor, |
| 258 | + mod.config.cast_config_input.target_dtype, |
| 259 | + mod.linear_mm_config, |
| 260 | + gemm_input_role=GemmInputRole.INPUT, |
| 261 | + axiswise_dim=-1, |
| 262 | + ) # DTensor(Float8Tensor) |
| 263 | + return tensor |
| 264 | + |
| 265 | + @staticmethod |
| 266 | + def backward(ctx, tensor: torch.Tensor) -> torch.Tensor: |
| 267 | + if not tensor_already_casted_to_fp8(tensor): |
| 268 | + return hp_tensor_to_float8_dynamic( |
| 269 | + tensor, |
| 270 | + ctx.config.cast_config_input.target_dtype, |
| 271 | + ctx.linear_mm_config, |
| 272 | + gemm_input_role=GemmInputRole.INPUT, |
| 273 | + axiswise_dim=0, |
| 274 | + ) # DTensor(Float8Tensor) |
| 275 | + return tensor |
| 276 | + |
| 277 | + |
| 278 | + |
| 279 | +@torch._dynamo.allow_in_graph |
| 280 | +class NoopFwToFloat8RowwiseBwDynamic(torch.autograd.Function): |
| 281 | + """ |
| 282 | + Forward: no-op |
| 283 | + Backward: convert to target float8 dtype with dynamic scaling |
| 284 | + """ |
| 285 | + |
| 286 | + @staticmethod |
| 287 | + def forward( |
| 288 | + ctx, |
| 289 | + tensor, |
| 290 | + linear_mm_config: LinearMMConfig, |
| 291 | + target_dtype: torch.dtype, |
| 292 | + axiswise_dim: int, |
| 293 | + ): |
| 294 | + ctx.linear_mm_config = linear_mm_config |
| 295 | + ctx.target_dtype = target_dtype |
| 296 | + ctx.axiswise_dim = axiswise_dim |
| 297 | + return tensor |
| 298 | + |
| 299 | + @staticmethod |
| 300 | + def backward(ctx, gradY): |
| 301 | + if tensor_already_casted_to_fp8(gradY): |
| 302 | + return gradY, None, None |
| 303 | + fp8_tensor = hp_tensor_to_float8_dynamic( |
| 304 | + gradY, |
| 305 | + ctx.target_dtype, |
| 306 | + ctx.linear_mm_config, |
| 307 | + gemm_input_role=GemmInputRole.GRAD_OUTPUT, |
| 308 | + axiswise_dim=ctx.axiswise_dim, |
| 309 | + ) |
| 310 | + return fp8_tensor, None, None, None |
0 commit comments