From 634d48387e9555d443fd6503d04f3f51c85c4097 Mon Sep 17 00:00:00 2001 From: xrsrke Date: Mon, 25 Sep 2023 11:06:02 +0700 Subject: [PATCH] add support to define which master rank should report progress in the pipeline's progress tracker --- pipegoose/nn/pipeline_parallel2/sync/handshake.py | 12 +++++++----- .../sync/test_progress_tracker.py | 10 ++++++---- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/pipegoose/nn/pipeline_parallel2/sync/handshake.py b/pipegoose/nn/pipeline_parallel2/sync/handshake.py index 6adeb0b..a268e7a 100644 --- a/pipegoose/nn/pipeline_parallel2/sync/handshake.py +++ b/pipegoose/nn/pipeline_parallel2/sync/handshake.py @@ -6,14 +6,18 @@ from pipegoose.distributed.parallel_context import ParallelContext from pipegoose.distributed.parallel_mode import ParallelMode +# NOTE: (microbatch_idx, partition_idx) Task = NewType("Task", Tuple[int, int]) class Handshake(ABC): + master_rank = None + parallel_context = None parallel_mode = None - def __init__(self, parallel_context: ParallelContext, parallel_mode: ParallelMode): + def __init__(self, master_rank: int, parallel_context: ParallelContext, parallel_mode: ParallelMode): + Handshake.master_rank = master_rank Handshake.parallel_context = parallel_context Handshake.parallel_mode = parallel_mode @@ -41,9 +45,6 @@ def is_all_confirmed(self, clock_idx: int) -> bool: class ProgressTracker(Handshake): """Pipeline parallelism's progress tracker.""" - # TODO: make this configurable - MASTER_RANK = 0 - progress = None clock_idx = None @@ -86,7 +87,8 @@ def is_all_confirmed(clock_idx: int) -> bool: def confirm(self, task: Task): # TODO: only non-scheduler ranks should confirm - master_worker_name = self.parallel_context.get_worker_name(self.MASTER_RANK) + global_master_rank = self.parallel_context.get_global_rank_from_local_rank(self.master_rank, self.parallel_mode) + master_worker_name = self.parallel_context.get_worker_name(global_master_rank) # rank = self.parallel_context.get_local_rank(self.parallel_mode) rpc.rpc_sync(master_worker_name, func=ProgressTracker._recv_confirm_from_worker, args=(task,)) diff --git a/tests/nn/pipeline_parallel_2/sync/test_progress_tracker.py b/tests/nn/pipeline_parallel_2/sync/test_progress_tracker.py index 6c54d91..622be9b 100644 --- a/tests/nn/pipeline_parallel_2/sync/test_progress_tracker.py +++ b/tests/nn/pipeline_parallel_2/sync/test_progress_tracker.py @@ -31,6 +31,7 @@ def schedules_to_progress(schedules): def run_init_progress_tracker(rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size): N_MICROBATCHES = 4 + MASTER_RANK = 0 schedules = get_gpipe_schedules(pipeline_parallel_size, N_MICROBATCHES) PROGRESS = schedules_to_progress(schedules) @@ -38,9 +39,9 @@ def run_init_progress_tracker(rank, world_size, port, tensor_parallel_size, pipe parallel_context = init_parallel_context( rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size ) - tracker = ProgressTracker(parallel_context, ParallelMode.GLOBAL) + tracker = ProgressTracker(MASTER_RANK, parallel_context, ParallelMode.GLOBAL) - if rank == tracker.MASTER_RANK: + if rank == tracker.master_rank: tracker.initiate(PROGRESS) assert tracker.is_initiated() is True assert tracker.progress == PROGRESS @@ -75,6 +76,7 @@ def test_init_progress_tracker(): def run_confirm_progress_tracker(rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size): N_MICROBATCHES = 4 MICROBATCH_IDX = 0 + MASTER_RANK = 0 schedules = get_gpipe_schedules(pipeline_parallel_size, N_MICROBATCHES) PROGRESS = schedules_to_progress(schedules) @@ -83,9 +85,9 @@ def run_confirm_progress_tracker(rank, world_size, port, tensor_parallel_size, p parallel_context = init_parallel_context( rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size ) - tracker = ProgressTracker(parallel_context, ParallelMode.GLOBAL) + tracker = ProgressTracker(MASTER_RANK, parallel_context, ParallelMode.GLOBAL) - if rank == tracker.MASTER_RANK: + if rank == tracker.master_rank: tracker.initiate(PROGRESS) # NOTE: wait until all workers are confirmed time.sleep(5)