Skip to content

Commit

Permalink
[zero] support extra dp (hpcaitech#6123)
Browse files Browse the repository at this point in the history
* [zero] support extra dp

* [zero] update checkpoint

* fix bugs

* fix bugs
  • Loading branch information
ver217 authored Nov 12, 2024
1 parent 30a9443 commit a259651
Show file tree
Hide file tree
Showing 8 changed files with 238 additions and 57 deletions.
10 changes: 10 additions & 0 deletions colossalai/booster/plugin/low_level_zero_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
save_state_dict,
sharded_optimizer_loading_epilogue,
)
from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.interface.optimizer import DistributedOptim
from colossalai.logging import get_dist_logger
Expand Down Expand Up @@ -333,6 +334,7 @@ class LowLevelZeroPlugin(DPPluginBase):
verbose (bool, optional): verbose mode. Debug info including grad overflow will be printed. Defaults to False.
use_fp8 (bool, optional): Whether to enable fp8 mixed precision training. Defaults to False.
fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False.
extra_dp_size (int, optional): The number of extra data parallel groups. Defaults to 1.
"""

def __init__(
Expand All @@ -358,11 +360,16 @@ def __init__(
cast_inputs: bool = True,
fp8_communication: bool = False,
use_fp8: bool = False,
extra_dp_size: int = 1,
) -> None:
super().__init__()
assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training"
assert precision in SUPPORTED_PRECISION, f"LowLevelZeroPlugin only supports {SUPPORTED_PRECISION} training"
assert norm_type == 2.0, f"LowLevelZeroPlugin only supports norm_type=2.0 now"
if extra_dp_size > 1:
assert dist.get_world_size() % extra_dp_size == 0, "extra_dp_size should be a factor of world_size"
inner_dp_size = dist.get_world_size() // extra_dp_size
self.pg_mesh = ProcessGroupMesh(extra_dp_size, inner_dp_size)
self.stage = stage
self.precision = precision
self.zero_optim_kwargs = dict(
Expand All @@ -383,6 +390,9 @@ def __init__(
overlap_allgather=overlap_allgather,
fp8_communication=fp8_communication,
)
if extra_dp_size > 1:
self.zero_optim_kwargs["extra_dp_group"] = self.pg_mesh.get_group_along_axis(0)
self.zero_optim_kwargs["dp_process_group"] = self.pg_mesh.get_group_along_axis(1)
self.lora_enabled = False
self.verbose = verbose
self.logger = get_dist_logger()
Expand Down
42 changes: 41 additions & 1 deletion colossalai/zero/low_level/_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import math
from typing import Optional
from typing import Optional, Tuple, Union

import numpy as np
import torch
import torch.distributed as dist
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
Expand Down Expand Up @@ -209,3 +210,42 @@ def sync_tensor(flat_tensor, tensor_list):
# update the tensor data
for p, q in zip(tensor_list, updated_params):
p.data = q.data


def all_gather_into_flat_tensor_nd(
output_tensor: torch.Tensor,
input_tensor: torch.Tensor,
group: Union[dist.ProcessGroup, Tuple[dist.ProcessGroup, ...]],
async_op: bool = False,
):
if isinstance(group, dist.ProcessGroup):
group = (group,)
sizes = [dist.get_world_size(pg) for pg in group]
ranks = [dist.get_rank(pg) for pg in group]
for i, pg in list(enumerate(group))[::-1]:
if i == 0:
out = output_tensor
else:
prev_sizes = sizes[:i]
prev_ranks = ranks[:i]
chunks = output_tensor.chunk(np.prod(prev_sizes))
out = chunks[np.ravel_multi_index(prev_ranks, prev_sizes)]
handle = dist.all_gather_into_tensor(out, input_tensor, group=pg, async_op=async_op)
input_tensor = out
return handle


def get_nd_world_size(group) -> int:
if isinstance(group, tuple):
return int(np.prod([dist.get_world_size(pg) for pg in group]))
else:
return dist.get_world_size(group)


def get_nd_rank(group) -> int:
if isinstance(group, tuple):
return np.ravel_multi_index(
tuple(dist.get_rank(group=pg) for pg in group), [dist.get_world_size(pg) for pg in group]
)
else:
return dist.get_rank(group)
15 changes: 12 additions & 3 deletions colossalai/zero/low_level/bookkeeping/base_store.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
from typing import Tuple, Union

import numpy as np
import torch.distributed as dist
from torch.distributed import ProcessGroup


class BaseStore:
def __init__(self, torch_pg: ProcessGroup):
self._world_size = dist.get_world_size(group=torch_pg)
self._local_rank = dist.get_rank(group=torch_pg)
def __init__(self, torch_pg: Union[ProcessGroup, Tuple[ProcessGroup, ...]]):
if isinstance(torch_pg, tuple):
self.sizes = [dist.get_world_size(group=pg) for pg in torch_pg]
self._world_size = int(np.prod(self.sizes))
self._local_rank = np.ravel_multi_index(tuple(dist.get_rank(group=pg) for pg in torch_pg), self.sizes)
else:
self._world_size = dist.get_world_size(group=torch_pg)
self._local_rank = dist.get_rank(group=torch_pg)
self.sizes = [self._world_size]
self.torch_pg = torch_pg

@property
Expand Down
14 changes: 11 additions & 3 deletions colossalai/zero/low_level/bookkeeping/tensor_bucket.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import Optional

import numpy as np
import torch
import torch.distributed as dist
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors

from colossalai.quantization.fp8 import all_gather_fp8
from colossalai.zero.low_level._utils import all_gather_into_flat_tensor_nd


class TensorBucket:
Expand Down Expand Up @@ -65,12 +67,18 @@ def unflatten_and_copy(self, flat_tensor):

def all_gather(self, group=None, fp8_communication: bool = False):
flat = self.flatten()
buffer = torch.empty(flat.numel() * dist.get_world_size(group), device=flat.device, dtype=flat.dtype)
if isinstance(group, tuple):
world_size = np.prod([dist.get_world_size(pg) for pg in group])
else:
world_size = dist.get_world_size(group)
buffer = torch.empty(flat.numel() * world_size, device=flat.device, dtype=flat.dtype)
if fp8_communication:
# TODO: fit fp8
all_gather_fp8(list(buffer.chunk(dist.get_world_size(group))), flat, group=group, fp8_format="e4m3")
else:
dist.all_gather_into_tensor(buffer, flat, group=group)
unflat_buffers = [self.unflatten(buffer) for buffer in buffer.chunk(dist.get_world_size(group))]
# dist.all_gather_into_tensor(buffer, flat, group=group)
all_gather_into_flat_tensor_nd(buffer, flat, group=group)
unflat_buffers = [self.unflatten(buffer) for buffer in buffer.chunk(world_size)]
# transpose the list of list
unflat_buffers = list(map(list, zip(*unflat_buffers)))
for unflat_shards, tensor in zip(unflat_buffers, self._bucket):
Expand Down
Loading

0 comments on commit a259651

Please sign in to comment.