Skip to content

Commit 853467d

Browse files
fp8 rowwise all gather for tp
1 parent cdced21 commit 853467d

File tree

2 files changed

+314
-1
lines changed

2 files changed

+314
-1
lines changed

torchao/float8/float8_scaling_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def get_maybe_axiswise_dim(
9191
class NoopFwToFloat8BwDynamic(torch.autograd.Function):
9292
"""
9393
Forward: no-op
94-
Backward: convert to float8_e5m2 with dynamic scaling
94+
Backward: convert to target float8 dtype with dynamic scaling
9595
"""
9696

9797
@staticmethod
@@ -100,9 +100,11 @@ def forward(
100100
tensor,
101101
linear_mm_config: LinearMMConfig,
102102
target_dtype: torch.dtype,
103+
axiswise_dim: int,
103104
):
104105
ctx.linear_mm_config = linear_mm_config
105106
ctx.target_dtype = target_dtype
107+
ctx.axiswise_dim = axiswise_dim
106108
return tensor
107109

108110
@staticmethod
@@ -116,5 +118,6 @@ def backward(ctx, gradY):
116118
ctx.target_dtype,
117119
ctx.linear_mm_config,
118120
GemmInputRole.GRAD_OUTPUT,
121+
axiswise_dim=ctx.axiswise_dim,
119122
)
120123
return fp8_tensor, None, None
Lines changed: 310 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,310 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
from math import modf
7+
import torch
8+
import torch.nn as nn
9+
from torch.distributed._tensor import DTensor
10+
from torch.distributed.device_mesh import DeviceMesh
11+
from torch.distributed.tensor.parallel import (
12+
ColwiseParallel,
13+
PrepareModuleInput,
14+
RowwiseParallel,
15+
)
16+
17+
from torchao.float8.config import ScalingType, e4m3_dtype
18+
from torchao.float8.distributed_utils import tensor_already_casted_to_fp8
19+
from torchao.float8.float8_scaling_utils import (
20+
hp_tensor_to_float8_dynamic,
21+
)
22+
from torchao.float8.float8_tensor import GemmInputRole, LinearMMConfig
23+
24+
# subclass the ColwiseParallel and RowwiseParallel classes
25+
# to add the float8 support
26+
# The parameter sharding stays the same as the core
27+
# ColwiseParallel and RowwiseParallel, the only difference
28+
# here is that in input/output handling we do casting after
29+
# creating the DTensor.
30+
31+
# NOTE: This only works and tested with the dynamic scaling
32+
33+
34+
def _float8_linear_supports_float8_allgather(m):
35+
# TODO(future PR): also gate this by granularity
36+
return (
37+
m.scaling_type_input == ScalingType.DYNAMIC
38+
and m.scaling_type_grad_output == ScalingType.DYNAMIC
39+
)
40+
41+
42+
class Float8ColwiseParallel(ColwiseParallel):
43+
"""
44+
Like `ColwiseParallel`, but with all-gather in float8 with rowwise scales.
45+
"""
46+
47+
@staticmethod
48+
def _prepare_input_fn(
49+
input_layouts, desired_input_layouts, mod, inputs, device_mesh,
50+
):
51+
# annotate module input placements/sharding with input_layouts
52+
input_tensor = inputs[0]
53+
if not isinstance(input_tensor, DTensor):
54+
input_tensor = DTensor.from_local(
55+
input_tensor, device_mesh, input_layouts, run_check=False
56+
)
57+
58+
if not tensor_already_casted_to_fp8(input_tensor):
59+
input_tensor = Float8RowwiseFwdColwiseBwd.apply(input_tensor, mod)
60+
61+
# transform the input layouts to the desired layouts of ColwiseParallel
62+
if input_layouts != desired_input_layouts:
63+
input_tensor = input_tensor.redistribute(
64+
placements=desired_input_layouts, async_op=True
65+
)
66+
return input_tensor
67+
68+
@staticmethod
69+
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
70+
# outputs is a shard on last dimension DTensor, i.e. Shard(-1)
71+
if outputs.placements != output_layouts:
72+
outputs = outputs.redistribute(
73+
placements=output_layouts, async_op=True
74+
) # DTensor(torch.Tensor)
75+
76+
outputs = NoopFwToFloat8RowwiseBwDynamic.apply(
77+
outputs,
78+
mod.linear_mm_config,
79+
mod.config.cast_config_grad_output.target_dtype,
80+
-1,
81+
)
82+
83+
# back to local tensor
84+
return outputs.to_local() if use_local_output else outputs
85+
86+
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
87+
from torchao.float8.float8_linear import Float8Linear
88+
89+
if not isinstance(module, Float8Linear):
90+
raise ValueError(
91+
f"Expecting module to be Float8Linear but found {type(module)}"
92+
)
93+
elif isinstance(
94+
module, Float8Linear
95+
) and not _float8_linear_supports_float8_allgather(module):
96+
raise AssertionError("unsupported")
97+
98+
return super()._apply(module, device_mesh)
99+
100+
101+
class Float8RowwiseParallel(RowwiseParallel):
102+
"""
103+
Like `RowwiseParallel`, but with all-gather in float8 with rowwise scales.
104+
"""
105+
106+
@staticmethod
107+
def _prepare_input_fn(
108+
input_layouts, desired_input_layouts, mod, inputs, device_mesh,
109+
):
110+
input_tensor = inputs[0]
111+
if not isinstance(input_tensor, DTensor):
112+
input_tensor = DTensor.from_local(
113+
input_tensor, device_mesh, input_layouts, run_check=False
114+
)
115+
116+
if not tensor_already_casted_to_fp8(input_tensor):
117+
input_tensor = Float8RowwiseFwdColwiseBwd.apply(input_tensor, mod)
118+
119+
if input_layouts != desired_input_layouts:
120+
input_tensor = input_tensor.redistribute(
121+
placements=desired_input_layouts, async_op=True
122+
)
123+
return input_tensor
124+
125+
@staticmethod
126+
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
127+
# Rowwise sharding produces partial output, depending on output layouts:
128+
# 1. to replicate -> allreduce
129+
# 2. to shard -> reduce_scatter
130+
if outputs.placements != output_layouts:
131+
outputs = outputs.redistribute(placements=output_layouts, async_op=True)
132+
133+
outputs = NoopFwToFloat8RowwiseBwDynamic.apply(
134+
outputs,
135+
mod.linear_mm_config,
136+
mod.config.cast_config_grad_output.target_dtype,
137+
-1,
138+
)
139+
140+
# back to local tensor if use_local_output is True
141+
return outputs.to_local() if use_local_output else outputs
142+
143+
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
144+
from torchao.float8.float8_linear import Float8Linear
145+
146+
if not isinstance(module, Float8Linear):
147+
raise ValueError(
148+
f"Expecting module to be Float8Linear but found {type(module)}"
149+
)
150+
elif isinstance(
151+
module, Float8Linear
152+
) and not _float8_linear_supports_float8_allgather(module):
153+
raise AssertionError("unsupported")
154+
155+
return super()._apply(module, device_mesh)
156+
157+
158+
class PrepareFloat8ModuleInput(PrepareModuleInput):
159+
"""
160+
Like `PrepareModuleInput`, but with all-gather in float8 with rowwise scales.
161+
162+
The only difference from `PrepareModuleInput` is that
163+
after we prepare the input DTensor, we cast the input to DTensor(Float8Tensor)
164+
This is to ensure the float8 cast happens before the all-gather (i.e. Shard -> Replicate)
165+
so that if there are multiple float8 users of the input activation, we perform fp8 allgather
166+
only once.
167+
FP8 Args:
168+
float8_dtype (torch.dtype, optional): control what float8 dtype to cast to when prepare the module input,
169+
we currently only support torch.float8_e4m3fn. default: torch.float8_e4m3fn
170+
fwd_config_submodule_fqn (str, optional): the fqn of the submodule that contains the forward config used
171+
for the float8 cast. If not specified, we will search for the Float8Linear in the submodules
172+
and use the forward config from that module, in this case all module's forward config must be
173+
the same.
174+
"""
175+
176+
def __init__(
177+
self,
178+
*,
179+
input_layouts=None,
180+
desired_input_layouts=None,
181+
input_kwarg_layouts=None,
182+
desired_input_kwarg_layouts=None,
183+
use_local_output=False,
184+
float8_dtype=torch.float8_e4m3fn,
185+
fwd_config_submodule_fqn=None,
186+
):
187+
super().__init__(
188+
input_layouts=input_layouts,
189+
desired_input_layouts=desired_input_layouts,
190+
input_kwarg_layouts=input_kwarg_layouts,
191+
desired_input_kwarg_layouts=desired_input_kwarg_layouts,
192+
use_local_output=use_local_output,
193+
)
194+
195+
# fp8 specific fields
196+
self.float8_dtype = float8_dtype
197+
self.linear_mm_config = None
198+
self.fwd_config_submodule_fqn = fwd_config_submodule_fqn
199+
200+
if self.float8_dtype != torch.float8_e4m3fn:
201+
raise NotImplementedError(
202+
"PrepareFloat8ModuleInput only support casting to float8_e4m3fn for now"
203+
)
204+
205+
def _prepare_input_arg(self, input, mesh, input_layout, desired_layout):
206+
if input_layout is not None:
207+
if isinstance(input, DTensor):
208+
# TODO: re-enable the check once we fix the compile path
209+
# assert inp.placements[0] == input_layout
210+
dt_inp = input
211+
else:
212+
assert isinstance(input, torch.Tensor), (
213+
"expecting input to be a torch.Tensor!"
214+
)
215+
dt_inp = DTensor.from_local(
216+
input, mesh, (input_layout,), run_check=False
217+
)
218+
219+
dt_inp = Float8RowwiseFwdColwiseBwd.apply(input, mod)
220+
if desired_layout is not None and input_layout != desired_layout:
221+
dt_inp = dt_inp.redistribute(placements=(desired_layout,))
222+
223+
return dt_inp.to_local() if self.use_local_output else dt_inp
224+
else:
225+
return input
226+
227+
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
228+
from torchao.float8.float8_linear import Float8Linear
229+
230+
if self.fwd_config_submodule_fqn is not None:
231+
fwd_linear = module.get_submodule(self.fwd_config_submodule_fqn)
232+
assert isinstance(fwd_linear, Float8Linear)
233+
self.linear_mm_config = fwd_linear.linear_mm_config
234+
else:
235+
# search for ScaledMM configs for all the submodules and make sure they are the same
236+
for mod in module.modules():
237+
if isinstance(mod, Float8Linear):
238+
if self.linear_mm_config is None:
239+
self.linear_mm_config = mod.linear_mm_config
240+
else:
241+
assert self.linear_mm_config == mod.linear_mm_config, (
242+
"All the Float8Linear modules should have same linear_mm_config!"
243+
)
244+
245+
assert self.linear_mm_config is not None
246+
super()._apply(module, device_mesh)
247+
return module
248+
249+
250+
class Float8RowwiseFwdColwiseBwd(torch.autograd.Function):
251+
@staticmethod
252+
def forward(ctx, tensor: torch.Tensor, mod: nn.Module) -> torch.Tensor:
253+
ctx.mod_config = mod.config
254+
ctx.linear_mm_config = mod.linear_mm_config
255+
if not tensor_already_casted_to_fp8(tensor):
256+
return hp_tensor_to_float8_dynamic(
257+
tensor,
258+
mod.config.cast_config_input.target_dtype,
259+
mod.linear_mm_config,
260+
gemm_input_role=GemmInputRole.INPUT,
261+
axiswise_dim=-1,
262+
) # DTensor(Float8Tensor)
263+
return tensor
264+
265+
@staticmethod
266+
def backward(ctx, tensor: torch.Tensor) -> torch.Tensor:
267+
if not tensor_already_casted_to_fp8(tensor):
268+
return hp_tensor_to_float8_dynamic(
269+
tensor,
270+
ctx.config.cast_config_input.target_dtype,
271+
ctx.linear_mm_config,
272+
gemm_input_role=GemmInputRole.INPUT,
273+
axiswise_dim=0,
274+
) # DTensor(Float8Tensor)
275+
return tensor
276+
277+
278+
279+
@torch._dynamo.allow_in_graph
280+
class NoopFwToFloat8RowwiseBwDynamic(torch.autograd.Function):
281+
"""
282+
Forward: no-op
283+
Backward: convert to target float8 dtype with dynamic scaling
284+
"""
285+
286+
@staticmethod
287+
def forward(
288+
ctx,
289+
tensor,
290+
linear_mm_config: LinearMMConfig,
291+
target_dtype: torch.dtype,
292+
axiswise_dim: int,
293+
):
294+
ctx.linear_mm_config = linear_mm_config
295+
ctx.target_dtype = target_dtype
296+
ctx.axiswise_dim = axiswise_dim
297+
return tensor
298+
299+
@staticmethod
300+
def backward(ctx, gradY):
301+
if tensor_already_casted_to_fp8(gradY):
302+
return gradY, None, None
303+
fp8_tensor = hp_tensor_to_float8_dynamic(
304+
gradY,
305+
ctx.target_dtype,
306+
ctx.linear_mm_config,
307+
gemm_input_role=GemmInputRole.GRAD_OUTPUT,
308+
axiswise_dim=ctx.axiswise_dim,
309+
)
310+
return fp8_tensor, None, None, None

0 commit comments

Comments
 (0)