diff --git a/pipegoose/distributed/parallel_context.py b/pipegoose/distributed/parallel_context.py index ced0095..425acb3 100644 --- a/pipegoose/distributed/parallel_context.py +++ b/pipegoose/distributed/parallel_context.py @@ -22,7 +22,8 @@ class ParallelContext: - """Inspired from OSLO's parallel context: + """ + Inspired from OSLO's parallel context: https://github.com/EleutherAI/oslo/blob/f16c73bc5893cd6cefe65e70acf6d88428a324e1/oslo/torch/distributed/parallel_context.py#L53 """ @@ -105,7 +106,6 @@ def __init__( self.set_device() self.rpc_worker_map = {rank: WORKER_NAME.format(rank) for rank in self.get_ranks_in_group(ParallelMode.GLOBAL)} - # TODO: add initialize from torch launcher self.init_rpc_workers(host, port) # self.set_seed(seed) @@ -199,7 +199,6 @@ def _register_dist( self.add_local_rank(parallel_mode, local_rank) self.add_world_size(parallel_mode, local_world_size) self.add_group(parallel_mode, process_group) - # TODO: remove this self.add_ranks_in_group(parallel_mode, ranks_in_group) def set_device(self): @@ -213,9 +212,8 @@ def set_seed(self, seed: int): torch.manual_seed(seed) # TODO: set GPU seed - if torch.cuda.is_available(): - # parallel_seed = seed - pass + # if torch.cuda.is_available(): + # pass def is_initialized(self, parallel_mode: ParallelMode) -> bool: """Check if the parallel mode is initialized. @@ -261,7 +259,6 @@ def add_ranks_in_group(self, parallel_mode: ParallelMode, ranks_in_group: List[i def get_ranks_in_group(self, parallel_mode: ParallelMode) -> List[int]: """A list of global ranks in a given parallel mode of the local process.""" - # return dist.get_process_group_ranks(self._groups[parallel_mode]) return self._ranks_in_group[parallel_mode] def get_next_global_rank(self, parallel_mode: ParallelMode) -> int: diff --git a/pipegoose/nn/pipeline_parallel2/_job/forward.py b/pipegoose/nn/pipeline_parallel2/_job/forward.py index cb6e330..cc195e6 100644 --- a/pipegoose/nn/pipeline_parallel2/_job/forward.py +++ b/pipegoose/nn/pipeline_parallel2/_job/forward.py @@ -116,3 +116,6 @@ def after_compute(self): # ) key = (microbatch_idx, partition_idx) progress_tracker.confirm(key) + import time + + time.sleep(3) diff --git a/pipegoose/nn/pipeline_parallel2/pipeline_engine.py b/pipegoose/nn/pipeline_parallel2/pipeline_engine.py index 722988a..7c220a9 100644 --- a/pipegoose/nn/pipeline_parallel2/pipeline_engine.py +++ b/pipegoose/nn/pipeline_parallel2/pipeline_engine.py @@ -2,6 +2,7 @@ from dataclasses import dataclass import torch +import torch.distributed as dist from torch import nn from pipegoose.distributed.parallel_context import ParallelContext @@ -86,29 +87,35 @@ def after_new_clock_cycle(self, progress, clock_idx): MASTER_RANK, callbacks=callbacks, parallel_context=self.parallel_context, parallel_mode=ParallelMode.GLOBAL ) # NOTE: wait for all ranks to be initiated + dist.barrier() time.sleep(1) - if self.parallel_context.is_first_rank(ParallelMode.PIPELINE): + # if self.parallel_context.is_first_rank(ParallelMode.PIPELINE): + if self.parallel_context.get_global_rank() == 0: schedules = self.pipeline_context.schedules progress = { i: {(item.microbatch_idx, item.partition_idx): False for item in sublist} for i, sublist in enumerate(schedules) } progress_tracker.initiate(progress) + print(progress) - time.sleep(1) + dist.barrier() + time.sleep(5) set_progress_tracker(progress_tracker) - time.sleep(1) + dist.barrier() + time.sleep(2) # from hanging_threads import start_monitoring # monitoring_thread = start_monitoring() for tasks in self.pipeline_context.get_schedule(): - time.sleep(2) - rank = self.parallel_context.get_local_rank(ParallelMode.GLOBAL) + dist.barrier() + + rank = self.parallel_context.get_global_rank() partition_idx = self.pipeline_context.partition_idx if rank == 0: @@ -118,7 +125,7 @@ def after_new_clock_cycle(self, progress, clock_idx): assert 1 == 1 if len(tasks) > 0: - print(f"[enter look] clock_idx={self.pipeline_context.clock_idx}, rank={rank}, partition_idx={partition_idx}") + # print(f"[enter look] clock_idx={self.pipeline_context.clock_idx}, rank={rank}, partition_idx={partition_idx}") for task in tasks: microbatch_idx = task.microbatch_idx partition_idx = task.partition_idx @@ -129,17 +136,18 @@ def after_new_clock_cycle(self, progress, clock_idx): else: package = RECV_QUEUE.get() - print( - f"[received a package]clock_idx={self.pipeline_context.clock_idx}, rank={rank}, partition_idx={partition_idx}", - package.metadata, - ) + # print( + # f"[received a package]clock_idx={self.pipeline_context.clock_idx}, rank={rank}, partition_idx={partition_idx}", + # package.metadata, + # ) job = create_job(self.partition_func, package, self.pipeline_context) # print(f"created a job: {package.metadata}") JobQueue.PENDING_JOBS.put(job) - time.sleep(2) + + dist.barrier() # def _retrieve_package_from_received_package(self, microbatch_idx, partition_idx): # # package = RECV_QUEUE[(microbatch_idx, partition_idx)] diff --git a/pipegoose/nn/pipeline_parallel2/sync/handshake.py b/pipegoose/nn/pipeline_parallel2/sync/handshake.py index 6303669..ee6b792 100644 --- a/pipegoose/nn/pipeline_parallel2/sync/handshake.py +++ b/pipegoose/nn/pipeline_parallel2/sync/handshake.py @@ -95,12 +95,12 @@ def is_initiated(self) -> bool: def initiate(self, progress: Progress): INITIAL_CLOCK_IDX = 0 - ProgressTracker._broadcast_tasks(progress, clock_idx=INITIAL_CLOCK_IDX) + ProgressTracker._broadcast_tasks(progress, clock_idx=INITIAL_CLOCK_IDX, is_init=True) ProgressTracker.progress = progress ProgressTracker.clock_idx = INITIAL_CLOCK_IDX @staticmethod - def _broadcast_tasks(progress, clock_idx): + def _broadcast_tasks(progress, clock_idx, is_init=False): parallel_context = ProgressTracker.parallel_context parallel_mode = ProgressTracker.parallel_mode @@ -108,24 +108,25 @@ def _broadcast_tasks(progress, clock_idx): local_world_size = parallel_context.get_world_size(parallel_mode) for local_dst in range(local_world_size): - if local_dst == local_rank: + if local_dst == local_rank and is_init is False: # NOTE: since we skip the master node, we need to manually run the callback ProgressTracker._run_callback("after_new_clock_cycle", progress=progress, clock_idx=clock_idx) continue global_dst = parallel_context.get_global_rank_from_local_rank(local_dst, parallel_mode) worker_name = parallel_context.get_worker_name(global_dst) - rpc.rpc_sync(to=worker_name, func=ProgressTracker._recv_tasks, args=(progress, clock_idx)) + rpc.rpc_sync(to=worker_name, func=ProgressTracker._recv_tasks, args=(progress, clock_idx, is_init)) @staticmethod - def _recv_tasks(progress: Progress, clock_idx: int): + def _recv_tasks(progress: Progress, clock_idx: int, is_init): with ProgressTracker.update_progress_lock: ProgressTracker.progress = progress ProgressTracker.clock_idx = clock_idx # NOTE: don't increase a new clock cycle if just initializing it # NOTE: after a worker node receives the progress, it should run the callback - ProgressTracker._run_callback("after_new_clock_cycle", progress=progress, clock_idx=clock_idx) + if is_init is False: + ProgressTracker._run_callback("after_new_clock_cycle", progress=progress, clock_idx=clock_idx) def is_confirmed(self, task: Task, clock_idx: int) -> bool: return self.progress[clock_idx][task] is True @@ -166,4 +167,4 @@ def _update_local_progress(task: Task): if ProgressTracker.is_all_confirmed(clock_idx) is True: NEXT_CLOCK_IDX = clock_idx + 1 ProgressTracker.clock_idx = NEXT_CLOCK_IDX - ProgressTracker._broadcast_tasks(ProgressTracker.progress, clock_idx=NEXT_CLOCK_IDX) + ProgressTracker._broadcast_tasks(ProgressTracker.progress, clock_idx=NEXT_CLOCK_IDX, is_init=False) diff --git a/tests/distributed/test_functional.py b/tests/distributed/test_functional.py index 665647f..6547e17 100644 --- a/tests/distributed/test_functional.py +++ b/tests/distributed/test_functional.py @@ -199,5 +199,5 @@ def test_all_reduce(world_size, tensor_parallel_size, pipeline_parallel_size, da @pytest.mark.skip(reason="not implemented") -def test_reduce_scatter(parallel_context): +def test_reduce_scatter(): pass diff --git a/tests/nn/pipeline_parallel_2/run_engine.py b/tests/nn/pipeline_parallel_2/run_engine.py index f918421..5c28fdf 100644 --- a/tests/nn/pipeline_parallel_2/run_engine.py +++ b/tests/nn/pipeline_parallel_2/run_engine.py @@ -84,9 +84,9 @@ def forward(self, input): if __name__ == "__main__": - DATA_PARALLEL_SIZE = 1 TENSOR_PARALLEL_SIZE = 1 PIPELINE_PARALLEL_SIZE = 4 + DATA_PARALLEL_SIZE = 1 WORLD_SIZE = PIPELINE_PARALLEL_SIZE * DATA_PARALLEL_SIZE * TENSOR_PARALLEL_SIZE diff --git a/tests/nn/pipeline_parallel_2/sync/test_progress_tracker.py b/tests/nn/pipeline_parallel_2/sync/test_progress_tracker.py index 44b9b26..dab10da 100644 --- a/tests/nn/pipeline_parallel_2/sync/test_progress_tracker.py +++ b/tests/nn/pipeline_parallel_2/sync/test_progress_tracker.py @@ -1,8 +1,8 @@ -import time from copy import deepcopy from typing import Dict import pytest +import torch.distributed as dist from pipegoose.distributed.parallel_mode import ParallelMode from pipegoose.nn.pipeline_parallel2.sync.callback import Callback @@ -34,7 +34,8 @@ def run_init_progress_tracker(rank, world_size, port, tensor_parallel_size, pipe tracker.initiate(PROGRESS) # NOTE: wait until the tracker is initiated - time.sleep(0.1) + dist.barrier() + assert tracker.is_initiated() is True assert tracker.clock_idx == 0 assert tracker.is_all_confirmed(clock_idx=0) is False @@ -73,14 +74,14 @@ def run_confirm_progress_tracker(rank, world_size, port, tensor_parallel_size, p tracker.initiate(PROGRESS) # NOTE: wait until the tracker is initiated - time.sleep(2) + dist.barrier() for clock_idx in range(N_CLOCK_CYCLES): tracker.confirm(rank) assert tracker.is_confirmed(rank, clock_idx=clock_idx) is True # NOTE: wait until all workers are confirmed - time.sleep(2) + dist.barrier() assert tracker.is_all_confirmed(clock_idx=clock_idx) is True if not (clock_idx == N_CLOCK_CYCLES - 1): @@ -89,8 +90,6 @@ def run_confirm_progress_tracker(rank, world_size, port, tensor_parallel_size, p assert tracker.clock_idx == clock_idx + 1 assert tracker.progress != INITIAL_PROGRESS - time.sleep(0.1) - assert tracker.progress == FINAL_PROGRESS parallel_context.destroy() @@ -129,15 +128,13 @@ def after_new_clock_cycle(self, progress: Dict, clock_idx: int): tracker.initiate(PROGRESS) # NOTE: wait until the tracker is initiated - time.sleep(0.5) - assert QUEUE == [rank] - + dist.barrier() tracker.confirm(rank) # NOTE: wait until all workers are confirmed # callback should be called again after all workers are confirmed - time.sleep(0.5) - assert QUEUE == [rank, rank] + dist.barrier() + assert QUEUE == [rank] parallel_context.destroy()