Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CI] Cleanup Dist Optim tests with shared helper funcs #6125

Merged
merged 4 commits into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions colossalai/shardformer/layer/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ class Linear1D_Row(ParallelModule):
out_features (int): size of each output sample.
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False.
parallel_input (bool): If set to ``True``, it's assumed that the input is already split/copied across each rank, defaults to False.
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
seq_parallel_mode (`str`): The type of sp mode, it will use sequence parallel when `seq_parallel_mode` is not None. Defaults to None.
seq_parallel_dim (`int`): Which dim will sequence parallelism split and gather the sequence.
Expand Down Expand Up @@ -544,14 +544,14 @@ def forward(self, input_: Tensor) -> Tensor:
if self.parallel_input:
assert (
input_.shape[-1] == self.weight.shape[-1]
), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format(
), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected feature dim of input {}.".format(
input_.shape, self.weight.shape, self.weight.shape[-1]
)
input_ = input_
else:
assert (
divide(input_.shape[-1], self.num_partitions) == self.weight.shape[-1]
), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format(
), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected feature dim of input {}.".format(
input_.shape, self.weight.shape, self.weight.shape[-1] * self.num_partitions
)
input_ = split_forward_gather_backward(
Expand Down
4 changes: 2 additions & 2 deletions tests/kit/model_zoo/custom/simple_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@


class Net(nn.Module):
def __init__(self, in_dim=_IN_DIM, hid_dim=_HID_DIM, identity=False, dtype=torch.float32):
def __init__(self, in_dim=_IN_DIM, hid_dim=_HID_DIM, identity=True, dtype=torch.float32):
super().__init__()
if identity:
self.fc0 = nn.Identity()
Expand All @@ -30,7 +30,7 @@ def forward(self, x):
class TPNet(nn.Module):
def __init__(
self,
fc0=nn.Linear(_IN_DIM, _IN_DIM),
fc0=nn.Identity(),
fc1=nn.Linear(_IN_DIM, _HID_DIM),
fc2=nn.Linear(_HID_DIM, _IN_DIM),
tp_group=None,
Expand Down
85 changes: 85 additions & 0 deletions tests/test_optimizer/_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.testing import assert_close

import colossalai
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor import get_layout, get_sharding_spec, is_distributed_tensor
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.tensor.d_tensor.sharding_spec import DimSpec
from colossalai.testing import parameterize, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import (
Expand All @@ -15,6 +18,88 @@
)


def force_assign_grad(p, g_dtype, grad=None):
"""Bypass inconsistent grad and param dtype error when assigning grad"""
orig_p = p.data
p.data = torch.randn_like(p, device=orig_p.device, dtype=g_dtype) if grad == None else grad.clone().to(g_dtype)
p.grad = p.data
p.data = orig_p


def setup_param_groups(model: nn.Module) -> list:
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": 0.1,
},
{
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
return optimizer_grouped_parameters


# setup flatten param groups, sharding spec and shape; (For dist Adafactor and CAME)
def setup_flatten_param_groups_sharding_spec_shape(model: nn.Module) -> dict:
flatten_optimizer_grouped_parameters = []
sharding_spec = {} # {id(flatten param): get_layout(p).global_shape}
param_shape = {} # {id(flatten param): get_sharding_spec(p)}
for n, p in model.named_parameters():
# flatten_p = copy.deepcopy(p).flatten()
flatten_p = nn.Parameter(p.clone().flatten().requires_grad_(True))
flatten_optimizer_grouped_parameters.append(flatten_p)
if is_distributed_tensor(p):
sharding_spec[id(flatten_p)] = get_sharding_spec(p)
param_shape[id(flatten_p)] = get_layout(p).global_shape
else:
sharding_spec[id(flatten_p)] = None
param_shape[id(flatten_p)] = p.shape
return flatten_optimizer_grouped_parameters, sharding_spec, param_shape


def set_master_param_to_shard_param(master_param_list) -> dict:
master_param_to_shard_param = {id(p): p for p in master_param_list}
return master_param_to_shard_param


def set_dist_grad(
dist_module: nn.Module,
torch_model: nn.Module,
g_dtype: torch.dtype,
group: dist.ProcessGroup,
tp_spec: DimSpec,
) -> None:
"""
Set split grads for Tensor Parallel or ZeRO DP.
We do not need a separate treatment for ZeRO,
as the wrapper takes care of reduce-scattering grads.
"""
rank = dist.get_rank(group)
world_size = dist.get_world_size(group)

for p, torch_p in zip(dist_module.parameters(), torch_model.parameters()):
if torch_p.grad is None:
torch_p.grad = torch.zeros_like(torch_p)

is_distributed = hasattr(p, "dist_layout")
if is_distributed:
sharding = p.dist_layout.sharding_spec.sharding_sequence
split_dim = sharding.index(tp_spec)
shape = torch_p.split(world_size, dim=split_dim)[rank].shape

indices = torch.arange(shape[split_dim] * rank, shape[split_dim] * (rank + 1))
# Generate grads only for the correctly split chunk
torch_p.grad.index_add_(split_dim, indices, torch.randn(shape, device=torch_p.device, dtype=g_dtype))

else:
shape = torch_p.shape
torch_p.grad += torch.randn(shape, device=torch_p.device, dtype=g_dtype)

force_assign_grad(p, g_dtype, grad=torch_p.grad)


def check_optim_states(org_optim, sharded_optim):
for group in org_optim.param_groups:
for p in group["params"]:
Expand Down
21 changes: 2 additions & 19 deletions tests/test_optimizer/test_adam_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from colossalai.nn.optimizer import CPUAdam, FusedAdam, HybridAdam
from tests.kit.model_zoo import model_zoo
from tests.test_optimizer._utils import force_assign_grad, setup_param_groups

_ALLOWED_OPTIM_DEVICES = [
(FusedAdam, torch.device("cuda:0")),
Expand All @@ -26,29 +27,11 @@
N_STEPS = 3


def setup_param_groups(bert_model: nn.Module) -> list:
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in bert_model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": 0.1,
},
{
"params": [p for n, p in bert_model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
return optimizer_grouped_parameters


def set_grad(model: nn.Module, torch_model: nn.Module, g_dtype: torch.dtype) -> None:
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
torch_p.grad = torch.rand_like(torch_p)
# avoid inconsistent grad and param dtype error
orig_p = p.data
p.data = torch_p.grad.clone().to(g_dtype)
p.grad = p.data
p.data = orig_p
force_assign_grad(p, g_dtype, torch_p.grad)


@pytest.mark.parametrize("optim_cls, device", _ALLOWED_OPTIM_DEVICES)
Expand Down
118 changes: 24 additions & 94 deletions tests/test_optimizer/test_dist_adafactor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import copy

import pytest
import torch
import torch.distributed as dist
Expand All @@ -16,7 +14,6 @@
from colossalai.tensor.d_tensor import (
distribute_tensor,
get_device_mesh,
get_layout,
get_sharding_spec,
is_distributed_tensor,
shard_colwise,
Expand All @@ -28,7 +25,13 @@
from colossalai.utils import set_seed
from colossalai.zero import LowLevelZeroOptimizer
from tests.kit.model_zoo import model_zoo
from tests.test_optimizer._utils import check_dist_optim_state, check_dist_param, check_optim_states
from tests.test_optimizer._utils import (
check_dist_optim_state,
check_dist_param,
check_optim_states,
set_master_param_to_shard_param,
setup_param_groups,
)
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
build_model_from_low_level_zero_plugin,
Expand All @@ -38,10 +41,13 @@
unwrap_model,
)

HEIGHT = 4
WIDTH = 4
IN_DIM = 4
HID_DIM = 4
_TP_SPEC = DimSpec([0])

Net, data_gen, *_ = next(iter(model_zoo.get_sub_registry("simple_mlp").values()))
TPNet, *_ = next(iter(model_zoo.get_sub_registry("simple_tp_mlp").values()))


def correctness_verify(tensor1: torch.Tensor, tensor2: torch.Tensor, dtype: torch.dtype = torch.float32):
rtol = None
Expand All @@ -59,92 +65,11 @@ def correctness_verify(tensor1: torch.Tensor, tensor2: torch.Tensor, dtype: torc
assert_close(tensor1, tensor2, rtol=rtol, atol=atol)


# setup param groups; (For zero test optim)
def setup_param_groups_zero(model: nn.Module) -> list:
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": 0.1,
},
{
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
return optimizer_grouped_parameters


# setup param groups; (For base optim)
def setup_param_groups(model: nn.Module) -> list:
optimizer_grouped_parameters = [p for n, p in model.named_parameters()]
return optimizer_grouped_parameters


# setup flatten param groups, sharding spec and shape; (For dist optim)
def setup_flatten_param_groups_sharding_spec_shape(model: nn.Module) -> dict:
flatten_optimizer_grouped_parameters = []
sharding_spec = {} # {id(flatten param): get_layout(p).global_shape}
param_shape = {} # {id(flatten param): get_sharding_spec(p)}
for n, p in model.named_parameters():
# flatten_p = copy.deepcopy(p).flatten()
flatten_p = nn.Parameter(p.clone().flatten().requires_grad_(True))
flatten_optimizer_grouped_parameters.append(flatten_p)
if is_distributed_tensor(p):
sharding_spec[id(flatten_p)] = get_sharding_spec(p)
param_shape[id(flatten_p)] = get_layout(p).global_shape
else:
sharding_spec[id(flatten_p)] = None
param_shape[id(flatten_p)] = p.shape
return flatten_optimizer_grouped_parameters, sharding_spec, param_shape


def set_dist_grad(
dist_module: nn.Module, torch_model: nn.Module, g_dtype: torch.dtype, group: dist.ProcessGroup
) -> None:
"""
Set split grads for Tensor Parallel or ZeRO DP.
We do not need a separate treatment for ZeRO,
as the wrapper takes care of reduce-scattering grads.
"""
rank = dist.get_rank(group)
world_size = dist.get_world_size(group)

for p, torch_p in zip(dist_module.parameters(), torch_model.parameters()):
if torch_p.grad is None:
torch_p.grad = torch.zeros_like(torch_p)

is_distributed = hasattr(p, "dist_layout")
if is_distributed:
sharding = p.dist_layout.sharding_spec.sharding_sequence
split_dim = sharding.index(_TP_SPEC)
shape = torch_p.split(world_size, dim=split_dim)[rank].shape

indices = torch.arange(shape[split_dim] * rank, shape[split_dim] * (rank + 1))
# Generate grads only for the correctly split chunk
torch_p.grad.index_add_(split_dim, indices, torch.randn(shape, device=torch_p.device, dtype=g_dtype))

else:
shape = torch_p.shape
torch_p.grad += torch.randn(shape, device=torch_p.device, dtype=g_dtype)

# avoid inconsistent grad and param dtype error
orig_p = p.data
p.data = torch_p.grad.clone().to(g_dtype)
p.grad = p.data
p.data = orig_p


def set_master_param_to_shard_param(master_param_list) -> dict:
master_param_to_shard_param = {id(p): p for p in master_param_list}
return master_param_to_shard_param


class MlpModel(nn.Module):
def __init__(self):
super(MlpModel, self).__init__()
self.linear1 = nn.Linear(HEIGHT, WIDTH)
self.linear2 = nn.Linear(WIDTH, HEIGHT)
self.linear1 = nn.Linear(IN_DIM, HID_DIM)
self.linear2 = nn.Linear(HID_DIM, IN_DIM)

def forward(self, x):
x = self.linear1(x)
Expand Down Expand Up @@ -182,7 +107,7 @@ def exam_dist_adafactor_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]):
# ==============================
# Base Case
# ==============================
H, W = HEIGHT, WIDTH
H, W = IN_DIM, HID_DIM
model_col = nn.Linear(H, W).to(local_rank) # Col parallel weight
weight, bias = model_col.weight, model_col.bias

Expand Down Expand Up @@ -284,8 +209,11 @@ def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]):
# ==============================
# Model Init
# ==============================
base_model = MlpModel().to(local_rank)
tp_model = TPModel(copy.deepcopy(base_model.linear1), copy.deepcopy(base_model.linear2), tp_group).to(local_rank)
# base_model = MlpModel().to(local_rank)
# tp_model = TPModel(copy.deepcopy(base_model.linear1), copy.deepcopy(base_model.linear2), tp_group).to(local_rank)
base_model = Net(in_dim=IN_DIM, hid_dim=HID_DIM, dtype=dtype).to(local_rank)
# Must specify dtype; TPNet init seem to run out of set_default_dtype scope
tp_model = TPNet(fc1=base_model.fc1, fc2=base_model.fc2, tp_group=tp_group, dtype=dtype)

base_param_group = setup_param_groups(base_model)
tp_param_group = setup_param_groups(tp_model)
Expand Down Expand Up @@ -335,7 +263,7 @@ def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]):
# ==============================
# Correctness Verify
# ==============================
x = torch.randn(HEIGHT, WIDTH, device=local_rank)
x = torch.randn(IN_DIM, HID_DIM, device=local_rank)

out = base_model(x)
out_tp = tp_model(x)
Expand All @@ -353,7 +281,9 @@ def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]):
base_optim.zero_grad()
dist_optim.zero_grad()

for p, tp_p in zip(base_param_group, tp_param_group):
base_params = base_model.parameters()
tp_params = tp_model.parameters()
for p, tp_p in zip(base_params, tp_params):
param_is_distributed = is_distributed_tensor(tp_p)
if param_is_distributed:
shard_spec = get_sharding_spec(tp_p)
Expand Down
Loading