Skip to content

Commit

Permalink
refactor progress tracker
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Sep 25, 2023
1 parent 634d483 commit 962d7b2
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 20 deletions.
9 changes: 1 addition & 8 deletions pipegoose/nn/pipeline_parallel2/scheduler.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,15 @@
from abc import ABC, abstractclassmethod
from dataclasses import dataclass
from enum import Enum, auto
from typing import List

from pipegoose.nn.pipeline_parallel2._job.job_type import JobType
from pipegoose.nn.pipeline_parallel2.task import Task


class SchedulerType(Enum):
GPIPE = auto()


@dataclass
class Task:
job_type: JobType
microbatch_idx: int
partition_idx: int


class BaseScheduler(ABC):
@abstractclassmethod
def get_schedules(self):
Expand Down
Empty file.
20 changes: 10 additions & 10 deletions pipegoose/nn/pipeline_parallel2/sync/handshake.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
from abc import ABC, abstractclassmethod
from typing import NewType, Tuple
from typing import Dict, NewType

import torch.distributed.rpc as rpc

from pipegoose.distributed.parallel_context import ParallelContext
from pipegoose.distributed.parallel_mode import ParallelMode
from pipegoose.nn.pipeline_parallel2.task import Task

# NOTE: (microbatch_idx, partition_idx)
Task = NewType("Task", Tuple[int, int])
Progress = NewType("Progress", Dict[int, Dict[Task, bool]])


class Handshake(ABC):
master_rank = None
master_rank: int = None

parallel_context = None
parallel_mode = None
parallel_context: ParallelContext = None
parallel_mode: ParallelMode = None

def __init__(self, master_rank: int, parallel_context: ParallelContext, parallel_mode: ParallelMode):
Handshake.master_rank = master_rank
Expand Down Expand Up @@ -45,13 +45,13 @@ def is_all_confirmed(self, clock_idx: int) -> bool:
class ProgressTracker(Handshake):
"""Pipeline parallelism's progress tracker."""

progress = None
clock_idx = None
progress: Progress = None
clock_idx: int = None

def is_initiated(self) -> bool:
return self.progress is not None

def initiate(self, progress):
def initiate(self, progress: Progress):
INITIAL_CLOCK_IDX = 0
ProgressTracker._broadcast_tasks(progress, clock_idx=INITIAL_CLOCK_IDX)
ProgressTracker._recv_tasks(progress, clock_idx=INITIAL_CLOCK_IDX)
Expand All @@ -73,7 +73,7 @@ def _broadcast_tasks(progress, clock_idx):
rpc.rpc_sync(to=worker_name, func=ProgressTracker._recv_tasks, args=(progress, clock_idx))

@staticmethod
def _recv_tasks(progress, clock_idx):
def _recv_tasks(progress: Progress, clock_idx: int):
ProgressTracker.progress = progress
ProgressTracker.clock_idx = clock_idx

Expand Down
Empty file.
10 changes: 10 additions & 0 deletions pipegoose/nn/pipeline_parallel2/task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from dataclasses import dataclass

from pipegoose.nn.pipeline_parallel2._job.job_type import JobType


@dataclass
class Task:
job_type: JobType
microbatch_idx: int
partition_idx: int
9 changes: 7 additions & 2 deletions tests/nn/pipeline_parallel_2/sync/test_progress_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
from pipegoose.testing.utils import init_parallel_context, spawn


def get_task(microbatch_idx, partition_idx):
return (microbatch_idx, partition_idx)


def get_gpipe_schedules(n_partitions, n_microbatches):
n_clock_cycles = n_partitions + n_microbatches - 1
schedules = []
Expand All @@ -18,7 +22,8 @@ def get_gpipe_schedules(n_partitions, n_microbatches):
tasks = []
for partition_idx in range(start_partrition, end_partition):
microbatch_idx = clock_idx - partition_idx
tasks.append((microbatch_idx, partition_idx))
task = get_task(microbatch_idx, partition_idx)
tasks.append(task)

schedules.append(tasks)

Expand Down Expand Up @@ -102,7 +107,7 @@ def run_confirm_progress_tracker(rank, world_size, port, tensor_parallel_size, p
# NOTE: wait until the tracker is initiated
time.sleep(2)
partition_idx = get_partition_idx(parallel_context)
task = (MICROBATCH_IDX, partition_idx)
task = get_task(MICROBATCH_IDX, partition_idx)
tracker.confirm(task)
assert tracker.is_confirmed(task, clock_idx=0) is True

Expand Down

0 comments on commit 962d7b2

Please sign in to comment.