From 2fa97d6b1fead1eb9d0dc17703ebb12bbf752bbc Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Fri, 17 May 2024 01:43:17 +0900 Subject: [PATCH] [fsdp] Coping with params whose `shape[0]` is not divisible by `world_size` by padding (#415) Signed-off-by: Masaki Kozuki --- thunder/core/jit_ext.py | 7 +++- thunder/core/proxies.py | 29 ++++++++++++-- thunder/distributed/__init__.py | 57 +++++++++++++++++++-------- thunder/distributed/prims.py | 15 +++++-- thunder/tests/distributed/test_ddp.py | 50 +++++++++++++++++++++++ 5 files changed, 133 insertions(+), 25 deletions(-) diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py index ba022f6968..9ed0ea8d1b 100644 --- a/thunder/core/jit_ext.py +++ b/thunder/core/jit_ext.py @@ -570,7 +570,12 @@ def proxify(self, value: WrappedValue) -> Any: value.provenance.ext_flag |= EXT_FLAG_IS_TENSOR_PROXY if isinstance(p, TensorProxy) and p.ddp_type in (DDPType.REPLICATED, DDPType.FULLY_SHARDED): - p_new = thunder.distributed.prims.synchronize(p, self._process_group_for_ddp) + p_new = thunder.distributed.prims.synchronize( + p, + self._process_group_for_ddp, + ) + if isinstance(p.thunder_fsdp_padding_size, int): + p_new = p_new[: (p_new.shape[0] - p.thunder_fsdp_padding_size)] p_orig = p p = p_new else: diff --git a/thunder/core/proxies.py b/thunder/core/proxies.py index 3be3735b6a..4ca0924aed 100644 --- a/thunder/core/proxies.py +++ b/thunder/core/proxies.py @@ -1005,12 +1005,14 @@ def _infer_tensor_properties( dtype: dtypes.dtype | None = None, requires_grad: bool | None = None, ddp_type: DDPType | None = None, + thunder_fsdp_padding_size: int | None = None, ): _shape = None _device = None _dtype = None _requires_grad: None | bool = None _ddp_type = DDPType.NONE + _thunder_fsdp_padding_size = None if like is not None: baseutils.check_type(like, (TensorProxy, FutureTensorProxy)) @@ -1030,6 +1032,9 @@ def _infer_tensor_properties( _requires_grad = requires_grad if requires_grad is not None else _requires_grad _requires_grad = False if not dtypes.is_inexact_dtype(_dtype) else _requires_grad _ddp_type = ddp_type if ddp_type is not None else _ddp_type + _thunder_fsdp_padding_size = ( + thunder_fsdp_padding_size if thunder_fsdp_padding_size is not None else _thunder_fsdp_padding_size + ) # Extracts actual values for shape # TODO RC1 Enable this @@ -1051,13 +1056,22 @@ def _infer_tensor_properties( baseutils.check_type(_dtype, dtypes.dtype) baseutils.check_type(_requires_grad, bool) baseutils.check_type(_ddp_type, DDPType) + if isinstance(_thunder_fsdp_padding_size, int): + baseutils.check( + _ddp_type == DDPType.FULLY_SHARDED, + lambda: f"{_ddp_type = } and {_thunder_fsdp_padding_size = } do not work", + ) + baseutils.check( + _thunder_fsdp_padding_size > 0, + lambda: f"{_thunder_fsdp_padding_size=} expected to be > 0 or `None`", + ) # NOTE for simplicity functions that want to reason about weak dtypes should explicitly request # the true_dtype property _true_dtype = _dtype _dtype = dtypes.to_strong_dtype(_dtype) - return _shape, _device, _dtype, _true_dtype, _numel, _ndim, _requires_grad, _ddp_type + return _shape, _device, _dtype, _true_dtype, _numel, _ndim, _requires_grad, _ddp_type, _thunder_fsdp_padding_size # NOTE A FutureTensorProxy is intentionally NOT a subclass of TensorProxy @@ -1084,7 +1098,8 @@ def __init__( self._numel, self._ndim, self._requires_grad, - _, + _, # ddp_type + _, # thunder_fsdp_padding_size ) = _infer_tensor_properties( like, shape, @@ -1152,6 +1167,7 @@ def __init__( prefix: None | str = None, ddp_type: DDPType | None = None, history: None | tuple = None, + thunder_fsdp_padding_size: int | None = None, ): super().__init__(name, prefix=prefix, history=history) @@ -1164,7 +1180,8 @@ def __init__( self._ndim, self._requires_grad, self._ddp_type, - ) = _infer_tensor_properties(like, shape, device, dtype, requires_grad, ddp_type) + self._thunder_fsdp_padding_size, + ) = _infer_tensor_properties(like, shape, device, dtype, requires_grad, ddp_type, thunder_fsdp_padding_size) # NOTE The following properties DO NOT depend on the language context or record # themselves into the trace, so they can be used when working with tensor proxies @@ -1197,6 +1214,10 @@ def requires_grad(self): def ddp_type(self): return self._ddp_type + @property + def thunder_fsdp_padding_size(self): + return self._thunder_fsdp_padding_size + # We need to implement `__len__` as # > In addition to bypassing any instance attributes in the # > interest of correctness, implicit special method lookup @@ -1498,6 +1519,7 @@ def tensorproxy(t: torch.Tensor, /, *, name: None | str, history: None | tuple = dtype = dtypes.to_dtype(t.dtype) # See Note [DistributedDataParallel and ddp_type] ddp_type = getattr(t, "ddp_type", None) + _thunder_fsdp_padding_size = getattr(t, "_thunder_fsdp_padding_size", None) # NOTE Without tuple(t.shape) then the shape would be a torch.Size object return TensorProxy( name, @@ -1507,6 +1529,7 @@ def tensorproxy(t: torch.Tensor, /, *, name: None | str, history: None | tuple = requires_grad=t.requires_grad, ddp_type=ddp_type, history=history, + thunder_fsdp_padding_size=_thunder_fsdp_padding_size, ) diff --git a/thunder/distributed/__init__.py b/thunder/distributed/__init__.py index 34260ec0d5..70929b5070 100644 --- a/thunder/distributed/__init__.py +++ b/thunder/distributed/__init__.py @@ -449,7 +449,7 @@ def fsdp_transform_module( thunder_model._overrides[pn] = copy.copy(p) # we collect shapes and devices because we do not know if other transforms also change it... old_shape = thunder_model._overrides[pn].shape - _shard_param(thunder_model._overrides[pn], global_rank, world_size, pn) + _shard_param(thunder_model._overrides[pn], global_rank, world_size, pn, allow_padding_for_fsdp=True) new_shape = thunder_model._overrides[pn].shape sharded_params[pn] = (old_shape, new_shape, thunder_model._overrides[pn].device) @@ -529,7 +529,7 @@ def fsdp( model.bucketing_strategy = bucketing_strategy # Shard the parameters - _shard_params(model, process_group, device, broadcast_from) + _shard_params(model, process_group, device, broadcast_from, allow_padding_for_fsdp=True) # See Note [DistributedDataParallel and ddp_type] # If model was wrapped with thunder.distributed.fsdp it would have a @@ -545,7 +545,11 @@ def fsdp( @torch.no_grad() def _shard_params( - module: torch.nn.Module, process_group: ProcessGroup, device: torch.device | None, broadcast_from: int | None + module: torch.nn.Module, + process_group: ProcessGroup, + device: torch.device | None, + broadcast_from: int | None, + allow_padding_for_fsdp: bool = False, ) -> None: """Shards the parameters on the first dimension.""" global_rank = tdist.get_rank(group=process_group) @@ -576,22 +580,41 @@ def _shard_params( # Note [FSDP Sharding] # All internal code will assume that the parameters are sharded on the first dimension for param_name, param in submodule.named_parameters(recurse=False, prefix=module_name): - _shard_param(param, global_rank, world_size, param_name) + _shard_param(param, global_rank, world_size, param_name, allow_padding_for_fsdp=allow_padding_for_fsdp) -def _shard_param(param: torch.Tensor, rank: int, world_size: int, name: str) -> None: - utils.check( - param.shape[0] % world_size == 0, - lambda: ( - f"Current sharding requires the first dimension of the parameter {name!r} ({param.shape[0]})" - f" to be divisible by the world size ({world_size})" - ), - ) - chunk_size = param.shape[0] // world_size - # NOTE This could be a ShardTensor to indicate other parts of the code - # that it's sharded and should be treated differently - shard = param.data.narrow(0, chunk_size * rank, chunk_size).clone() - param.data = shard +def _shard_param( + param: torch.Tensor, + rank: int, + world_size: int, + name: str, + allow_padding_for_fsdp: bool = False, +) -> None: + + if not allow_padding_for_fsdp or (param.size(0) % world_size == 0): + if not allow_padding_for_fsdp: + utils.check( + param.shape[0] % world_size == 0, + lambda: ( + f"Current sharding requires the first dimension of the parameter {name!r} ({param.shape[0]})" + f" to be divisible by the world size ({world_size})" + ), + ) + chunk_size = param.shape[0] // world_size + # NOTE This could be a ShardTensor to indicate other parts of the code + # that it's sharded and should be treated differently + shard = param.data.narrow(0, chunk_size * rank, chunk_size).clone() + param.data = shard + else: + padded_param_shape = list(param.shape) + orig_0dim_size = param.size(0) + chunk_size = (padded_param_shape[0] + world_size - 1) // world_size + padded_param_shape[0] = chunk_size * world_size + _thunder_fsdp_padding_size = padded_param_shape[0] - param.size(0) + padded_param = torch.empty(padded_param_shape, device=param.device, dtype=param.dtype) + padded_param[:orig_0dim_size].copy_(param) + param.data = padded_param.data.narrow(0, chunk_size * rank, chunk_size).clone() + param._thunder_fsdp_padding_size = _thunder_fsdp_padding_size @torch.no_grad() diff --git a/thunder/distributed/prims.py b/thunder/distributed/prims.py index 15d5c08f27..c5067a3b12 100644 --- a/thunder/distributed/prims.py +++ b/thunder/distributed/prims.py @@ -148,7 +148,11 @@ def wait_meta(a: FutureTensorProxy, /) -> TensorProxy: return TensorProxy(like=a) -def synchronize_meta(a: TensorProxy, /, group: torch.distributed.ProcessGroup) -> TensorProxy: +def synchronize_meta( + a: TensorProxy, + /, + group: torch.distributed.ProcessGroup, +) -> TensorProxy: utils.check_type(a, TensorProxy) utils.check_type(group, torch.distributed.ProcessGroup) @@ -286,7 +290,8 @@ def stash_grad_for_fsdp_meta( @register_augmented_forward(PrimIDs.SYNCHRONIZE) def synchronize_augmented_forward_rule( - a: TensorProxy, group: torch.distributed.ProcessGroup + a: TensorProxy, + group: torch.distributed.ProcessGroup, ) -> tuple[TensorProxy, tuple]: match a.ddp_type: case DDPType.REPLICATED: @@ -302,7 +307,7 @@ def synchronize_augmented_forward_rule( # immediately called on the result with the hope that the execution # passes would reorder the wait operation to be closer to the actual # usage of the tensor. - return all_gather(a, group, do_async=True).wait(), ( + return all_gather(a, group, True).wait(), ( a.ddp_type, group, ) @@ -312,7 +317,9 @@ def synchronize_augmented_forward_rule( @register_backward(PrimIDs.SYNCHRONIZE) def synchronize_backward_rule( - ddp_type: DDPType, group: torch.distributed.ProcessGroup, grad: TensorProxy + ddp_type: DDPType, + group: torch.distributed.ProcessGroup, + grad: TensorProxy, ) -> tuple[TensorProxy, None]: preaverage_grad = grad / group.size() match ddp_type: diff --git a/thunder/tests/distributed/test_ddp.py b/thunder/tests/distributed/test_ddp.py index e68e2c24ee..108c90995e 100644 --- a/thunder/tests/distributed/test_ddp.py +++ b/thunder/tests/distributed/test_ddp.py @@ -804,6 +804,55 @@ def test_fsdp_grad_parity_with_without_bucketing( if bucketing_strategy == FSDPBucketingStrategy.LAYER: self.assertTrue(has_pack_multiple_tensors, msg=f"{[bsym.args[0] for bsym in pack_bsyms]=}") + @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires 2 devices") + @common_utils.parametrize( + "bucketing_strategy,fsdptype", + product( + ( + FSDPBucketingStrategy.NONE, + FSDPBucketingStrategy.BLOCK, + ), + (FSDPType.ZERO2, FSDPType.ZERO3), + ), + name_fn=lambda bucketing_strategy, fsdptype: ( + f"bucketing_{str(bucketing_strategy).split('.')[1].lower()}_{(str(fsdptype).lower().split('.')[1])}" + ), + ) + def test_fsdp_with_padding( + self, + bucketing_strategy: FSDPBucketingStrategy, + fsdptype: FSDPType, + ): + + from thunder.core.prims import PrimIDs + from thunder.executors.torchex import pad_prim_impl + from thunder.executors.torchex import slice_prim_impl + + class M(nn.Module): + def __init__(self): + super().__init__() + self.l1 = nn.Linear(4, 13) + self.l2 = nn.Linear(13, 1) + + def forward(self, x): + return self.l2(new_gelu(self.l1(x))) + + device = torch.device(f"cuda:{self.rank}") + m = M().to(device) + jitted = thunder.jit(fsdp(m, bucketing_strategy=bucketing_strategy, sharding_strategy=fsdptype)) + + x = torch.randn(4, 4, device=device) + y = jitted(x) + y.mean().backward() + + fw_extrace = thunder.last_traces(jitted)[-1] + fw_symids = [bsym.sym.id for bsym in fw_extrace.bound_symbols] + self.assertTrue(any(sym_id in {PrimIDs.SLICE, slice_prim_impl.id} for sym_id in fw_symids)) + + bw_trace = thunder.last_backward_traces(jitted)[0] + bw_symids = [bsym.sym.id for bsym in bw_trace.bound_symbols] + self.assertTrue(any(sym_id in {PrimIDs.PAD, pad_prim_impl.id} for sym_id in bw_symids)) + @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires 2 devices") def test_fsdp_shard_unshard(self): from thunder.distributed import _shard_params, _unshard_params @@ -814,6 +863,7 @@ def test_fsdp_shard_unshard(self): model = torch.nn.Linear(3, 5, bias=False, device="meta") with pytest.raises(RuntimeError, match=r"parameter 'weight' \(5\) to be divisible by the world size \(2\)"): _shard_params(model, pg, device, None) + _shard_params(model, pg, device, None, allow_padding_for_fsdp=True) model = torch.nn.Linear(3, 4, bias=False, device="meta") weight = torch.arange(3 * 4, device="cpu", dtype=torch.float).view(4, 3)