Skip to content

Commit

Permalink
add broadcasting the current progress in the pipeline after all worke…
Browse files Browse the repository at this point in the history
…rs finish a clock cycle
  • Loading branch information
xrsrke committed Sep 24, 2023
1 parent 22b1544 commit 6f4a1d9
Show file tree
Hide file tree
Showing 3 changed files with 208 additions and 224 deletions.
114 changes: 48 additions & 66 deletions pipegoose/nn/pipeline_parallel2/sync/handshake.py
Original file line number Diff line number Diff line change
@@ -1,128 +1,110 @@
import random
from abc import ABC, abstractclassmethod
from dataclasses import dataclass
from queue import Queue
from time import sleep
from typing import Dict, List
from typing import NewType, Tuple

import torch.distributed.rpc as rpc

from pipegoose.distributed.parallel_context import ParallelContext
from pipegoose.distributed.parallel_mode import ParallelMode

SessionId = str
Task = NewType("Task", Tuple[int, int])


class Handshake(ABC):
def __init__(self, parallel_context: ParallelContext, parallel_mode: ParallelMode):
self.parallel_context = parallel_context
self.parallel_mode = parallel_mode

self._session_id: str = None
self._queue: Dict[SessionId, Queue] = {}
self._ranks_confirmed: Dict[SessionId, List[int]] = set()
parallel_context = None
parallel_mode = None

self._data = None
self._clock_idx = 0
def __init__(self, parallel_context: ParallelContext, parallel_mode: ParallelMode):
Handshake.parallel_context = parallel_context
Handshake.parallel_mode = parallel_mode

@abstractclassmethod
def initiate(self):
raise NotImplementedError

def _generate_session_id(self) -> int:
return random.randint(0, 9999)

@abstractclassmethod
def confirm(self):
raise NotImplementedError

@abstractclassmethod
def is_initiated(self):
raise NotImplementedError

@abstractclassmethod
def is_confirmed(self):
def is_initiated(self) -> bool:
raise NotImplementedError

@abstractclassmethod
def is_all_confirmed(self):
def is_confirmed(self) -> bool:
raise NotImplementedError

@abstractclassmethod
def wait_until_all_confirmed(self):
def is_all_confirmed(self) -> bool:
raise NotImplementedError


@dataclass
class SessionMetadata:
clock_idx: int
parallel_mode: ParallelMode

class ProgressTracker(Handshake):
"""Pipeline parallelism's progress tracker."""

class SchedulerHandshake(Handshake):
NUM_SECONDS_IDLE = 0.5
# TODO: make this configurable
MASTER_RANK = 0

progress = None
clock_idx = 0
clock_idx = None

def is_initiated(self) -> bool:
return self.progress is not None

def initiate(self, data):
rank = self.parallel_context.get_local_rank(self.parallel_mode)
world_size = self.parallel_context.get_world_size(self.parallel_mode)
def initiate(self, progress):
INITIAL_CLOCK_IDX = 0
ProgressTracker._broadcast_tasks(progress)
ProgressTracker._recv_tasks(progress, clock_idx=INITIAL_CLOCK_IDX)

@staticmethod
def _broadcast_tasks(progress):
parallel_context = ProgressTracker.parallel_context
parallel_mode = ProgressTracker.parallel_mode
clock_idx = ProgressTracker.clock_idx

rank = parallel_context.get_local_rank(parallel_mode)
world_size = parallel_context.get_world_size(parallel_mode)

for dst in range(world_size):
if dst == rank:
continue

worker_name = self.parallel_context.get_worker_name(dst)
rpc.rpc_sync(to=worker_name, func=SchedulerHandshake._recv_execution_plan, args=(data,))

SchedulerHandshake._recv_execution_plan(data)
worker_name = parallel_context.get_worker_name(dst)
rpc.rpc_sync(to=worker_name, func=ProgressTracker._recv_tasks, args=(progress, clock_idx))

@staticmethod
def _recv_execution_plan(data):
SchedulerHandshake.progress = data
def _recv_tasks(progress, clock_idx):
ProgressTracker.progress = progress
ProgressTracker.clock_idx = clock_idx

def is_confirmed(self, task, clock_idx: int) -> bool:
def is_confirmed(self, task: Task, clock_idx: int) -> bool:
return self.progress[clock_idx][task] is True

@staticmethod
def is_all_confirmed(clock_idx: int) -> bool:
progress = SchedulerHandshake.progress
progress = ProgressTracker.progress
return all([progress[clock_idx][task] is True for task in progress[clock_idx]])

def confirm(self, task):
def confirm(self, task: Task):
# TODO: only non-scheduler ranks should confirm
master_worker_name = self.parallel_context.get_worker_name(self.MASTER_RANK)
# rank = self.parallel_context.get_local_rank(self.parallel_mode)
rpc.rpc_sync(master_worker_name, func=SchedulerHandshake._recv_confirm_from_worker, args=(task,))
rpc.rpc_sync(master_worker_name, func=ProgressTracker._recv_confirm_from_worker, args=(task,))

