diff --git a/colossalai/zero/gemini/gemini_hook.py b/colossalai/zero/gemini/gemini_hook.py index 9e297c2a8a19..bf5faa0fe884 100644 --- a/colossalai/zero/gemini/gemini_hook.py +++ b/colossalai/zero/gemini/gemini_hook.py @@ -55,6 +55,15 @@ def pre_op(self, params): ) # prefetch + if self._gemini_manager.chunk_manager._prefetch_stream is not None: + # This is when prefetch happens the first time and there is no dist.Work to sync, + # there is possibility that the optimizer haven't finish computation on default stream, + # thus we might prefetch outdated chunks there. + # + # Other than that, self._gemini_manager.wait_chunks will have synced with default stream + # by calling dist.Work.wait() and this line makes no diff. + self._gemini_manager.chunk_manager._prefetch_stream.wait_stream(torch.cuda.current_stream()) + with get_accelerator().stream(self._gemini_manager.chunk_manager._prefetch_stream): for chunk in chunks_fetch_async: maybe_work = self._chunk_manager.access_chunk(chunk, async_access=True)