Skip to content

Commit

Permalink
fixed synchronization in the forward pass of the pipeline engine
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Sep 30, 2023
1 parent c9647a3 commit 821a5fb
Show file tree
Hide file tree
Showing 7 changed files with 44 additions and 38 deletions.
11 changes: 4 additions & 7 deletions pipegoose/distributed/parallel_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@


class ParallelContext:
"""Inspired from OSLO's parallel context:
"""
Inspired from OSLO's parallel context:
https://github.com/EleutherAI/oslo/blob/f16c73bc5893cd6cefe65e70acf6d88428a324e1/oslo/torch/distributed/parallel_context.py#L53
"""

Expand Down Expand Up @@ -105,7 +106,6 @@ def __init__(
self.set_device()

self.rpc_worker_map = {rank: WORKER_NAME.format(rank) for rank in self.get_ranks_in_group(ParallelMode.GLOBAL)}
# TODO: add initialize from torch launcher
self.init_rpc_workers(host, port)

# self.set_seed(seed)
Expand Down Expand Up @@ -199,7 +199,6 @@ def _register_dist(
self.add_local_rank(parallel_mode, local_rank)
self.add_world_size(parallel_mode, local_world_size)
self.add_group(parallel_mode, process_group)
# TODO: remove this
self.add_ranks_in_group(parallel_mode, ranks_in_group)

def set_device(self):
Expand All @@ -213,9 +212,8 @@ def set_seed(self, seed: int):
torch.manual_seed(seed)

# TODO: set GPU seed
if torch.cuda.is_available():
# parallel_seed = seed
pass
# if torch.cuda.is_available():
# pass

def is_initialized(self, parallel_mode: ParallelMode) -> bool:
"""Check if the parallel mode is initialized.
Expand Down Expand Up @@ -261,7 +259,6 @@ def add_ranks_in_group(self, parallel_mode: ParallelMode, ranks_in_group: List[i

def get_ranks_in_group(self, parallel_mode: ParallelMode) -> List[int]:
"""A list of global ranks in a given parallel mode of the local process."""
# return dist.get_process_group_ranks(self._groups[parallel_mode])
return self._ranks_in_group[parallel_mode]

def get_next_global_rank(self, parallel_mode: ParallelMode) -> int:
Expand Down
3 changes: 3 additions & 0 deletions pipegoose/nn/pipeline_parallel2/_job/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,6 @@ def after_compute(self):
# )
key = (microbatch_idx, partition_idx)
progress_tracker.confirm(key)
import time

time.sleep(3)
30 changes: 19 additions & 11 deletions pipegoose/nn/pipeline_parallel2/pipeline_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from dataclasses import dataclass

import torch
import torch.distributed as dist
from torch import nn

from pipegoose.distributed.parallel_context import ParallelContext
Expand Down Expand Up @@ -86,29 +87,35 @@ def after_new_clock_cycle(self, progress, clock_idx):
MASTER_RANK, callbacks=callbacks, parallel_context=self.parallel_context, parallel_mode=ParallelMode.GLOBAL
)
# NOTE: wait for all ranks to be initiated
dist.barrier()
time.sleep(1)

if self.parallel_context.is_first_rank(ParallelMode.PIPELINE):
# if self.parallel_context.is_first_rank(ParallelMode.PIPELINE):
if self.parallel_context.get_global_rank() == 0:
schedules = self.pipeline_context.schedules
progress = {
i: {(item.microbatch_idx, item.partition_idx): False for item in sublist}
for i, sublist in enumerate(schedules)
}
progress_tracker.initiate(progress)
print(progress)

time.sleep(1)
dist.barrier()
time.sleep(5)

set_progress_tracker(progress_tracker)

time.sleep(1)
dist.barrier()
time.sleep(2)

# from hanging_threads import start_monitoring
# monitoring_thread = start_monitoring()

for tasks in self.pipeline_context.get_schedule():

time.sleep(2)
rank = self.parallel_context.get_local_rank(ParallelMode.GLOBAL)
dist.barrier()

rank = self.parallel_context.get_global_rank()
partition_idx = self.pipeline_context.partition_idx

if rank == 0:
Expand All @@ -118,7 +125,7 @@ def after_new_clock_cycle(self, progress, clock_idx):
assert 1 == 1

if len(tasks) > 0:
print(f"[enter look] clock_idx={self.pipeline_context.clock_idx}, rank={rank}, partition_idx={partition_idx}")
# print(f"[enter look] clock_idx={self.pipeline_context.clock_idx}, rank={rank}, partition_idx={partition_idx}")
for task in tasks:
microbatch_idx = task.microbatch_idx
partition_idx = task.partition_idx
Expand All @@ -129,17 +136,18 @@ def after_new_clock_cycle(self, progress, clock_idx):
else:
package = RECV_QUEUE.get()

print(
f"[received a package]clock_idx={self.pipeline_context.clock_idx}, rank={rank}, partition_idx={partition_idx}",
package.metadata,
)
# print(
# f"[received a package]clock_idx={self.pipeline_context.clock_idx}, rank={rank}, partition_idx={partition_idx}",
# package.metadata,
# )

job = create_job(self.partition_func, package, self.pipeline_context)

# print(f"created a job: {package.metadata}")

JobQueue.PENDING_JOBS.put(job)
time.sleep(2)

dist.barrier()

# def _retrieve_package_from_received_package(self, microbatch_idx, partition_idx):
# # package = RECV_QUEUE[(microbatch_idx, partition_idx)]
Expand Down
15 changes: 8 additions & 7 deletions pipegoose/nn/pipeline_parallel2/sync/handshake.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,37 +95,38 @@ def is_initiated(self) -> bool:

def initiate(self, progress: Progress):
INITIAL_CLOCK_IDX = 0
ProgressTracker._broadcast_tasks(progress, clock_idx=INITIAL_CLOCK_IDX)
ProgressTracker._broadcast_tasks(progress, clock_idx=INITIAL_CLOCK_IDX, is_init=True)
ProgressTracker.progress = progress
ProgressTracker.clock_idx = INITIAL_CLOCK_IDX

@staticmethod
def _broadcast_tasks(progress, clock_idx):
def _broadcast_tasks(progress, clock_idx, is_init=False):
parallel_context = ProgressTracker.parallel_context
parallel_mode = ProgressTracker.parallel_mode

local_rank = parallel_context.get_local_rank(parallel_mode)
local_world_size = parallel_context.get_world_size(parallel_mode)

for local_dst in range(local_world_size):
if local_dst == local_rank:
if local_dst == local_rank and is_init is False:
# NOTE: since we skip the master node, we need to manually run the callback
ProgressTracker._run_callback("after_new_clock_cycle", progress=progress, clock_idx=clock_idx)
continue

global_dst = parallel_context.get_global_rank_from_local_rank(local_dst, parallel_mode)
worker_name = parallel_context.get_worker_name(global_dst)
rpc.rpc_sync(to=worker_name, func=ProgressTracker._recv_tasks, args=(progress, clock_idx))
rpc.rpc_sync(to=worker_name, func=ProgressTracker._recv_tasks, args=(progress, clock_idx, is_init))

@staticmethod
def _recv_tasks(progress: Progress, clock_idx: int):
def _recv_tasks(progress: Progress, clock_idx: int, is_init):
with ProgressTracker.update_progress_lock:
ProgressTracker.progress = progress
ProgressTracker.clock_idx = clock_idx

# NOTE: don't increase a new clock cycle if just initializing it
# NOTE: after a worker node receives the progress, it should run the callback
ProgressTracker._run_callback("after_new_clock_cycle", progress=progress, clock_idx=clock_idx)
if is_init is False:
ProgressTracker._run_callback("after_new_clock_cycle", progress=progress, clock_idx=clock_idx)

def is_confirmed(self, task: Task, clock_idx: int) -> bool:
return self.progress[clock_idx][task] is True
Expand Down Expand Up @@ -166,4 +167,4 @@ def _update_local_progress(task: Task):
if ProgressTracker.is_all_confirmed(clock_idx) is True:
NEXT_CLOCK_IDX = clock_idx + 1
ProgressTracker.clock_idx = NEXT_CLOCK_IDX
ProgressTracker._broadcast_tasks(ProgressTracker.progress, clock_idx=NEXT_CLOCK_IDX)
ProgressTracker._broadcast_tasks(ProgressTracker.progress, clock_idx=NEXT_CLOCK_IDX, is_init=False)
2 changes: 1 addition & 1 deletion tests/distributed/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,5 +199,5 @@ def test_all_reduce(world_size, tensor_parallel_size, pipeline_parallel_size, da


@pytest.mark.skip(reason="not implemented")
def test_reduce_scatter(parallel_context):
def test_reduce_scatter():
pass
2 changes: 1 addition & 1 deletion tests/nn/pipeline_parallel_2/run_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,9 @@ def forward(self, input):


if __name__ == "__main__":
DATA_PARALLEL_SIZE = 1
TENSOR_PARALLEL_SIZE = 1
PIPELINE_PARALLEL_SIZE = 4
DATA_PARALLEL_SIZE = 1

WORLD_SIZE = PIPELINE_PARALLEL_SIZE * DATA_PARALLEL_SIZE * TENSOR_PARALLEL_SIZE

Expand Down
19 changes: 8 additions & 11 deletions tests/nn/pipeline_parallel_2/sync/test_progress_tracker.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import time
from copy import deepcopy
from typing import Dict

import pytest
import torch.distributed as dist

from pipegoose.distributed.parallel_mode import ParallelMode
from pipegoose.nn.pipeline_parallel2.sync.callback import Callback
Expand Down Expand Up @@ -34,7 +34,8 @@ def run_init_progress_tracker(rank, world_size, port, tensor_parallel_size, pipe
tracker.initiate(PROGRESS)

# NOTE: wait until the tracker is initiated
time.sleep(0.1)
dist.barrier()

assert tracker.is_initiated() is True
assert tracker.clock_idx == 0
assert tracker.is_all_confirmed(clock_idx=0) is False
Expand Down Expand Up @@ -73,14 +74,14 @@ def run_confirm_progress_tracker(rank, world_size, port, tensor_parallel_size, p
tracker.initiate(PROGRESS)

# NOTE: wait until the tracker is initiated
time.sleep(2)
dist.barrier()

for clock_idx in range(N_CLOCK_CYCLES):
tracker.confirm(rank)
assert tracker.is_confirmed(rank, clock_idx=clock_idx) is True

# NOTE: wait until all workers are confirmed
time.sleep(2)
dist.barrier()
assert tracker.is_all_confirmed(clock_idx=clock_idx) is True

if not (clock_idx == N_CLOCK_CYCLES - 1):
Expand All @@ -89,8 +90,6 @@ def run_confirm_progress_tracker(rank, world_size, port, tensor_parallel_size, p
assert tracker.clock_idx == clock_idx + 1
assert tracker.progress != INITIAL_PROGRESS

time.sleep(0.1)

assert tracker.progress == FINAL_PROGRESS

parallel_context.destroy()
Expand Down Expand Up @@ -129,15 +128,13 @@ def after_new_clock_cycle(self, progress: Dict, clock_idx: int):
tracker.initiate(PROGRESS)

# NOTE: wait until the tracker is initiated
time.sleep(0.5)
assert QUEUE == [rank]

dist.barrier()
tracker.confirm(rank)

# NOTE: wait until all workers are confirmed
# callback should be called again after all workers are confirmed
time.sleep(0.5)
assert QUEUE == [rank, rank]
dist.barrier()
assert QUEUE == [rank]

parallel_context.destroy()

Expand Down

0 comments on commit 821a5fb

Please sign in to comment.