Skip to content

Commit

Permalink
[fsdp] Propagate bucketing_strategy and sharding_strategy via ori…
Browse files Browse the repository at this point in the history
…ginal module (#424)
  • Loading branch information
crcrpar authored May 16, 2024
1 parent 1f298f8 commit 599243c
Show file tree
Hide file tree
Showing 5 changed files with 153 additions and 99 deletions.
118 changes: 28 additions & 90 deletions thunder/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import os

from itertools import chain
import collections
from contextlib import contextmanager
from contextvars import ContextVar, Token
import copy
Expand All @@ -20,7 +19,7 @@

if TYPE_CHECKING:
from torch.distributed import ProcessGroup
import thunder
from thunder.core.module import ThunderModule


__all__ = [
Expand Down Expand Up @@ -368,18 +367,17 @@ def f(tensor: TensorProxy) -> str:


def fsdp_transform_module(
thunder_model: thunder.ThunderModule,
thunder_model: ThunderModule,
*,
device: torch.device | None = None,
broadcast_from: int | None = None,
sharding_strategy: FSDPType = FSDPType.ZERO2,
bucketing_strategy: FSDPBucketingStrategy = FSDPBucketingStrategy.NONE,
) -> thunder.ThunderModule:
import thunder

cd = thunder.compile_data(thunder_model)
# TODO: promote use_fsdp and use_ddp to public members of CompileData
cd.use_fsdp = True
) -> ThunderModule:
from thunder import compile_data as get_compile_data
from thunder.core.transforms import add_transform
from thunder.core.module import ThunderModule
from thunder.distributed.transforms.fsdp_v2 import FSDPTraceTransform

process_group = tdist.distributed_c10d._get_default_group()
utils.check(process_group is not None, lambda: "The default process group is None")
Expand All @@ -389,85 +387,18 @@ def fsdp_transform_module(
local_rank = int(os.environ["LOCAL_RANK"])
device = torch.device("cuda", local_rank)

def prologue_and_compute_transform(prologue_trace, computation_trace, epilogue_trace, **kwargs):
import thunder

prologue_producers, prologue_consumers = thunder.core.utils.producers_and_consumers(prologue_trace)
computation_producers, computation_consumers = thunder.core.utils.producers_and_consumers(computation_trace)

modules_and_thunder_modules = [
(bsym.args[0], bsym.output)
for bsym in prologue_trace.bound_symbols
if bsym.sym is thunder.prims.unpack_thunder_module
]

if len(modules_and_thunder_modules) != 1:
raise NotImplementedError("cannot deal with modules other than the compiled module")

((orig_module_proxy, thunder_module_proxy),) = modules_and_thunder_modules
if prologue_producers[orig_module_proxy].sym is not thunder.prims.unpack_function_obj:
raise NotImplementedError("original module does not match the compiled module")

computation_trace.push_scope([])

synchronized_parameters = []
# todo: deal with epilogue output
for pro_out_p, comp_inp_p in zip(prologue_trace.output, computation_trace.args):
bsym = prologue_producers[pro_out_p]
if bsym.sym == thunder.prims.unpack_parameter:
param_thunder_module, param_name = bsym.args
assert param_thunder_module is thunder_module_proxy
if param_name in sharded_params:
old_shape, new_shape, new_torch_device = sharded_params[param_name]
thunder_device = thunder.core.devices.to_device(new_torch_device)
thunder_device_str = str(thunder_device)

pro_out_p._ddp_type = thunder.core.proxies.DDPType.FULLY_SHARDED
pro_out_p._shape = tuple(new_shape)
pro_out_p._device = thunder_device
if comp_inp_p is not pro_out_p:
comp_inp_p._ddp_type = thunder.core.proxies.DDPType.FULLY_SHARDED
comp_inp_p._shape = tuple(new_shape)
comp_inp_p._device = thunder_device
with thunder.core.trace.tracectx(computation_trace):
synchronized_parameters.append(thunder.distributed.prims.synchronize(comp_inp_p, process_group))

for c in prologue_consumers[pro_out_p]:
if c.sym is thunder.core.prims.check_tensor_shape_and_metadata:
# TODO have a more principled way to update this?
a0, _, _, *a2pp = c.args
c.args = (a0, tuple(new_shape), thunder_device_str, *a2pp)

new_scope = computation_trace.pop_scope()

for bsym in prologue_trace.bound_symbols:
if bsym.sym is thunder.core.prims.check_tensor_shape_and_metadata and prologue_producers[
bsym.args[0]
].sym in (thunder.core.prims.unpack_parameter, thunder.core.prims.unpack_buffer):
param_thunder_module, name = prologue_producers[bsym.args[0]].args
assert param_thunder_module is thunder_module_proxy
if name not in sharded_params and name in device_adjutments:
a0, shape, _, *a2pp = bsym.args
bsym.args = (a0, shape, thunder_device_str, *a2pp)

proxies_to_replace = {id(bsym.args[0]): bsym.output for bsym in new_scope}

new_computation_trace = thunder.core.trace.from_trace(computation_trace)
for idx, bsym in enumerate(computation_trace.bound_symbols):
if bsym.sym != thunder.core.prims.unpack_trivial:
break
new_computation_trace.bound_symbols.append(bsym.from_bsym())
new_computation_trace.bound_symbols += new_scope
for bsym in computation_trace.bound_symbols[idx:]:
new_args = tuple(proxies_to_replace.get(id(a), a) for a in bsym.args)
new_computation_trace.bound_symbols.append(bsym.from_bsym(args=new_args))

new_computation_trace.set_provenance(thunder.core.trace.TraceProvenance("fsdp pass"))

return prologue_trace, new_computation_trace, epilogue_trace

# add prologue + compute transform
thunder_model = thunder.core.transforms.add_transform(thunder_model, early_transform=prologue_and_compute_transform)
cd = get_compile_data(thunder_model)
# TODO: promote use_fsdp and use_ddp to public members of CompileData
cd.use_fsdp = True
orig_module: torch.nn.Module = cd.fn
utils.check(
isinstance(orig_module, torch.nn.Module) and not isinstance(orig_module, ThunderModule),
lambda: f"CompileData.fn expected to be `nn.Module` but {type(orig_module)}",
)
orig_module.use_fsdp = True
orig_module.process_group_for_ddp = process_group
orig_module.bucketing_strategy = bucketing_strategy
orig_module.sharding_strategy = sharding_strategy

# modify module
sharded_params = {}
Expand All @@ -485,7 +416,7 @@ def prologue_and_compute_transform(prologue_trace, computation_trace, epilogue_t
# Each module only initializes its own parameters and not those of its children (recurse=False)
if any(t.is_meta for t in chain(module_copy.parameters(recurse=False), module_copy.buffers(recurse=False))):
# TODO: we could also support calling a "param_init_fn" argument like PyTorch
thunder.distributed._materialize(module_copy, device)
_materialize(module_copy, device)
for n, p in module_copy.named_parameters(recurse=False, prefix=module_name):
thunder_model._overrides[n] = p
device_adjustments[n] = device
Expand Down Expand Up @@ -518,10 +449,17 @@ def prologue_and_compute_transform(prologue_trace, computation_trace, epilogue_t
thunder_model._overrides[pn] = copy.copy(p)
# we collect shapes and devices because we do not know if other transforms also change it...
old_shape = thunder_model._overrides[pn].shape
thunder.distributed._shard_param(thunder_model._overrides[pn], global_rank, world_size, pn)
_shard_param(thunder_model._overrides[pn], global_rank, world_size, pn)
new_shape = thunder_model._overrides[pn].shape
sharded_params[pn] = (old_shape, new_shape, thunder_model._overrides[pn].device)

early_transform_from_trace_to_fsdp_trace = FSDPTraceTransform(
sharded_params=sharded_params,
process_group=process_group,
)
# add prologue + compute transform
thunder_model = add_transform(thunder_model, early_transform=early_transform_from_trace_to_fsdp_trace)

return thunder_model


Expand Down
2 changes: 1 addition & 1 deletion thunder/distributed/transforms/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,7 @@ def _collect_sharded_parameters(self, fwd_trace: TraceCtx) -> list[TensorProxy]:
fwd_trace_flat_args, _ = tree_flatten((fwd_trace.args, fwd_trace.kwargs))
return fwd_trace_flat_args

def apply_bucketing_to_forward_trace(self, fwd_trace: TraceCtx, bwd_trace_names: set[str]) -> TraceCtx:
def apply_bucketing_to_forward_trace(self, fwd_trace: TraceCtx) -> TraceCtx:
"""Optimize collective comms in fsdp with bucketing.
This function is no-op if you pass :obj:`BucketingStrategy.NONE` as kwarg of ``sharding_strategy`` to :func:`thunder.distributed.fsdp`.
Expand Down
105 changes: 105 additions & 0 deletions thunder/distributed/transforms/fsdp_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
"""Early transform for `fsdp(jit(model))` to convert a trace into fsdp."""

from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING

from thunder.core import devices
from thunder.core import prims
from thunder.core import utils
from thunder.core.proxies import DDPType
from thunder.core.trace import from_trace
from thunder.core.trace import tracectx
from thunder.core.trace import TraceProvenance

if TYPE_CHECKING:
from typing import Any
from torch.distributed import ProcessGroup


__all__ = [
"FSDPTraceTransform",
]


@dataclass(frozen=True)
class FSDPTraceTransform:
sharded_params: dict[str, Any]
process_group: ProcessGroup

def __call__(self, prologue_trace, computation_trace, epilogue_trace, **kwargs):
from thunder.distributed import prims as dist_prims

prologue_producers, prologue_consumers = utils.producers_and_consumers(prologue_trace)

modules_and_thunder_modules = [
(bsym.args[0], bsym.output)
for bsym in prologue_trace.bound_symbols
if bsym.sym is prims.unpack_thunder_module
]

if len(modules_and_thunder_modules) != 1:
raise NotImplementedError("cannot deal with modules other than the compiled module")

((orig_module_proxy, thunder_module_proxy),) = modules_and_thunder_modules
if prologue_producers[orig_module_proxy].sym is not prims.unpack_function_obj:
raise NotImplementedError("original module does not match the compiled module")

computation_trace.push_scope([])

synchronized_parameters = []
# todo: deal with epilogue output
for pro_out_p, comp_inp_p in zip(prologue_trace.output, computation_trace.args):
bsym = prologue_producers[pro_out_p]
if bsym.sym == prims.unpack_parameter:
param_thunder_module, param_name = bsym.args
assert param_thunder_module is thunder_module_proxy
if param_name in self.sharded_params:
old_shape, new_shape, new_torch_device = self.sharded_params[param_name]
thunder_device = devices.to_device(new_torch_device)
thunder_device_str = str(thunder_device)

pro_out_p._ddp_type = DDPType.FULLY_SHARDED
pro_out_p._shape = tuple(new_shape)
pro_out_p._device = thunder_device
if comp_inp_p is not pro_out_p:
comp_inp_p._ddp_type = DDPType.FULLY_SHARDED
comp_inp_p._shape = tuple(new_shape)
comp_inp_p._device = thunder_device
with tracectx(computation_trace):
synchronized_parameters.append(dist_prims.synchronize(comp_inp_p, self.process_group))

for c in prologue_consumers[pro_out_p]:
if c.sym is prims.check_tensor_shape_and_metadata:
# TODO have a more principled way to update this?
a0, _, _, *a2pp = c.args
c.args = (a0, tuple(new_shape), thunder_device_str, *a2pp)

new_scope = computation_trace.pop_scope()

for bsym in prologue_trace.bound_symbols:
if bsym.sym is prims.check_tensor_shape_and_metadata and prologue_producers[bsym.args[0]].sym in (
prims.unpack_parameter,
prims.unpack_buffer,
):
param_thunder_module, name = prologue_producers[bsym.args[0]].args
assert param_thunder_module is thunder_module_proxy
if name not in self.sharded_params:
a0, shape, _, *a2pp = bsym.args
bsym.args = (a0, shape, thunder_device_str, *a2pp)

proxies_to_replace = {id(bsym.args[0]): bsym.output for bsym in new_scope}

new_computation_trace = from_trace(computation_trace)
for idx, bsym in enumerate(computation_trace.bound_symbols):
if bsym.sym != prims.unpack_trivial:
break
new_computation_trace.bound_symbols.append(bsym.from_bsym())
new_computation_trace.bound_symbols += new_scope
for bsym in computation_trace.bound_symbols[idx:]:
new_args = tuple(proxies_to_replace.get(id(a), a) for a in bsym.args)
new_computation_trace.bound_symbols.append(bsym.from_bsym(args=new_args))

new_computation_trace.set_provenance(TraceProvenance("fsdp pass"))

return prologue_trace, new_computation_trace, epilogue_trace
2 changes: 1 addition & 1 deletion thunder/executors/torch_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat
_fsdp_comm_bucketing: FSDPCommBucketing | None = None
if getattr(compile_data.fn, "use_fsdp", False):
_fsdp_comm_bucketing = FSDPCommBucketing(compile_data, computation_trc)
fw_trace = _fsdp_comm_bucketing.apply_bucketing_to_forward_trace(fw_trace, bw_trace.names)
fw_trace = _fsdp_comm_bucketing.apply_bucketing_to_forward_trace(fw_trace)

# Now we can run the optimization passes on the forward trace
# TODO Restore request for no rematerialization
Expand Down
25 changes: 18 additions & 7 deletions thunder/tests/distributed/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,24 +730,26 @@ def test_ddp_grad_parity_with_without_bucketing(self, executor):

# TODO(crcrpar): Add torch compile to executors_list
@common_utils.parametrize(
"executor,bucketing_strategy,fsdptype",
"executor,bucketing_strategy,fsdptype,apply_fsdp_first",
product(
tuple(executors_map.keys()),
(
FSDPBucketingStrategy.LAYER,
FSDPBucketingStrategy.BLOCK,
),
(FSDPType.ZERO2, FSDPType.ZERO3),
(True, False),
),
name_fn=lambda executor, bucketing_strategy, fsdptype: (
f"executor_{executor}_bucketing_{str(bucketing_strategy).split('.')[1].lower()}_{(str(fsdptype).lower().split('.')[1])}"
name_fn=lambda executor, bucketing_strategy, fsdptype, apply_fsdp_first: (
f"executor_{executor}_bucketing_{str(bucketing_strategy).split('.')[1].lower()}_{(str(fsdptype).lower().split('.')[1])}_{'jit_fsdp' if apply_fsdp_first else 'fsdp_jit'}"
),
)
def test_fsdp_grad_parity_with_without_bucketing(
self,
executor,
bucketing_strategy: FSDPBucketingStrategy,
fsdptype: FSDPType,
apply_fsdp_first: bool,
):
from thunder.distributed import fsdp

Expand All @@ -757,10 +759,18 @@ def test_fsdp_grad_parity_with_without_bucketing(
for strategy in (FSDPBucketingStrategy.NONE, bucketing_strategy):
m = ToyModel()
m.load_state_dict(initial_model_state)
cm = thunder.jit(
fsdp(m, device=device, bucketing_strategy=bucketing_strategy, sharding_strategy=fsdptype),
executors=executors_map[executor].executors_list(),
)
if apply_fsdp_first:
cm = thunder.jit(
fsdp(m, device=device, bucketing_strategy=bucketing_strategy, sharding_strategy=fsdptype),
executors=executors_map[executor].executors_list(),
)
else:
cm = fsdp(
thunder.jit(m.to(device), executors=executors_map[executor].executors_list()),
device=device,
bucketing_strategy=bucketing_strategy,
sharding_strategy=fsdptype,
)
x = torch.ones((2, 12), device=device)
loss = cm(x).mean()
loss.backward()
Expand All @@ -783,6 +793,7 @@ def test_fsdp_grad_parity_with_without_bucketing(
ex_trace.bound_symbols,
)
)
self.assertGreater(len(pack_bsyms), 0)
has_pack_multiple_tensors = False
for bsym in pack_bsyms:
first_arg = bsym.args[0]
Expand Down

0 comments on commit 599243c

Please sign in to comment.