diff --git a/thunder/distributed/__init__.py b/thunder/distributed/__init__.py index f853e9876b..34260ec0d5 100644 --- a/thunder/distributed/__init__.py +++ b/thunder/distributed/__init__.py @@ -2,7 +2,6 @@ import os from itertools import chain -import collections from contextlib import contextmanager from contextvars import ContextVar, Token import copy @@ -20,7 +19,7 @@ if TYPE_CHECKING: from torch.distributed import ProcessGroup - import thunder + from thunder.core.module import ThunderModule __all__ = [ @@ -368,18 +367,17 @@ def f(tensor: TensorProxy) -> str: def fsdp_transform_module( - thunder_model: thunder.ThunderModule, + thunder_model: ThunderModule, *, device: torch.device | None = None, broadcast_from: int | None = None, sharding_strategy: FSDPType = FSDPType.ZERO2, bucketing_strategy: FSDPBucketingStrategy = FSDPBucketingStrategy.NONE, -) -> thunder.ThunderModule: - import thunder - - cd = thunder.compile_data(thunder_model) - # TODO: promote use_fsdp and use_ddp to public members of CompileData - cd.use_fsdp = True +) -> ThunderModule: + from thunder import compile_data as get_compile_data + from thunder.core.transforms import add_transform + from thunder.core.module import ThunderModule + from thunder.distributed.transforms.fsdp_v2 import FSDPTraceTransform process_group = tdist.distributed_c10d._get_default_group() utils.check(process_group is not None, lambda: "The default process group is None") @@ -389,85 +387,18 @@ def fsdp_transform_module( local_rank = int(os.environ["LOCAL_RANK"]) device = torch.device("cuda", local_rank) - def prologue_and_compute_transform(prologue_trace, computation_trace, epilogue_trace, **kwargs): - import thunder - - prologue_producers, prologue_consumers = thunder.core.utils.producers_and_consumers(prologue_trace) - computation_producers, computation_consumers = thunder.core.utils.producers_and_consumers(computation_trace) - - modules_and_thunder_modules = [ - (bsym.args[0], bsym.output) - for bsym in prologue_trace.bound_symbols - if bsym.sym is thunder.prims.unpack_thunder_module - ] - - if len(modules_and_thunder_modules) != 1: - raise NotImplementedError("cannot deal with modules other than the compiled module") - - ((orig_module_proxy, thunder_module_proxy),) = modules_and_thunder_modules - if prologue_producers[orig_module_proxy].sym is not thunder.prims.unpack_function_obj: - raise NotImplementedError("original module does not match the compiled module") - - computation_trace.push_scope([]) - - synchronized_parameters = [] - # todo: deal with epilogue output - for pro_out_p, comp_inp_p in zip(prologue_trace.output, computation_trace.args): - bsym = prologue_producers[pro_out_p] - if bsym.sym == thunder.prims.unpack_parameter: - param_thunder_module, param_name = bsym.args - assert param_thunder_module is thunder_module_proxy - if param_name in sharded_params: - old_shape, new_shape, new_torch_device = sharded_params[param_name] - thunder_device = thunder.core.devices.to_device(new_torch_device) - thunder_device_str = str(thunder_device) - - pro_out_p._ddp_type = thunder.core.proxies.DDPType.FULLY_SHARDED - pro_out_p._shape = tuple(new_shape) - pro_out_p._device = thunder_device - if comp_inp_p is not pro_out_p: - comp_inp_p._ddp_type = thunder.core.proxies.DDPType.FULLY_SHARDED - comp_inp_p._shape = tuple(new_shape) - comp_inp_p._device = thunder_device - with thunder.core.trace.tracectx(computation_trace): - synchronized_parameters.append(thunder.distributed.prims.synchronize(comp_inp_p, process_group)) - - for c in prologue_consumers[pro_out_p]: - if c.sym is thunder.core.prims.check_tensor_shape_and_metadata: - # TODO have a more principled way to update this? - a0, _, _, *a2pp = c.args - c.args = (a0, tuple(new_shape), thunder_device_str, *a2pp) - - new_scope = computation_trace.pop_scope() - - for bsym in prologue_trace.bound_symbols: - if bsym.sym is thunder.core.prims.check_tensor_shape_and_metadata and prologue_producers[ - bsym.args[0] - ].sym in (thunder.core.prims.unpack_parameter, thunder.core.prims.unpack_buffer): - param_thunder_module, name = prologue_producers[bsym.args[0]].args - assert param_thunder_module is thunder_module_proxy - if name not in sharded_params and name in device_adjutments: - a0, shape, _, *a2pp = bsym.args - bsym.args = (a0, shape, thunder_device_str, *a2pp) - - proxies_to_replace = {id(bsym.args[0]): bsym.output for bsym in new_scope} - - new_computation_trace = thunder.core.trace.from_trace(computation_trace) - for idx, bsym in enumerate(computation_trace.bound_symbols): - if bsym.sym != thunder.core.prims.unpack_trivial: - break - new_computation_trace.bound_symbols.append(bsym.from_bsym()) - new_computation_trace.bound_symbols += new_scope - for bsym in computation_trace.bound_symbols[idx:]: - new_args = tuple(proxies_to_replace.get(id(a), a) for a in bsym.args) - new_computation_trace.bound_symbols.append(bsym.from_bsym(args=new_args)) - - new_computation_trace.set_provenance(thunder.core.trace.TraceProvenance("fsdp pass")) - - return prologue_trace, new_computation_trace, epilogue_trace - - # add prologue + compute transform - thunder_model = thunder.core.transforms.add_transform(thunder_model, early_transform=prologue_and_compute_transform) + cd = get_compile_data(thunder_model) + # TODO: promote use_fsdp and use_ddp to public members of CompileData + cd.use_fsdp = True + orig_module: torch.nn.Module = cd.fn + utils.check( + isinstance(orig_module, torch.nn.Module) and not isinstance(orig_module, ThunderModule), + lambda: f"CompileData.fn expected to be `nn.Module` but {type(orig_module)}", + ) + orig_module.use_fsdp = True + orig_module.process_group_for_ddp = process_group + orig_module.bucketing_strategy = bucketing_strategy + orig_module.sharding_strategy = sharding_strategy # modify module sharded_params = {} @@ -485,7 +416,7 @@ def prologue_and_compute_transform(prologue_trace, computation_trace, epilogue_t # Each module only initializes its own parameters and not those of its children (recurse=False) if any(t.is_meta for t in chain(module_copy.parameters(recurse=False), module_copy.buffers(recurse=False))): # TODO: we could also support calling a "param_init_fn" argument like PyTorch - thunder.distributed._materialize(module_copy, device) + _materialize(module_copy, device) for n, p in module_copy.named_parameters(recurse=False, prefix=module_name): thunder_model._overrides[n] = p device_adjustments[n] = device @@ -518,10 +449,17 @@ def prologue_and_compute_transform(prologue_trace, computation_trace, epilogue_t thunder_model._overrides[pn] = copy.copy(p) # we collect shapes and devices because we do not know if other transforms also change it... old_shape = thunder_model._overrides[pn].shape - thunder.distributed._shard_param(thunder_model._overrides[pn], global_rank, world_size, pn) + _shard_param(thunder_model._overrides[pn], global_rank, world_size, pn) new_shape = thunder_model._overrides[pn].shape sharded_params[pn] = (old_shape, new_shape, thunder_model._overrides[pn].device) + early_transform_from_trace_to_fsdp_trace = FSDPTraceTransform( + sharded_params=sharded_params, + process_group=process_group, + ) + # add prologue + compute transform + thunder_model = add_transform(thunder_model, early_transform=early_transform_from_trace_to_fsdp_trace) + return thunder_model diff --git a/thunder/distributed/transforms/fsdp.py b/thunder/distributed/transforms/fsdp.py index 988068c7e6..6f9b18576d 100644 --- a/thunder/distributed/transforms/fsdp.py +++ b/thunder/distributed/transforms/fsdp.py @@ -504,7 +504,7 @@ def _collect_sharded_parameters(self, fwd_trace: TraceCtx) -> list[TensorProxy]: fwd_trace_flat_args, _ = tree_flatten((fwd_trace.args, fwd_trace.kwargs)) return fwd_trace_flat_args - def apply_bucketing_to_forward_trace(self, fwd_trace: TraceCtx, bwd_trace_names: set[str]) -> TraceCtx: + def apply_bucketing_to_forward_trace(self, fwd_trace: TraceCtx) -> TraceCtx: """Optimize collective comms in fsdp with bucketing. This function is no-op if you pass :obj:`BucketingStrategy.NONE` as kwarg of ``sharding_strategy`` to :func:`thunder.distributed.fsdp`. diff --git a/thunder/distributed/transforms/fsdp_v2.py b/thunder/distributed/transforms/fsdp_v2.py new file mode 100644 index 0000000000..40d7b5280c --- /dev/null +++ b/thunder/distributed/transforms/fsdp_v2.py @@ -0,0 +1,105 @@ +"""Early transform for `fsdp(jit(model))` to convert a trace into fsdp.""" + +from __future__ import annotations +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from thunder.core import devices +from thunder.core import prims +from thunder.core import utils +from thunder.core.proxies import DDPType +from thunder.core.trace import from_trace +from thunder.core.trace import tracectx +from thunder.core.trace import TraceProvenance + +if TYPE_CHECKING: + from typing import Any + from torch.distributed import ProcessGroup + + +__all__ = [ + "FSDPTraceTransform", +] + + +@dataclass(frozen=True) +class FSDPTraceTransform: + sharded_params: dict[str, Any] + process_group: ProcessGroup + + def __call__(self, prologue_trace, computation_trace, epilogue_trace, **kwargs): + from thunder.distributed import prims as dist_prims + + prologue_producers, prologue_consumers = utils.producers_and_consumers(prologue_trace) + + modules_and_thunder_modules = [ + (bsym.args[0], bsym.output) + for bsym in prologue_trace.bound_symbols + if bsym.sym is prims.unpack_thunder_module + ] + + if len(modules_and_thunder_modules) != 1: + raise NotImplementedError("cannot deal with modules other than the compiled module") + + ((orig_module_proxy, thunder_module_proxy),) = modules_and_thunder_modules + if prologue_producers[orig_module_proxy].sym is not prims.unpack_function_obj: + raise NotImplementedError("original module does not match the compiled module") + + computation_trace.push_scope([]) + + synchronized_parameters = [] + # todo: deal with epilogue output + for pro_out_p, comp_inp_p in zip(prologue_trace.output, computation_trace.args): + bsym = prologue_producers[pro_out_p] + if bsym.sym == prims.unpack_parameter: + param_thunder_module, param_name = bsym.args + assert param_thunder_module is thunder_module_proxy + if param_name in self.sharded_params: + old_shape, new_shape, new_torch_device = self.sharded_params[param_name] + thunder_device = devices.to_device(new_torch_device) + thunder_device_str = str(thunder_device) + + pro_out_p._ddp_type = DDPType.FULLY_SHARDED + pro_out_p._shape = tuple(new_shape) + pro_out_p._device = thunder_device + if comp_inp_p is not pro_out_p: + comp_inp_p._ddp_type = DDPType.FULLY_SHARDED + comp_inp_p._shape = tuple(new_shape) + comp_inp_p._device = thunder_device + with tracectx(computation_trace): + synchronized_parameters.append(dist_prims.synchronize(comp_inp_p, self.process_group)) + + for c in prologue_consumers[pro_out_p]: + if c.sym is prims.check_tensor_shape_and_metadata: + # TODO have a more principled way to update this? + a0, _, _, *a2pp = c.args + c.args = (a0, tuple(new_shape), thunder_device_str, *a2pp) + + new_scope = computation_trace.pop_scope() + + for bsym in prologue_trace.bound_symbols: + if bsym.sym is prims.check_tensor_shape_and_metadata and prologue_producers[bsym.args[0]].sym in ( + prims.unpack_parameter, + prims.unpack_buffer, + ): + param_thunder_module, name = prologue_producers[bsym.args[0]].args + assert param_thunder_module is thunder_module_proxy + if name not in self.sharded_params: + a0, shape, _, *a2pp = bsym.args + bsym.args = (a0, shape, thunder_device_str, *a2pp) + + proxies_to_replace = {id(bsym.args[0]): bsym.output for bsym in new_scope} + + new_computation_trace = from_trace(computation_trace) + for idx, bsym in enumerate(computation_trace.bound_symbols): + if bsym.sym != prims.unpack_trivial: + break + new_computation_trace.bound_symbols.append(bsym.from_bsym()) + new_computation_trace.bound_symbols += new_scope + for bsym in computation_trace.bound_symbols[idx:]: + new_args = tuple(proxies_to_replace.get(id(a), a) for a in bsym.args) + new_computation_trace.bound_symbols.append(bsym.from_bsym(args=new_args)) + + new_computation_trace.set_provenance(TraceProvenance("fsdp pass")) + + return prologue_trace, new_computation_trace, epilogue_trace diff --git a/thunder/executors/torch_autograd.py b/thunder/executors/torch_autograd.py index 9734e9e946..29be69a66d 100644 --- a/thunder/executors/torch_autograd.py +++ b/thunder/executors/torch_autograd.py @@ -158,7 +158,7 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat _fsdp_comm_bucketing: FSDPCommBucketing | None = None if getattr(compile_data.fn, "use_fsdp", False): _fsdp_comm_bucketing = FSDPCommBucketing(compile_data, computation_trc) - fw_trace = _fsdp_comm_bucketing.apply_bucketing_to_forward_trace(fw_trace, bw_trace.names) + fw_trace = _fsdp_comm_bucketing.apply_bucketing_to_forward_trace(fw_trace) # Now we can run the optimization passes on the forward trace # TODO Restore request for no rematerialization diff --git a/thunder/tests/distributed/test_ddp.py b/thunder/tests/distributed/test_ddp.py index f61edcba4a..e68e2c24ee 100644 --- a/thunder/tests/distributed/test_ddp.py +++ b/thunder/tests/distributed/test_ddp.py @@ -730,7 +730,7 @@ def test_ddp_grad_parity_with_without_bucketing(self, executor): # TODO(crcrpar): Add torch compile to executors_list @common_utils.parametrize( - "executor,bucketing_strategy,fsdptype", + "executor,bucketing_strategy,fsdptype,apply_fsdp_first", product( tuple(executors_map.keys()), ( @@ -738,9 +738,10 @@ def test_ddp_grad_parity_with_without_bucketing(self, executor): FSDPBucketingStrategy.BLOCK, ), (FSDPType.ZERO2, FSDPType.ZERO3), + (True, False), ), - name_fn=lambda executor, bucketing_strategy, fsdptype: ( - f"executor_{executor}_bucketing_{str(bucketing_strategy).split('.')[1].lower()}_{(str(fsdptype).lower().split('.')[1])}" + name_fn=lambda executor, bucketing_strategy, fsdptype, apply_fsdp_first: ( + f"executor_{executor}_bucketing_{str(bucketing_strategy).split('.')[1].lower()}_{(str(fsdptype).lower().split('.')[1])}_{'jit_fsdp' if apply_fsdp_first else 'fsdp_jit'}" ), ) def test_fsdp_grad_parity_with_without_bucketing( @@ -748,6 +749,7 @@ def test_fsdp_grad_parity_with_without_bucketing( executor, bucketing_strategy: FSDPBucketingStrategy, fsdptype: FSDPType, + apply_fsdp_first: bool, ): from thunder.distributed import fsdp @@ -757,10 +759,18 @@ def test_fsdp_grad_parity_with_without_bucketing( for strategy in (FSDPBucketingStrategy.NONE, bucketing_strategy): m = ToyModel() m.load_state_dict(initial_model_state) - cm = thunder.jit( - fsdp(m, device=device, bucketing_strategy=bucketing_strategy, sharding_strategy=fsdptype), - executors=executors_map[executor].executors_list(), - ) + if apply_fsdp_first: + cm = thunder.jit( + fsdp(m, device=device, bucketing_strategy=bucketing_strategy, sharding_strategy=fsdptype), + executors=executors_map[executor].executors_list(), + ) + else: + cm = fsdp( + thunder.jit(m.to(device), executors=executors_map[executor].executors_list()), + device=device, + bucketing_strategy=bucketing_strategy, + sharding_strategy=fsdptype, + ) x = torch.ones((2, 12), device=device) loss = cm(x).mean() loss.backward() @@ -783,6 +793,7 @@ def test_fsdp_grad_parity_with_without_bucketing( ex_trace.bound_symbols, ) ) + self.assertGreater(len(pack_bsyms), 0) has_pack_multiple_tensors = False for bsym in pack_bsyms: first_arg = bsym.args[0]