diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index f3a6901ada6b..d5afa2ba83ce 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -29,6 +29,7 @@ save_state_dict, sharded_optimizer_loading_epilogue, ) +from colossalai.cluster import ProcessGroupMesh from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper from colossalai.interface.optimizer import DistributedOptim from colossalai.logging import get_dist_logger @@ -333,6 +334,7 @@ class LowLevelZeroPlugin(DPPluginBase): verbose (bool, optional): verbose mode. Debug info including grad overflow will be printed. Defaults to False. use_fp8 (bool, optional): Whether to enable fp8 mixed precision training. Defaults to False. fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False. + extra_dp_size (int, optional): The number of extra data parallel groups. Defaults to 1. """ def __init__( @@ -358,11 +360,16 @@ def __init__( cast_inputs: bool = True, fp8_communication: bool = False, use_fp8: bool = False, + extra_dp_size: int = 1, ) -> None: super().__init__() assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training" assert precision in SUPPORTED_PRECISION, f"LowLevelZeroPlugin only supports {SUPPORTED_PRECISION} training" assert norm_type == 2.0, f"LowLevelZeroPlugin only supports norm_type=2.0 now" + if extra_dp_size > 1: + assert dist.get_world_size() % extra_dp_size == 0, "extra_dp_size should be a factor of world_size" + inner_dp_size = dist.get_world_size() // extra_dp_size + self.pg_mesh = ProcessGroupMesh(extra_dp_size, inner_dp_size) self.stage = stage self.precision = precision self.zero_optim_kwargs = dict( @@ -383,6 +390,9 @@ def __init__( overlap_allgather=overlap_allgather, fp8_communication=fp8_communication, ) + if extra_dp_size > 1: + self.zero_optim_kwargs["extra_dp_group"] = self.pg_mesh.get_group_along_axis(0) + self.zero_optim_kwargs["dp_process_group"] = self.pg_mesh.get_group_along_axis(1) self.lora_enabled = False self.verbose = verbose self.logger = get_dist_logger() diff --git a/colossalai/zero/low_level/_utils.py b/colossalai/zero/low_level/_utils.py index 5ab703f09063..8a641f71719c 100644 --- a/colossalai/zero/low_level/_utils.py +++ b/colossalai/zero/low_level/_utils.py @@ -1,6 +1,7 @@ import math -from typing import Optional +from typing import Optional, Tuple, Union +import numpy as np import torch import torch.distributed as dist from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors @@ -209,3 +210,42 @@ def sync_tensor(flat_tensor, tensor_list): # update the tensor data for p, q in zip(tensor_list, updated_params): p.data = q.data + + +def all_gather_into_flat_tensor_nd( + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + group: Union[dist.ProcessGroup, Tuple[dist.ProcessGroup, ...]], + async_op: bool = False, +): + if isinstance(group, dist.ProcessGroup): + group = (group,) + sizes = [dist.get_world_size(pg) for pg in group] + ranks = [dist.get_rank(pg) for pg in group] + for i, pg in list(enumerate(group))[::-1]: + if i == 0: + out = output_tensor + else: + prev_sizes = sizes[:i] + prev_ranks = ranks[:i] + chunks = output_tensor.chunk(np.prod(prev_sizes)) + out = chunks[np.ravel_multi_index(prev_ranks, prev_sizes)] + handle = dist.all_gather_into_tensor(out, input_tensor, group=pg, async_op=async_op) + input_tensor = out + return handle + + +def get_nd_world_size(group) -> int: + if isinstance(group, tuple): + return int(np.prod([dist.get_world_size(pg) for pg in group])) + else: + return dist.get_world_size(group) + + +def get_nd_rank(group) -> int: + if isinstance(group, tuple): + return np.ravel_multi_index( + tuple(dist.get_rank(group=pg) for pg in group), [dist.get_world_size(pg) for pg in group] + ) + else: + return dist.get_rank(group) diff --git a/colossalai/zero/low_level/bookkeeping/base_store.py b/colossalai/zero/low_level/bookkeeping/base_store.py index 7f2f9664b7de..291f7a0135bc 100644 --- a/colossalai/zero/low_level/bookkeeping/base_store.py +++ b/colossalai/zero/low_level/bookkeeping/base_store.py @@ -1,11 +1,20 @@ +from typing import Tuple, Union + +import numpy as np import torch.distributed as dist from torch.distributed import ProcessGroup class BaseStore: - def __init__(self, torch_pg: ProcessGroup): - self._world_size = dist.get_world_size(group=torch_pg) - self._local_rank = dist.get_rank(group=torch_pg) + def __init__(self, torch_pg: Union[ProcessGroup, Tuple[ProcessGroup, ...]]): + if isinstance(torch_pg, tuple): + self.sizes = [dist.get_world_size(group=pg) for pg in torch_pg] + self._world_size = int(np.prod(self.sizes)) + self._local_rank = np.ravel_multi_index(tuple(dist.get_rank(group=pg) for pg in torch_pg), self.sizes) + else: + self._world_size = dist.get_world_size(group=torch_pg) + self._local_rank = dist.get_rank(group=torch_pg) + self.sizes = [self._world_size] self.torch_pg = torch_pg @property diff --git a/colossalai/zero/low_level/bookkeeping/tensor_bucket.py b/colossalai/zero/low_level/bookkeeping/tensor_bucket.py index 3c95aa6babcd..452080a491c7 100644 --- a/colossalai/zero/low_level/bookkeeping/tensor_bucket.py +++ b/colossalai/zero/low_level/bookkeeping/tensor_bucket.py @@ -1,10 +1,12 @@ from typing import Optional +import numpy as np import torch import torch.distributed as dist from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from colossalai.quantization.fp8 import all_gather_fp8 +from colossalai.zero.low_level._utils import all_gather_into_flat_tensor_nd class TensorBucket: @@ -65,12 +67,18 @@ def unflatten_and_copy(self, flat_tensor): def all_gather(self, group=None, fp8_communication: bool = False): flat = self.flatten() - buffer = torch.empty(flat.numel() * dist.get_world_size(group), device=flat.device, dtype=flat.dtype) + if isinstance(group, tuple): + world_size = np.prod([dist.get_world_size(pg) for pg in group]) + else: + world_size = dist.get_world_size(group) + buffer = torch.empty(flat.numel() * world_size, device=flat.device, dtype=flat.dtype) if fp8_communication: + # TODO: fit fp8 all_gather_fp8(list(buffer.chunk(dist.get_world_size(group))), flat, group=group, fp8_format="e4m3") else: - dist.all_gather_into_tensor(buffer, flat, group=group) - unflat_buffers = [self.unflatten(buffer) for buffer in buffer.chunk(dist.get_world_size(group))] + # dist.all_gather_into_tensor(buffer, flat, group=group) + all_gather_into_flat_tensor_nd(buffer, flat, group=group) + unflat_buffers = [self.unflatten(buffer) for buffer in buffer.chunk(world_size)] # transpose the list of list unflat_buffers = list(map(list, zip(*unflat_buffers))) for unflat_shards, tensor in zip(unflat_buffers, self._bucket): diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 8b0e6475b7e3..26fff75fbfdf 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -2,7 +2,7 @@ import copy from contextlib import contextmanager, nullcontext from functools import partial -from typing import Dict, Iterator, List, Optional, Tuple +from typing import Dict, Iterator, List, Optional, Tuple, Union from weakref import proxy import torch @@ -23,7 +23,15 @@ from colossalai.quantization.fp8 import all_gather_fp8, all_reduce_fp8, reduce_scatter_fp8 from colossalai.tensor.moe_tensor.api import is_moe_tensor -from ._utils import calculate_global_norm_from_list, has_inf_or_nan, release_param_grad, sync_tensor +from ._utils import ( + all_gather_into_flat_tensor_nd, + calculate_global_norm_from_list, + get_nd_rank, + get_nd_world_size, + has_inf_or_nan, + release_param_grad, + sync_tensor, +) from .bookkeeping import BucketStore, GradientStore, TensorBucket from .zero_hook import set_all_gather_handle, wait_all_gather_handle @@ -68,7 +76,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): def __init__( self, optimizer: Optimizer, - pg_to_param_list: Optional[Dict[ProcessGroup, List[nn.Parameter]]] = None, + pg_to_param_list: Optional[Dict[Union[ProcessGroup, Tuple[ProcessGroup, ...]], List[nn.Parameter]]] = None, initial_scale: int = 2**16, # grad scaler config min_scale: int = 1, growth_factor: float = 2.0, @@ -84,6 +92,7 @@ def __init__( partition_grad: bool = False, # stage 2 flag cpu_offload: bool = False, # cpu offload dp_process_group: Optional[ProcessGroup] = None, + extra_dp_group: Optional[ProcessGroup] = None, forced_dtype: Optional[torch.dtype] = None, master_weights: bool = True, # master weights overlap_allgather: bool = False, @@ -98,9 +107,17 @@ def __init__( if (dp_process_group is not None) and (pg_to_param_list is not None): raise ValueError("dp_process_group and pg_to_param_list should not be provided at the same time.") + if pg_to_param_list is None and extra_dp_group is not None and dp_process_group is None: + raise ValueError("dp_process_group should be provided when extra_dp_group is provided.") + if pg_to_param_list is None and extra_dp_group is not None and fp8_communication: + raise ValueError( + "fp8_communication is not supported when pg_to_param_list is None and extra_dp_group is provided." + ) if pg_to_param_list is None: unique_dp_group = dist.group.WORLD if dp_process_group is None else dp_process_group + if extra_dp_group is not None: + unique_dp_group = (extra_dp_group, unique_dp_group) pg_to_param_list = {unique_dp_group: []} for group in self.optim.param_groups: pg_to_param_list[unique_dp_group].extend(group["params"]) @@ -336,10 +353,12 @@ def _run_reduction(self): flat_grads = flat_grads.to(self._communication_dtype) if not self._partition_grads: - if self._fp8_communication: - all_reduce_fp8(flat_grads, group=bucket_store.torch_pg) - else: - dist.all_reduce(flat_grads, group=bucket_store.torch_pg) + for i, sz in enumerate(bucket_store.sizes): + grp = bucket_store.torch_pg if len(bucket_store.sizes) == 1 else bucket_store.torch_pg[i] + if self._fp8_communication: + all_reduce_fp8(flat_grads, group=grp) + else: + dist.all_reduce(flat_grads, group=grp) if flat_grads.dtype != grad_dtype: flat_grads = flat_grads.to(grad_dtype) @@ -347,16 +366,20 @@ def _run_reduction(self): grad_in_bucket = bucket_store.get_grad() self._update_unpartitoned_grad(bucket_store, grad_in_bucket.values(), flat_grads_per_rank, group_id) else: - flat_grads_list = list(flat_grads.split(len(flat_grads) // bucket_store.world_size)) - received_grad = torch.zeros_like(flat_grads_list[0]) - if self._fp8_communication: - reduce_scatter_fp8( - received_grad, - flat_grads_list, - group=bucket_store.torch_pg, - ) - else: - dist.reduce_scatter(received_grad, flat_grads_list, group=bucket_store.torch_pg) + cur_flat_grads = flat_grads + for i, sz in enumerate(bucket_store.sizes): + grp = bucket_store.torch_pg if len(bucket_store.sizes) == 1 else bucket_store.torch_pg[i] + flat_grads_list = list(cur_flat_grads.split(len(cur_flat_grads) // sz)) + received_grad = torch.zeros_like(flat_grads_list[0]) + if self._fp8_communication: + reduce_scatter_fp8( + received_grad, + flat_grads_list, + group=grp, + ) + else: + dist.reduce_scatter_tensor(received_grad, cur_flat_grads, group=grp) + cur_flat_grads = received_grad if received_grad.dtype != grad_dtype: received_grad = received_grad.to(grad_dtype) @@ -577,11 +600,13 @@ def step(self, closure=None): pg = self.param_to_pg[working_param] padded_working_param = self._working_param_to_padded_working_param[working_param] if self._overlap_allgather: - handle = dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg, async_op=True) + # handle = dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg, async_op=True) + handle = all_gather_into_flat_tensor_nd(padded_working_param, param_to_gather, pg, async_op=True) set_all_gather_handle(working_param, handle) else: if param_to_gather.numel() > self.pg_to_tensor_bucket[pg].max_size: if self._fp8_communication: + # TODO: fit fp8 communication all_gather_fp8( list(padded_working_param.chunk(dist.get_world_size(pg))), param_to_gather, @@ -589,7 +614,8 @@ def step(self, closure=None): fp8_format="e4m3", ) else: - dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg) + # dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg) + all_gather_into_flat_tensor_nd(padded_working_param, param_to_gather, pg) continue try: self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param) @@ -602,7 +628,9 @@ def step(self, closure=None): if not tensor_bucket.is_empty(): tensor_bucket.all_gather(pg, fp8_communication=self._fp8_communication) - def _compute_grad_norm(self, dp_pg: ProcessGroup, gradients: List[Tensor], norm_type: int = 2) -> float: + def _compute_grad_norm( + self, dp_pg: Union[ProcessGroup, Tuple[ProcessGroup, ...]], gradients: List[Tensor], norm_type: int = 2 + ) -> float: r""" Compute and return the gradient norm for gradient clipping. @@ -625,7 +653,11 @@ def _compute_grad_norm(self, dp_pg: ProcessGroup, gradients: List[Tensor], norm_ device=get_accelerator().get_current_device(), dtype=torch.float, ) - dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=dp_pg) + if isinstance(dp_pg, tuple): + for grp in dp_pg: + dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=grp) + else: + dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=dp_pg) total_norm = total_norm_cuda.item() else: @@ -640,11 +672,19 @@ def _compute_grad_norm(self, dp_pg: ProcessGroup, gradients: List[Tensor], norm_ device=get_accelerator().get_current_device(), dtype=torch.float, ) - torch.distributed.all_reduce( - total_norm_exponentiated_cuda, - op=torch.distributed.ReduceOp.SUM, - group=dp_pg, - ) + if isinstance(dp_pg, tuple): + for grp in dp_pg: + dist.all_reduce( + total_norm_exponentiated_cuda, + op=torch.distributed.ReduceOp.SUM, + group=grp, + ) + else: + torch.distributed.all_reduce( + total_norm_exponentiated_cuda, + op=torch.distributed.ReduceOp.SUM, + group=dp_pg, + ) total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type) return total_norm @@ -744,11 +784,9 @@ def state_dict(self) -> Dict: if isinstance(v, torch.Tensor) and k != "step": working_param = self.master_to_working_param[id(param)] pg = self.param_to_pg[working_param] - gather_tensor = [torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(pg.size())] - dist.all_gather(gather_tensor, v.to(device), group=pg) - param_state = ( - torch.stack(gather_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu() - ) + gathered_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype) + all_gather_into_flat_tensor_nd(gathered_tensor, v.to(device).flatten(), pg) + param_state = gathered_tensor[: working_param.numel()].reshape_as(working_param).cpu() zero_state[param][k] = param_state states_dict = self._pack_state(zero_state) @@ -770,15 +808,17 @@ def load_state_dict(self, state_dict: Dict): cnt += 1 for param_idx, state in zero_state_dict["state"].items(): pg = self.param_to_pg[self.master_to_working_param[id(idx2master[param_idx])]] + world_size = get_nd_world_size(pg) + rank = get_nd_rank(pg) for k, v in state.items(): if isinstance(v, torch.Tensor) and k != "step": - padding_size = (pg.size() - v.numel() % pg.size()) % pg.size() + padding_size = (world_size - v.numel() % world_size) % world_size with torch.no_grad(): v = v.flatten() if padding_size > 0: v = torch.nn.functional.pad(v, [0, padding_size]) - v_list = v.split(v.numel() // pg.size()) - zero_state_dict["state"][param_idx][k] = v_list[pg.rank()].detach().clone() + v_list = v.split(v.numel() // world_size) + zero_state_dict["state"][param_idx][k] = v_list[rank].detach().clone() self.optim.load_state_dict(zero_state_dict) @@ -814,11 +854,9 @@ def state_dict_shard(self, max_shard_size: int = 1024) -> Iterator[Tuple[Dict, i for k, v in states.items(): if isinstance(v, torch.Tensor) and k != "step": - state_tensor = [torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(pg.size())] - dist.all_gather(state_tensor, v.to(device), group=pg) - state_tensor = ( - torch.stack(state_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu() - ) + state_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype) + all_gather_into_flat_tensor_nd(state_tensor, v.to(device).flatten(), pg) + state_tensor = state_tensor[: working_param.numel()].reshape_as(working_param).cpu() current_block_size += state_tensor.numel() current_block[k] = state_tensor @@ -842,12 +880,14 @@ def update_master_params(self, model: nn.Module) -> None: p_id = id(p) if p_id in self.working_to_master_param: pg = self.param_to_pg[p] + world_size = get_nd_world_size(pg) + rank = get_nd_rank(pg) master_param = self.working_to_master_param[p_id] padding_size = self.get_param_padding_size(p) working_param = p.data.view(-1) if padding_size > 0: working_param = torch.nn.functional.pad(working_param, [0, padding_size]) - master_param.copy_(working_param.chunk(pg.size())[pg.rank()]) + master_param.copy_(working_param.chunk(world_size)[rank]) def get_working_to_master_map(self) -> Dict[int, torch.Tensor]: return self.working_to_master_param @@ -905,9 +945,12 @@ def get_param_grad(self, working_param: nn.Parameter) -> Tensor: grad = grad_store.get_working_grad_by_param_id(id(working_param)) if grad is None: return None - grad_flat = torch.empty((grad_store.world_size, *grad.shape), dtype=grad.dtype, device=grad.device) - dist.all_gather_into_tensor(grad_flat, grad, group=grad_store.torch_pg) - return grad_flat.view(-1)[: working_param.numel()].view_as(working_param) + grad_flat = grad.flatten() + output_grad = torch.empty( + grad_flat.numel() * grad_store.world_size, device=grad_flat.device, dtype=grad_flat.dtype + ) + all_gather_into_flat_tensor_nd(output_grad, grad_flat, grad_store.torch_pg) + return output_grad.view(-1)[: working_param.numel()].view_as(working_param) def get_working_grads_by_group_id(self, group_id: int) -> List[Tensor]: working_grads = [] diff --git a/tests/test_zero/test_low_level/test_coll_nd.py b/tests/test_zero/test_low_level/test_coll_nd.py new file mode 100644 index 000000000000..c9d7e6341c48 --- /dev/null +++ b/tests/test_zero/test_low_level/test_coll_nd.py @@ -0,0 +1,42 @@ +import numpy as np +import pytest +import torch +import torch.distributed as dist + +import colossalai +from colossalai.cluster import ProcessGroupMesh +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.testing.random import seed_all +from colossalai.utils import get_current_device +from colossalai.zero.low_level._utils import all_gather_into_flat_tensor_nd + + +def check_all_gather_2d(): + seed_all(1024) + tensor = torch.rand(128, device=get_current_device()) + extra_dp_size, inner_dp_size = 2, 2 + pg_mesh = ProcessGroupMesh(extra_dp_size, inner_dp_size) + extra_dp_group = pg_mesh.get_group_along_axis(0) + inner_dp_group = pg_mesh.get_group_along_axis(1) + ranks = [dist.get_rank(extra_dp_group), dist.get_rank(inner_dp_group)] + sizes = [dist.get_world_size(extra_dp_group), dist.get_world_size(inner_dp_group)] + chunk = tensor.chunk(dist.get_world_size())[np.ravel_multi_index(ranks, sizes)].clone() + out = torch.zeros_like(tensor) + all_gather_into_flat_tensor_nd(out, chunk, group=(extra_dp_group, inner_dp_group)) + assert torch.equal(out, tensor) + + +def run_dist(rank, world_size, port): + colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost") + + check_all_gather_2d() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_comm_nd(): + spawn(run_dist, 4) + + +if __name__ == "__main__": + test_comm_nd() diff --git a/tests/test_zero/test_low_level/test_zero1_2.py b/tests/test_zero/test_low_level/test_zero1_2.py index 368c782fe2c4..103854f869c7 100644 --- a/tests/test_zero/test_low_level/test_zero1_2.py +++ b/tests/test_zero/test_low_level/test_zero1_2.py @@ -2,11 +2,13 @@ import pytest import torch +import torch.distributed as dist import torch.nn as nn from torch.nn.parallel import DistributedDataParallel as DDP from torch.testing import assert_close import colossalai +from colossalai.cluster import ProcessGroupMesh from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing.random import seed_all from colossalai.zero import LowLevelZeroOptimizer @@ -123,7 +125,8 @@ def exam_zero_1_2(fp8_communication: bool): @parameterize("dtype", [torch.float16, torch.bfloat16]) @parameterize("master_weights", [True, False]) -def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype, master_weights: bool): +@parameterize("extra_dp_size", [1, 2]) +def exam_zero_1_torch_ddp(dtype: torch.dtype, master_weights: bool, extra_dp_size: int): """ In this test, two pairs of model and optimizers are created. 1. zero: use sharded optimizer and fp16 parameters @@ -132,6 +135,15 @@ def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype, master_weights: bool): We feed these two sets of models with the same input and check if the differences in model output and updated parameters are within tolerance. """ + if extra_dp_size > 1 and dtype != torch.bfloat16: + return + if extra_dp_size > 1: + pg_mesh = ProcessGroupMesh(extra_dp_size, dist.get_world_size() // extra_dp_size) + extra_dp_group = pg_mesh.get_group_along_axis(0) + dp_group = pg_mesh.get_group_along_axis(1) + else: + extra_dp_group = None + dp_group = None local_rank = torch.distributed.get_rank() seed_all(1453) @@ -153,6 +165,8 @@ def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype, master_weights: bool): initial_scale=1, reduce_bucket_size=1024 * 1024, master_weights=master_weights, + dp_process_group=dp_group, + extra_dp_group=extra_dp_group, ) torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1) @@ -200,14 +214,14 @@ def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype, master_weights: bool): def run_dist(rank, world_size, port): colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost") - exam_zero_1_torch_ddp(world_size=world_size) + exam_zero_1_torch_ddp() exam_zero_1_2() @pytest.mark.dist @rerun_if_address_is_in_use() def test_zero_1_2(): - spawn(run_dist, 2) + spawn(run_dist, 4) if __name__ == "__main__": diff --git a/tests/test_zero/test_low_level/test_zero_ckpt.py b/tests/test_zero/test_low_level/test_zero_ckpt.py index 8543dfba0c15..656559718518 100644 --- a/tests/test_zero/test_low_level/test_zero_ckpt.py +++ b/tests/test_zero/test_low_level/test_zero_ckpt.py @@ -2,12 +2,14 @@ import pytest import torch +import torch.distributed as dist import torch.nn as nn from torch.nn.parallel import DistributedDataParallel as DDP from torch.testing import assert_close import colossalai -from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.cluster import ProcessGroupMesh +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing.random import seed_all from colossalai.zero import LowLevelZeroOptimizer @@ -40,11 +42,19 @@ def loose_close(a, b, dtype: torch.dtype = torch.float32): assert_close(a, b, rtol=rtol, atol=atol) -def exam_zero_1_torch_ddp_ckpt(): +@parameterize("extra_dp_size", [1, 2]) +def exam_zero_1_torch_ddp_ckpt(extra_dp_size: int): """ We examine the state_dict of zero and DDP. Moreover, we examine the zero's loading checkpoint of a torch ckpt. """ + if extra_dp_size > 1: + pg_mesh = ProcessGroupMesh(extra_dp_size, dist.get_world_size() // extra_dp_size) + extra_dp_group = pg_mesh.get_group_along_axis(0) + dp_group = pg_mesh.get_group_along_axis(1) + else: + dp_group = None + extra_dp_group = None local_rank = torch.distributed.get_rank() seed_all(1453) @@ -60,7 +70,12 @@ def exam_zero_1_torch_ddp_ckpt(): # we only test stage 1 here # the state dicts of stage 1 and stage 2 are the same zero_optimizer = LowLevelZeroOptimizer( - zero_optimizer, overlap_communication=True, initial_scale=1, reduce_bucket_size=262144 + zero_optimizer, + overlap_communication=True, + initial_scale=1, + reduce_bucket_size=262144, + dp_process_group=dp_group, + extra_dp_group=extra_dp_group, ) torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=1) @@ -111,7 +126,7 @@ def run_dist(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() def test_zero_ckpt(): - spawn(run_dist, 2) + spawn(run_dist, 4) if __name__ == "__main__":