Skip to content

Commit

Permalink
[fsdp] Coping with params whose shape[0] is not divisible by `world…
Browse files Browse the repository at this point in the history
…_size` by padding (#415)

Signed-off-by: Masaki Kozuki <[email protected]>
  • Loading branch information
crcrpar authored May 16, 2024
1 parent 599243c commit 2fa97d6
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 25 deletions.
7 changes: 6 additions & 1 deletion thunder/core/jit_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,12 @@ def proxify(self, value: WrappedValue) -> Any:
value.provenance.ext_flag |= EXT_FLAG_IS_TENSOR_PROXY

if isinstance(p, TensorProxy) and p.ddp_type in (DDPType.REPLICATED, DDPType.FULLY_SHARDED):
p_new = thunder.distributed.prims.synchronize(p, self._process_group_for_ddp)
p_new = thunder.distributed.prims.synchronize(
p,
self._process_group_for_ddp,
)
if isinstance(p.thunder_fsdp_padding_size, int):
p_new = p_new[: (p_new.shape[0] - p.thunder_fsdp_padding_size)]
p_orig = p
p = p_new
else:
Expand Down
29 changes: 26 additions & 3 deletions thunder/core/proxies.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,12 +1005,14 @@ def _infer_tensor_properties(
dtype: dtypes.dtype | None = None,
requires_grad: bool | None = None,
ddp_type: DDPType | None = None,
thunder_fsdp_padding_size: int | None = None,
):
_shape = None
_device = None
_dtype = None
_requires_grad: None | bool = None
_ddp_type = DDPType.NONE
_thunder_fsdp_padding_size = None

if like is not None:
baseutils.check_type(like, (TensorProxy, FutureTensorProxy))
Expand All @@ -1030,6 +1032,9 @@ def _infer_tensor_properties(
_requires_grad = requires_grad if requires_grad is not None else _requires_grad
_requires_grad = False if not dtypes.is_inexact_dtype(_dtype) else _requires_grad
_ddp_type = ddp_type if ddp_type is not None else _ddp_type
_thunder_fsdp_padding_size = (
thunder_fsdp_padding_size if thunder_fsdp_padding_size is not None else _thunder_fsdp_padding_size
)

# Extracts actual values for shape
# TODO RC1 Enable this
Expand All @@ -1051,13 +1056,22 @@ def _infer_tensor_properties(
baseutils.check_type(_dtype, dtypes.dtype)
baseutils.check_type(_requires_grad, bool)
baseutils.check_type(_ddp_type, DDPType)
if isinstance(_thunder_fsdp_padding_size, int):
baseutils.check(
_ddp_type == DDPType.FULLY_SHARDED,
lambda: f"{_ddp_type = } and {_thunder_fsdp_padding_size = } do not work",
)
baseutils.check(
_thunder_fsdp_padding_size > 0,
lambda: f"{_thunder_fsdp_padding_size=} expected to be > 0 or `None`",
)

# NOTE for simplicity functions that want to reason about weak dtypes should explicitly request
# the true_dtype property
_true_dtype = _dtype
_dtype = dtypes.to_strong_dtype(_dtype)

return _shape, _device, _dtype, _true_dtype, _numel, _ndim, _requires_grad, _ddp_type
return _shape, _device, _dtype, _true_dtype, _numel, _ndim, _requires_grad, _ddp_type, _thunder_fsdp_padding_size


# NOTE A FutureTensorProxy is intentionally NOT a subclass of TensorProxy
Expand All @@ -1084,7 +1098,8 @@ def __init__(
self._numel,
self._ndim,
self._requires_grad,
_,
_, # ddp_type
_, # thunder_fsdp_padding_size
) = _infer_tensor_properties(
like,
shape,
Expand Down Expand Up @@ -1152,6 +1167,7 @@ def __init__(
prefix: None | str = None,
ddp_type: DDPType | None = None,
history: None | tuple = None,
thunder_fsdp_padding_size: int | None = None,
):
super().__init__(name, prefix=prefix, history=history)

Expand All @@ -1164,7 +1180,8 @@ def __init__(
self._ndim,
self._requires_grad,
self._ddp_type,
) = _infer_tensor_properties(like, shape, device, dtype, requires_grad, ddp_type)
self._thunder_fsdp_padding_size,
) = _infer_tensor_properties(like, shape, device, dtype, requires_grad, ddp_type, thunder_fsdp_padding_size)

# NOTE The following properties DO NOT depend on the language context or record
# themselves into the trace, so they can be used when working with tensor proxies
Expand Down Expand Up @@ -1197,6 +1214,10 @@ def requires_grad(self):
def ddp_type(self):
return self._ddp_type

@property
def thunder_fsdp_padding_size(self):
return self._thunder_fsdp_padding_size

# We need to implement `__len__` as
# > In addition to bypassing any instance attributes in the
# > interest of correctness, implicit special method lookup
Expand Down Expand Up @@ -1498,6 +1519,7 @@ def tensorproxy(t: torch.Tensor, /, *, name: None | str, history: None | tuple =
dtype = dtypes.to_dtype(t.dtype)
# See Note [DistributedDataParallel and ddp_type]
ddp_type = getattr(t, "ddp_type", None)
_thunder_fsdp_padding_size = getattr(t, "_thunder_fsdp_padding_size", None)
# NOTE Without tuple(t.shape) then the shape would be a torch.Size object
return TensorProxy(
name,
Expand All @@ -1507,6 +1529,7 @@ def tensorproxy(t: torch.Tensor, /, *, name: None | str, history: None | tuple =
requires_grad=t.requires_grad,
ddp_type=ddp_type,
history=history,
thunder_fsdp_padding_size=_thunder_fsdp_padding_size,
)


Expand Down
57 changes: 40 additions & 17 deletions thunder/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ def fsdp_transform_module(
thunder_model._overrides[pn] = copy.copy(p)
# we collect shapes and devices because we do not know if other transforms also change it...
old_shape = thunder_model._overrides[pn].shape
_shard_param(thunder_model._overrides[pn], global_rank, world_size, pn)
_shard_param(thunder_model._overrides[pn], global_rank, world_size, pn, allow_padding_for_fsdp=True)
new_shape = thunder_model._overrides[pn].shape
sharded_params[pn] = (old_shape, new_shape, thunder_model._overrides[pn].device)

Expand Down Expand Up @@ -529,7 +529,7 @@ def fsdp(
model.bucketing_strategy = bucketing_strategy

# Shard the parameters
_shard_params(model, process_group, device, broadcast_from)
_shard_params(model, process_group, device, broadcast_from, allow_padding_for_fsdp=True)

# See Note [DistributedDataParallel and ddp_type]
# If model was wrapped with thunder.distributed.fsdp it would have a
Expand All @@ -545,7 +545,11 @@ def fsdp(

@torch.no_grad()
def _shard_params(
module: torch.nn.Module, process_group: ProcessGroup, device: torch.device | None, broadcast_from: int | None
module: torch.nn.Module,
process_group: ProcessGroup,
device: torch.device | None,
broadcast_from: int | None,
allow_padding_for_fsdp: bool = False,
) -> None:
"""Shards the parameters on the first dimension."""
global_rank = tdist.get_rank(group=process_group)
Expand Down Expand Up @@ -576,22 +580,41 @@ def _shard_params(
# Note [FSDP Sharding]
# All internal code will assume that the parameters are sharded on the first dimension
for param_name, param in submodule.named_parameters(recurse=False, prefix=module_name):
_shard_param(param, global_rank, world_size, param_name)
_shard_param(param, global_rank, world_size, param_name, allow_padding_for_fsdp=allow_padding_for_fsdp)


def _shard_param(param: torch.Tensor, rank: int, world_size: int, name: str) -> None:
utils.check(
param.shape[0] % world_size == 0,
lambda: (
f"Current sharding requires the first dimension of the parameter {name!r} ({param.shape[0]})"
f" to be divisible by the world size ({world_size})"
),
)
chunk_size = param.shape[0] // world_size
# NOTE This could be a ShardTensor to indicate other parts of the code
# that it's sharded and should be treated differently
shard = param.data.narrow(0, chunk_size * rank, chunk_size).clone()
param.data = shard
def _shard_param(
param: torch.Tensor,
rank: int,
world_size: int,
name: str,
allow_padding_for_fsdp: bool = False,
) -> None:

if not allow_padding_for_fsdp or (param.size(0) % world_size == 0):
if not allow_padding_for_fsdp:
utils.check(
param.shape[0] % world_size == 0,
lambda: (
f"Current sharding requires the first dimension of the parameter {name!r} ({param.shape[0]})"
f" to be divisible by the world size ({world_size})"
),
)
chunk_size = param.shape[0] // world_size
# NOTE This could be a ShardTensor to indicate other parts of the code
# that it's sharded and should be treated differently
shard = param.data.narrow(0, chunk_size * rank, chunk_size).clone()
param.data = shard
else:
padded_param_shape = list(param.shape)
orig_0dim_size = param.size(0)
chunk_size = (padded_param_shape[0] + world_size - 1) // world_size
padded_param_shape[0] = chunk_size * world_size
_thunder_fsdp_padding_size = padded_param_shape[0] - param.size(0)
padded_param = torch.empty(padded_param_shape, device=param.device, dtype=param.dtype)
padded_param[:orig_0dim_size].copy_(param)
param.data = padded_param.data.narrow(0, chunk_size * rank, chunk_size).clone()
param._thunder_fsdp_padding_size = _thunder_fsdp_padding_size


@torch.no_grad()
Expand Down
15 changes: 11 additions & 4 deletions thunder/distributed/prims.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,11 @@ def wait_meta(a: FutureTensorProxy, /) -> TensorProxy:
return TensorProxy(like=a)


def synchronize_meta(a: TensorProxy, /, group: torch.distributed.ProcessGroup) -> TensorProxy:
def synchronize_meta(
a: TensorProxy,
/,
group: torch.distributed.ProcessGroup,
) -> TensorProxy:
utils.check_type(a, TensorProxy)
utils.check_type(group, torch.distributed.ProcessGroup)

Expand Down Expand Up @@ -286,7 +290,8 @@ def stash_grad_for_fsdp_meta(

@register_augmented_forward(PrimIDs.SYNCHRONIZE)
def synchronize_augmented_forward_rule(
a: TensorProxy, group: torch.distributed.ProcessGroup
a: TensorProxy,
group: torch.distributed.ProcessGroup,
) -> tuple[TensorProxy, tuple]:
match a.ddp_type:
case DDPType.REPLICATED:
Expand All @@ -302,7 +307,7 @@ def synchronize_augmented_forward_rule(
# immediately called on the result with the hope that the execution
# passes would reorder the wait operation to be closer to the actual
# usage of the tensor.
return all_gather(a, group, do_async=True).wait(), (
return all_gather(a, group, True).wait(), (
a.ddp_type,
group,
)
Expand All @@ -312,7 +317,9 @@ def synchronize_augmented_forward_rule(

@register_backward(PrimIDs.SYNCHRONIZE)
def synchronize_backward_rule(
ddp_type: DDPType, group: torch.distributed.ProcessGroup, grad: TensorProxy
ddp_type: DDPType,
group: torch.distributed.ProcessGroup,
grad: TensorProxy,
) -> tuple[TensorProxy, None]:
preaverage_grad = grad / group.size()
match ddp_type:
Expand Down
50 changes: 50 additions & 0 deletions thunder/tests/distributed/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,6 +804,55 @@ def test_fsdp_grad_parity_with_without_bucketing(
if bucketing_strategy == FSDPBucketingStrategy.LAYER:
self.assertTrue(has_pack_multiple_tensors, msg=f"{[bsym.args[0] for bsym in pack_bsyms]=}")

@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires 2 devices")
@common_utils.parametrize(
"bucketing_strategy,fsdptype",
product(
(
FSDPBucketingStrategy.NONE,
FSDPBucketingStrategy.BLOCK,
),
(FSDPType.ZERO2, FSDPType.ZERO3),
),
name_fn=lambda bucketing_strategy, fsdptype: (
f"bucketing_{str(bucketing_strategy).split('.')[1].lower()}_{(str(fsdptype).lower().split('.')[1])}"
),
)
def test_fsdp_with_padding(
self,
bucketing_strategy: FSDPBucketingStrategy,
fsdptype: FSDPType,
):

from thunder.core.prims import PrimIDs
from thunder.executors.torchex import pad_prim_impl
from thunder.executors.torchex import slice_prim_impl

class M(nn.Module):
def __init__(self):
super().__init__()
self.l1 = nn.Linear(4, 13)
self.l2 = nn.Linear(13, 1)

def forward(self, x):
return self.l2(new_gelu(self.l1(x)))

device = torch.device(f"cuda:{self.rank}")
m = M().to(device)
jitted = thunder.jit(fsdp(m, bucketing_strategy=bucketing_strategy, sharding_strategy=fsdptype))

x = torch.randn(4, 4, device=device)
y = jitted(x)
y.mean().backward()

fw_extrace = thunder.last_traces(jitted)[-1]
fw_symids = [bsym.sym.id for bsym in fw_extrace.bound_symbols]
self.assertTrue(any(sym_id in {PrimIDs.SLICE, slice_prim_impl.id} for sym_id in fw_symids))

bw_trace = thunder.last_backward_traces(jitted)[0]
bw_symids = [bsym.sym.id for bsym in bw_trace.bound_symbols]
self.assertTrue(any(sym_id in {PrimIDs.PAD, pad_prim_impl.id} for sym_id in bw_symids))

@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires 2 devices")
def test_fsdp_shard_unshard(self):
from thunder.distributed import _shard_params, _unshard_params
Expand All @@ -814,6 +863,7 @@ def test_fsdp_shard_unshard(self):
model = torch.nn.Linear(3, 5, bias=False, device="meta")
with pytest.raises(RuntimeError, match=r"parameter 'weight' \(5\) to be divisible by the world size \(2\)"):
_shard_params(model, pg, device, None)
_shard_params(model, pg, device, None, allow_padding_for_fsdp=True)

model = torch.nn.Linear(3, 4, bias=False, device="meta")
weight = torch.arange(3 * 4, device="cpu", dtype=torch.float).view(4, 3)
Expand Down

0 comments on commit 2fa97d6

Please sign in to comment.