Skip to content

Commit

Permalink
[Gemini] Use async stream to prefetch and h2d data moving (hpcaitech#…
Browse files Browse the repository at this point in the history
…5781)

* use async stream to prefetch and h2d data moving

* Remove redundant code
  • Loading branch information
Hz188 authored Jun 12, 2024
1 parent 8554585 commit d9dddf5
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 12 deletions.
2 changes: 2 additions & 0 deletions colossalai/zero/gemini/chunk/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __init__(
chunk_configuration,
init_device: Optional[torch.device] = None,
reuse_fp16_chunk: bool = True,
max_prefetch: int = 0,
) -> None:
self.device = init_device or get_accelerator().get_current_device()
self.dp_degree_chunk_size_dict: Dict[int, int] = dict()
Expand All @@ -42,6 +43,7 @@ def __init__(
# Whether model is accumulating gradients,
self.accumulating_grads = False
self.overflow_counter = torch.tensor([0], dtype=torch.int, device=get_accelerator().get_current_device())
self._prefetch_stream = get_accelerator().Stream() if max_prefetch else None

def register_tensor(
self,
Expand Down
7 changes: 2 additions & 5 deletions colossalai/zero/gemini/chunk/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def init_chunk_manager(
hidden_dim: Optional[int] = None,
reuse_fp16_chunk: bool = True,
verbose: bool = False,
max_prefetch: int = 0,
**kwargs,
) -> ChunkManager:
if hidden_dim:
Expand Down Expand Up @@ -51,9 +52,5 @@ def init_chunk_manager(
)
dist.barrier()

chunk_manager = ChunkManager(
config_dict,
init_device,
reuse_fp16_chunk=reuse_fp16_chunk,
)
chunk_manager = ChunkManager(config_dict, init_device, reuse_fp16_chunk=reuse_fp16_chunk, max_prefetch=max_prefetch)
return chunk_manager
5 changes: 2 additions & 3 deletions colossalai/zero/gemini/gemini_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,7 @@ def __init__(
self.enable_gradient_accumulation = enable_gradient_accumulation
if chunk_config_dict is not None:
self.chunk_manager = ChunkManager(
chunk_config_dict,
chunk_init_device,
reuse_fp16_chunk=reuse_fp16_chunk,
chunk_config_dict, chunk_init_device, reuse_fp16_chunk=reuse_fp16_chunk, max_prefetch=max_prefetch
)
else:
# some ugly hotfix for the compatibility with Lightning
Expand All @@ -122,6 +120,7 @@ def __init__(
process_group=zero_group,
reuse_fp16_chunk=reuse_fp16_chunk,
verbose=verbose,
max_prefetch=max_prefetch,
)
self.gemini_manager = GeminiManager(
placement_policy,
Expand Down
10 changes: 6 additions & 4 deletions colossalai/zero/gemini/gemini_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import torch

from colossalai.accelerator import get_accelerator
from colossalai.tensor.param_op_hook import ColoParamOpHook
from colossalai.utils import is_ddp_ignored
from colossalai.zero.gemini import TensorState
Expand Down Expand Up @@ -54,10 +55,11 @@ def pre_op(self, params):
)

# prefetch
for chunk in chunks_fetch_async:
maybe_work = self._chunk_manager.access_chunk(chunk, async_access=True)
if maybe_work is not None:
self._gemini_manager.add_work(chunk, maybe_work)
with get_accelerator().stream(self._gemini_manager.chunk_manager._prefetch_stream):
for chunk in chunks_fetch_async:
maybe_work = self._chunk_manager.access_chunk(chunk, async_access=True)
if maybe_work is not None:
self._gemini_manager.add_work(chunk, maybe_work)

# record cuda model data of the current OP, including memory for prefetched chunks
self._gemini_manager.record_model_data_volume()
Expand Down

0 comments on commit d9dddf5

Please sign in to comment.