diff --git a/colossalai/zero/gemini/chunk/manager.py b/colossalai/zero/gemini/chunk/manager.py index 45066ca898ef..3a5f0a5aaf32 100644 --- a/colossalai/zero/gemini/chunk/manager.py +++ b/colossalai/zero/gemini/chunk/manager.py @@ -25,6 +25,7 @@ def __init__( chunk_configuration, init_device: Optional[torch.device] = None, reuse_fp16_chunk: bool = True, + max_prefetch: int = 0, ) -> None: self.device = init_device or get_accelerator().get_current_device() self.dp_degree_chunk_size_dict: Dict[int, int] = dict() @@ -42,6 +43,7 @@ def __init__( # Whether model is accumulating gradients, self.accumulating_grads = False self.overflow_counter = torch.tensor([0], dtype=torch.int, device=get_accelerator().get_current_device()) + self._prefetch_stream = get_accelerator().Stream() if max_prefetch else None def register_tensor( self, diff --git a/colossalai/zero/gemini/chunk/utils.py b/colossalai/zero/gemini/chunk/utils.py index 049c5c10255b..884d1306ef77 100644 --- a/colossalai/zero/gemini/chunk/utils.py +++ b/colossalai/zero/gemini/chunk/utils.py @@ -21,6 +21,7 @@ def init_chunk_manager( hidden_dim: Optional[int] = None, reuse_fp16_chunk: bool = True, verbose: bool = False, + max_prefetch: int = 0, **kwargs, ) -> ChunkManager: if hidden_dim: @@ -51,9 +52,5 @@ def init_chunk_manager( ) dist.barrier() - chunk_manager = ChunkManager( - config_dict, - init_device, - reuse_fp16_chunk=reuse_fp16_chunk, - ) + chunk_manager = ChunkManager(config_dict, init_device, reuse_fp16_chunk=reuse_fp16_chunk, max_prefetch=max_prefetch) return chunk_manager diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 6f6064000626..9d6849daadc1 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -104,9 +104,7 @@ def __init__( self.enable_gradient_accumulation = enable_gradient_accumulation if chunk_config_dict is not None: self.chunk_manager = ChunkManager( - chunk_config_dict, - chunk_init_device, - reuse_fp16_chunk=reuse_fp16_chunk, + chunk_config_dict, chunk_init_device, reuse_fp16_chunk=reuse_fp16_chunk, max_prefetch=max_prefetch ) else: # some ugly hotfix for the compatibility with Lightning @@ -122,6 +120,7 @@ def __init__( process_group=zero_group, reuse_fp16_chunk=reuse_fp16_chunk, verbose=verbose, + max_prefetch=max_prefetch, ) self.gemini_manager = GeminiManager( placement_policy, diff --git a/colossalai/zero/gemini/gemini_hook.py b/colossalai/zero/gemini/gemini_hook.py index 736238a0992d..9e297c2a8a19 100644 --- a/colossalai/zero/gemini/gemini_hook.py +++ b/colossalai/zero/gemini/gemini_hook.py @@ -5,6 +5,7 @@ import torch +from colossalai.accelerator import get_accelerator from colossalai.tensor.param_op_hook import ColoParamOpHook from colossalai.utils import is_ddp_ignored from colossalai.zero.gemini import TensorState @@ -54,10 +55,11 @@ def pre_op(self, params): ) # prefetch - for chunk in chunks_fetch_async: - maybe_work = self._chunk_manager.access_chunk(chunk, async_access=True) - if maybe_work is not None: - self._gemini_manager.add_work(chunk, maybe_work) + with get_accelerator().stream(self._gemini_manager.chunk_manager._prefetch_stream): + for chunk in chunks_fetch_async: + maybe_work = self._chunk_manager.access_chunk(chunk, async_access=True) + if maybe_work is not None: + self._gemini_manager.add_work(chunk, maybe_work) # record cuda model data of the current OP, including memory for prefetched chunks self._gemini_manager.record_model_data_volume()