Skip to content

Commit

Permalink
add support to define which master rank should report progress in the…
Browse files Browse the repository at this point in the history
… pipeline's progress tracker
  • Loading branch information
xrsrke committed Sep 25, 2023
1 parent 2cca21a commit 634d483
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 9 deletions.
12 changes: 7 additions & 5 deletions pipegoose/nn/pipeline_parallel2/sync/handshake.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,18 @@
from pipegoose.distributed.parallel_context import ParallelContext
from pipegoose.distributed.parallel_mode import ParallelMode

# NOTE: (microbatch_idx, partition_idx)
Task = NewType("Task", Tuple[int, int])


class Handshake(ABC):
master_rank = None

parallel_context = None
parallel_mode = None

def __init__(self, parallel_context: ParallelContext, parallel_mode: ParallelMode):
def __init__(self, master_rank: int, parallel_context: ParallelContext, parallel_mode: ParallelMode):
Handshake.master_rank = master_rank
Handshake.parallel_context = parallel_context
Handshake.parallel_mode = parallel_mode

Expand Down Expand Up @@ -41,9 +45,6 @@ def is_all_confirmed(self, clock_idx: int) -> bool:
class ProgressTracker(Handshake):
"""Pipeline parallelism's progress tracker."""

# TODO: make this configurable
MASTER_RANK = 0

progress = None
clock_idx = None

Expand Down Expand Up @@ -86,7 +87,8 @@ def is_all_confirmed(clock_idx: int) -> bool:

def confirm(self, task: Task):
# TODO: only non-scheduler ranks should confirm
master_worker_name = self.parallel_context.get_worker_name(self.MASTER_RANK)
global_master_rank = self.parallel_context.get_global_rank_from_local_rank(self.master_rank, self.parallel_mode)
master_worker_name = self.parallel_context.get_worker_name(global_master_rank)
# rank = self.parallel_context.get_local_rank(self.parallel_mode)
rpc.rpc_sync(master_worker_name, func=ProgressTracker._recv_confirm_from_worker, args=(task,))

Expand Down
10 changes: 6 additions & 4 deletions tests/nn/pipeline_parallel_2/sync/test_progress_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,17 @@ def schedules_to_progress(schedules):

def run_init_progress_tracker(rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size):
N_MICROBATCHES = 4
MASTER_RANK = 0

schedules = get_gpipe_schedules(pipeline_parallel_size, N_MICROBATCHES)
PROGRESS = schedules_to_progress(schedules)

parallel_context = init_parallel_context(
rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size
)
tracker = ProgressTracker(parallel_context, ParallelMode.GLOBAL)
tracker = ProgressTracker(MASTER_RANK, parallel_context, ParallelMode.GLOBAL)

if rank == tracker.MASTER_RANK:
if rank == tracker.master_rank:
tracker.initiate(PROGRESS)
assert tracker.is_initiated() is True
assert tracker.progress == PROGRESS
Expand Down Expand Up @@ -75,6 +76,7 @@ def test_init_progress_tracker():
def run_confirm_progress_tracker(rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size):
N_MICROBATCHES = 4
MICROBATCH_IDX = 0
MASTER_RANK = 0

schedules = get_gpipe_schedules(pipeline_parallel_size, N_MICROBATCHES)
PROGRESS = schedules_to_progress(schedules)
Expand All @@ -83,9 +85,9 @@ def run_confirm_progress_tracker(rank, world_size, port, tensor_parallel_size, p
parallel_context = init_parallel_context(
rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size
)
tracker = ProgressTracker(parallel_context, ParallelMode.GLOBAL)
tracker = ProgressTracker(MASTER_RANK, parallel_context, ParallelMode.GLOBAL)

if rank == tracker.MASTER_RANK:
if rank == tracker.master_rank:
tracker.initiate(PROGRESS)
# NOTE: wait until all workers are confirmed
time.sleep(5)
Expand Down

0 comments on commit 634d483

Please sign in to comment.