# NOTE: a worker node should confirm itself
SchedulerHandshake._update_progress(task)
ProgressTracker._update_progress(task)

@staticmethod
def _update_progress(task):
clock_idx = SchedulerHandshake.clock_idx
progress = SchedulerHandshake.progress
def _update_progress(task: Task):
clock_idx = ProgressTracker.clock_idx
progress = ProgressTracker.progress
progress[clock_idx][task] = True

@staticmethod
def _recv_confirm_from_worker(task):
SchedulerHandshake._update_progress(task)

clock_idx = SchedulerHandshake.clock_idx
if SchedulerHandshake.is_all_confirmed(clock_idx) is True:
SchedulerHandshake.clock_idx += 1

def wait_until_all_confirmed(self):
if self.parallel_context.is_first_rank() is True:
while True:
if self.is_all_confirmed() is True:
break
else:
sleep(self.NUM_SECONDS_IDLE)
else:
pass
def _recv_confirm_from_worker(task: Task):
ProgressTracker._update_progress(task)

clock_idx = ProgressTracker.clock_idx
if ProgressTracker.is_all_confirmed(clock_idx) is True:
ProgressTracker.clock_idx += 1
# broadcast the progress to all worker nodes
ProgressTracker._broadcast_tasks(ProgressTracker.progress)
34 changes: 18 additions & 16 deletions tests/nn/test_die_trying.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from pipegoose.distributed.parallel_mode import ParallelMode
from pipegoose.nn.pipeline_parallel2._utils import get_partition_idx
from pipegoose.nn.pipeline_parallel2.sync.handshake import SchedulerHandshake
from pipegoose.nn.pipeline_parallel2.sync.handshake import ProgressTracker
from pipegoose.testing.utils import init_parallel_context, spawn


Expand Down Expand Up @@ -36,55 +36,57 @@ def run_send_rcv_rpc(rank, world_size, port, tensor_parallel_size, pipeline_para

schedules = get_gpipe_schedules(pipeline_parallel_size, N_MICROBATCHES)
PROGRESS = schedules_to_progress(schedules)
INITIAL_PROGRESS = deepcopy(PROGRESS)

parallel_context = init_parallel_context(
rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size
)
handshake = SchedulerHandshake(parallel_context, ParallelMode.GLOBAL)
handshake = ProgressTracker(parallel_context, ParallelMode.GLOBAL)

if rank == SchedulerHandshake.MASTER_RANK:
if rank == handshake.MASTER_RANK:
handshake.initiate(PROGRESS)
assert handshake.is_initiated() is True
assert handshake.progress == PROGRESS
assert handshake.clock_idx == 0

# NOTE: wait until all workers are confirmed
time.sleep(5)
assert SchedulerHandshake.is_all_confirmed(clock_idx=0) is True
assert handshake.is_all_confirmed(clock_idx=0) is True

# NOTE: after all workers are confirmed,
# the clock index should be incremented
assert handshake.clock_idx == 1
assert handshake.progress != INITIAL_PROGRESS
else:
# NOTE: wait until the handshake is initiated
time.sleep(2)
assert handshake.is_initiated() is True
assert handshake.progress == PROGRESS
assert handshake.clock_idx == 0
# TODO: if haven't confirmed any task, clock_idx should be 0
# assert handshake.clock_idx == 0

PREV_CLOCK_IDX = deepcopy(handshake.clock_idx)
task = (MICROBATCH_IDX, get_partition_idx(parallel_context))
handshake.confirm(task)
assert handshake.is_confirmed(task, PREV_CLOCK_IDX) is True
assert handshake.is_confirmed(task, 0) is True

# NOTE: wait until all workers are confirmed
# time.sleep(5)
# assert handshake.clock_idx == PREV_CLOCK_IDX + 1
time.sleep(5)
assert handshake.clock_idx == 1
assert handshake.progress != INITIAL_PROGRESS

parallel_context.destroy()


@pytest.mark.parametrize("tensor_parallel_size", [2])
@pytest.mark.parametrize("pipeline_parallel_size", [2])
def test_send_rcv_rpc(tensor_parallel_size, pipeline_parallel_size):
DATA_PARALLEL_SIZE = 1

world_size = tensor_parallel_size * pipeline_parallel_size * DATA_PARALLEL_SIZE
@pytest.mark.parametrize("tensor_parallel_size", [1, 2])
@pytest.mark.parametrize("pipeline_parallel_size", [2, 4])
@pytest.mark.parametrize("data_parallel_size", [1, 2])
def test_send_rcv_rpc(tensor_parallel_size, pipeline_parallel_size, data_parallel_size):
world_size = tensor_parallel_size * pipeline_parallel_size * data_parallel_size

spawn(
run_send_rcv_rpc,
world_size=world_size,
tensor_parallel_size=tensor_parallel_size,
pipeline_parallel_size=pipeline_parallel_size,
data_parallel_size=DATA_PARALLEL_SIZE,
data_parallel_size=data_parallel_size,
)
Loading

0 comments on commit 6f4a1d9

Please sign in to comment.