-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add broadcasting the current progress in the pipeline after all worke…
…rs finish a clock cycle
- Loading branch information
Showing
3 changed files
with
208 additions
and
224 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,128 +1,110 @@ | ||
import random | ||
from abc import ABC, abstractclassmethod | ||
from dataclasses import dataclass | ||
from queue import Queue | ||
from time import sleep | ||
from typing import Dict, List | ||
from typing import NewType, Tuple | ||
|
||
import torch.distributed.rpc as rpc | ||
|
||
from pipegoose.distributed.parallel_context import ParallelContext | ||
from pipegoose.distributed.parallel_mode import ParallelMode | ||
|
||
SessionId = str | ||
Task = NewType("Task", Tuple[int, int]) | ||
|
||
|
||
class Handshake(ABC): | ||
def __init__(self, parallel_context: ParallelContext, parallel_mode: ParallelMode): | ||
self.parallel_context = parallel_context | ||
self.parallel_mode = parallel_mode | ||
|
||
self._session_id: str = None | ||
self._queue: Dict[SessionId, Queue] = {} | ||
self._ranks_confirmed: Dict[SessionId, List[int]] = set() | ||
parallel_context = None | ||
parallel_mode = None | ||
|
||
self._data = None | ||
self._clock_idx = 0 | ||
def __init__(self, parallel_context: ParallelContext, parallel_mode: ParallelMode): | ||
Handshake.parallel_context = parallel_context | ||
Handshake.parallel_mode = parallel_mode | ||
|
||
@abstractclassmethod | ||
def initiate(self): | ||
raise NotImplementedError | ||
|
||
def _generate_session_id(self) -> int: | ||
return random.randint(0, 9999) | ||
|
||
@abstractclassmethod | ||
def confirm(self): | ||
raise NotImplementedError | ||
|
||
@abstractclassmethod | ||
def is_initiated(self): | ||
raise NotImplementedError | ||
|
||
@abstractclassmethod | ||
def is_confirmed(self): | ||
def is_initiated(self) -> bool: | ||
raise NotImplementedError | ||
|
||
@abstractclassmethod | ||
def is_all_confirmed(self): | ||
def is_confirmed(self) -> bool: | ||
raise NotImplementedError | ||
|
||
@abstractclassmethod | ||
def wait_until_all_confirmed(self): | ||
def is_all_confirmed(self) -> bool: | ||
raise NotImplementedError | ||
|
||
|
||
@dataclass | ||
class SessionMetadata: | ||
clock_idx: int | ||
parallel_mode: ParallelMode | ||
|
||
class ProgressTracker(Handshake): | ||
"""Pipeline parallelism's progress tracker.""" | ||
|
||
class SchedulerHandshake(Handshake): | ||
NUM_SECONDS_IDLE = 0.5 | ||
# TODO: make this configurable | ||
MASTER_RANK = 0 | ||
|
||
progress = None | ||
clock_idx = 0 | ||
clock_idx = None | ||
|
||
def is_initiated(self) -> bool: | ||
return self.progress is not None | ||
|
||
def initiate(self, data): | ||
rank = self.parallel_context.get_local_rank(self.parallel_mode) | ||
world_size = self.parallel_context.get_world_size(self.parallel_mode) | ||
def initiate(self, progress): | ||
INITIAL_CLOCK_IDX = 0 | ||
ProgressTracker._broadcast_tasks(progress) | ||
ProgressTracker._recv_tasks(progress, clock_idx=INITIAL_CLOCK_IDX) | ||
|
||
@staticmethod | ||
def _broadcast_tasks(progress): | ||
parallel_context = ProgressTracker.parallel_context | ||
parallel_mode = ProgressTracker.parallel_mode | ||
clock_idx = ProgressTracker.clock_idx | ||
|
||
rank = parallel_context.get_local_rank(parallel_mode) | ||
world_size = parallel_context.get_world_size(parallel_mode) | ||
|
||
for dst in range(world_size): | ||
if dst == rank: | ||
continue | ||
|
||
worker_name = self.parallel_context.get_worker_name(dst) | ||
rpc.rpc_sync(to=worker_name, func=SchedulerHandshake._recv_execution_plan, args=(data,)) | ||
|
||
SchedulerHandshake._recv_execution_plan(data) | ||
worker_name = parallel_context.get_worker_name(dst) | ||
rpc.rpc_sync(to=worker_name, func=ProgressTracker._recv_tasks, args=(progress, clock_idx)) | ||
|
||
@staticmethod | ||
def _recv_execution_plan(data): | ||
SchedulerHandshake.progress = data | ||
def _recv_tasks(progress, clock_idx): | ||
ProgressTracker.progress = progress | ||
ProgressTracker.clock_idx = clock_idx | ||
|
||
def is_confirmed(self, task, clock_idx: int) -> bool: | ||
def is_confirmed(self, task: Task, clock_idx: int) -> bool: | ||
return self.progress[clock_idx][task] is True | ||
|
||
@staticmethod | ||
def is_all_confirmed(clock_idx: int) -> bool: | ||
progress = SchedulerHandshake.progress | ||
progress = ProgressTracker.progress | ||
return all([progress[clock_idx][task] is True for task in progress[clock_idx]]) | ||
|
||
def confirm(self, task): | ||
def confirm(self, task: Task): | ||
# TODO: only non-scheduler ranks should confirm | ||
master_worker_name = self.parallel_context.get_worker_name(self.MASTER_RANK) | ||
# rank = self.parallel_context.get_local_rank(self.parallel_mode) | ||
rpc.rpc_sync(master_worker_name, func=SchedulerHandshake._recv_confirm_from_worker, args=(task,)) | ||
rpc.rpc_sync(master_worker_name, func=ProgressTracker._recv_confirm_from_worker, args=(task,)) | ||
|
||
# NOTE: a worker node should confirm itself | ||
SchedulerHandshake._update_progress(task) | ||
ProgressTracker._update_progress(task) | ||
|
||
@staticmethod | ||
def _update_progress(task): | ||
clock_idx = SchedulerHandshake.clock_idx | ||
progress = SchedulerHandshake.progress | ||
def _update_progress(task: Task): | ||
clock_idx = ProgressTracker.clock_idx | ||
progress = ProgressTracker.progress | ||
progress[clock_idx][task] = True | ||
|
||
@staticmethod | ||
def _recv_confirm_from_worker(task): | ||
SchedulerHandshake._update_progress(task) | ||
|
||
clock_idx = SchedulerHandshake.clock_idx | ||
if SchedulerHandshake.is_all_confirmed(clock_idx) is True: | ||
SchedulerHandshake.clock_idx += 1 | ||
|
||
def wait_until_all_confirmed(self): | ||
if self.parallel_context.is_first_rank() is True: | ||
while True: | ||
if self.is_all_confirmed() is True: | ||
break | ||
else: | ||
sleep(self.NUM_SECONDS_IDLE) | ||
else: | ||
pass | ||
def _recv_confirm_from_worker(task: Task): | ||
ProgressTracker._update_progress(task) | ||
|
||
clock_idx = ProgressTracker.clock_idx | ||
if ProgressTracker.is_all_confirmed(clock_idx) is True: | ||
ProgressTracker.clock_idx += 1 | ||
# broadcast the progress to all worker nodes | ||
ProgressTracker._broadcast_tasks(ProgressTracker.progress) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.