Skip to content

Commit

Permalink
refactor pipeline's progress tracker
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Sep 25, 2023
1 parent aaf7799 commit 2cca21a
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 17 deletions.
4 changes: 2 additions & 2 deletions pipegoose/nn/pipeline_parallel2/sync/handshake.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@ def is_initiated(self) -> bool:
raise NotImplementedError

@abstractclassmethod
def is_confirmed(self) -> bool:
def is_confirmed(self, clock_idx: int) -> bool:
raise NotImplementedError

@abstractclassmethod
def is_all_confirmed(self) -> bool:
def is_all_confirmed(self, clock_idx: int) -> bool:
raise NotImplementedError


Expand Down
65 changes: 50 additions & 15 deletions tests/nn/pipeline_parallel_2/sync/test_progress_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ def get_gpipe_schedules(n_partitions, n_microbatches):
for clock_idx in range(n_clock_cycles):
start_partrition = max(clock_idx + 1 - n_microbatches, 0)
end_partition = min(clock_idx + 1, n_partitions)

tasks = []
for partition_idx in range(start_partrition, end_partition):
microbatch_idx = clock_idx - partition_idx
Expand All @@ -30,13 +29,11 @@ def schedules_to_progress(schedules):
return {i: {item: False for item in sublist} for i, sublist in enumerate(schedules)}


def run_progress_tracker(rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size):
def run_init_progress_tracker(rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size):
N_MICROBATCHES = 4
MICROBATCH_IDX = 0

schedules = get_gpipe_schedules(pipeline_parallel_size, N_MICROBATCHES)
PROGRESS = schedules_to_progress(schedules)
INITIAL_PROGRESS = deepcopy(PROGRESS)

parallel_context = init_parallel_context(
rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size
Expand All @@ -48,10 +45,52 @@ def run_progress_tracker(rank, world_size, port, tensor_parallel_size, pipeline_
assert tracker.is_initiated() is True
assert tracker.progress == PROGRESS
assert tracker.clock_idx == 0
assert tracker.is_all_confirmed(clock_idx=0) is False
else:
# NOTE: wait until the tracker is initiated
time.sleep(2)
assert tracker.is_initiated() is True
# TODO: if haven't confirmed any task, clock_idx should be 0
assert tracker.progress == PROGRESS
assert tracker.clock_idx == 0

parallel_context.destroy()


def test_init_progress_tracker():
TENSOR_PARALLEL_SIZE = 2
PIPELINE_PARALLEL_SIZE = 2
DATA_PARALLEL_SIZE = 2
world_size = TENSOR_PARALLEL_SIZE * PIPELINE_PARALLEL_SIZE * DATA_PARALLEL_SIZE

spawn(
run_init_progress_tracker,
world_size=world_size,
tensor_parallel_size=TENSOR_PARALLEL_SIZE,
pipeline_parallel_size=PIPELINE_PARALLEL_SIZE,
data_parallel_size=DATA_PARALLEL_SIZE,
)


def run_confirm_progress_tracker(rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size):
N_MICROBATCHES = 4
MICROBATCH_IDX = 0

schedules = get_gpipe_schedules(pipeline_parallel_size, N_MICROBATCHES)
PROGRESS = schedules_to_progress(schedules)
INITIAL_PROGRESS = deepcopy(PROGRESS)

parallel_context = init_parallel_context(
rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size
)
tracker = ProgressTracker(parallel_context, ParallelMode.GLOBAL)

if rank == tracker.MASTER_RANK:
tracker.initiate(PROGRESS)
# NOTE: wait until all workers are confirmed
time.sleep(5)
assert tracker.is_all_confirmed(clock_idx=0) is True
assert tracker.is_all_confirmed(clock_idx=1) is False

# NOTE: after all workers are confirmed,
# the clock index should be incremented
Expand All @@ -60,19 +99,15 @@ def run_progress_tracker(rank, world_size, port, tensor_parallel_size, pipeline_
else:
# NOTE: wait until the tracker is initiated
time.sleep(2)
assert tracker.is_initiated() is True
# NOTE: other workers may updated the progress
# so the progress should be updated
# TODO: if haven't confirmed any task, clock_idx should be 0
# assert tracker.progress == PROGRESS
# assert handshake.clock_idx == 0

task = (MICROBATCH_IDX, get_partition_idx(parallel_context))
partition_idx = get_partition_idx(parallel_context)
task = (MICROBATCH_IDX, partition_idx)
tracker.confirm(task)
assert tracker.is_confirmed(task, 0) is True
assert tracker.is_confirmed(task, clock_idx=0) is True

# NOTE: wait until all workers are confirmed
time.sleep(5)
assert tracker.is_all_confirmed(clock_idx=0) is True
assert tracker.is_all_confirmed(clock_idx=1) is False
assert tracker.clock_idx == 1
assert tracker.progress != INITIAL_PROGRESS

Expand All @@ -82,11 +117,11 @@ def run_progress_tracker(rank, world_size, port, tensor_parallel_size, pipeline_
@pytest.mark.parametrize("tensor_parallel_size", [1, 2])
@pytest.mark.parametrize("pipeline_parallel_size", [2, 4])
@pytest.mark.parametrize("data_parallel_size", [1, 2])
def test_progress_tracker(tensor_parallel_size, pipeline_parallel_size, data_parallel_size):
def test_confirm_progress_tracker(tensor_parallel_size, pipeline_parallel_size, data_parallel_size):
world_size = tensor_parallel_size * pipeline_parallel_size * data_parallel_size

spawn(
run_progress_tracker,
run_confirm_progress_tracker,
world_size=world_size,
tensor_parallel_size=tensor_parallel_size,
pipeline_parallel_size=pipeline_parallel_size,
Expand Down

0 comments on commit 2cca21a

Please sign in to comment.