Skip to content

Commit

Permalink
add syncronous scheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Sep 26, 2023
1 parent 5690337 commit 6d5390d
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 45 deletions.
15 changes: 14 additions & 1 deletion pipegoose/nn/pipeline_parallel2/pipeline_context.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import threading
from typing import List

from pipegoose.distributed.parallel_context import ParallelContext
Expand All @@ -13,6 +14,7 @@ def __init__(self, scheduler: BaseScheduler, parallel_context: ParallelContext):
self.parallel_context = parallel_context

self._clock_idx = 0
self._wait_new_clock_cycle = threading.Condition()

@property
def partition_idx(self) -> int:
Expand All @@ -30,7 +32,9 @@ def clock_idx(self) -> int:
def increase_a_clock_cycle(self):
"""Increase the current clock cycle in the pipline by 1."""
# TODO: add assert maximum clock cycles
self._clock_idx += 1
with self._wait_new_clock_cycle:
self._clock_idx += 1
self._wait_new_clock_cycle.notify_all()

@property
def schedule(self) -> List:
Expand All @@ -42,6 +46,15 @@ def schedules(self) -> List:
"""Get the schedule for entire training run."""
return self.scheduler.get_schedules()

def get_schedule(self):
with self._wait_new_clock_cycle:
while self.clock_idx < self.scheduler.total_clock_cycles:
schedules = self.get_schedule_from_partition(self.clock_idx, self.partition_idx)
yield from schedules

# NOTE: wait for the next clock cycle
self._wait_new_clock_cycle.wait()

def get_schedule_from_partition(self, clock_idx: int, partition_idx: int):
"""Get the schedule of a partition at a certain clock cycle."""
assert clock_idx >= 0, "Clock cycle index must be greater than or equal to 0."
Expand Down
4 changes: 4 additions & 0 deletions pipegoose/nn/pipeline_parallel2/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ def generate_backward_schedule(forward_schedule):
def start(self):
self._schedules = self.get_schedules()

@property
def total_clock_cycles(self) -> int:
return len(self.get_schedules())


class JobTracker:
def __init__(self, n_microbatches: int, n_partitions: int):
Expand Down
47 changes: 47 additions & 0 deletions tests/nn/pipeline_parallel_2/test_pipeline_context.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import threading

import pytest

from pipegoose.nn.pipeline_parallel2.pipeline_context import PipelineContext
from pipegoose.nn.pipeline_parallel2.scheduler import SchedulerType, get_scheduler
from pipegoose.nn.pipeline_parallel2.task import Task
from pipegoose.testing.utils import init_parallel_context, spawn


Expand Down Expand Up @@ -50,3 +53,47 @@ def test_run_pipeline_context(pipeline_parallel_size):
pipeline_parallel_size=pipeline_parallel_size,
data_parallel_size=DATA_PARALLEL_SIZE,
)


def run_get_syncronous_schedule(rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size):
N_PARTITIONS = 4
N_MICROBATCHES = 5

parallel_context = init_parallel_context(
rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size
)
scheduler = get_scheduler(SchedulerType.GPIPE)(N_MICROBATCHES, N_PARTITIONS)
TOTAL_SCHEDULES = scheduler.total_clock_cycles

def increase_clock_every_second(pipeline_context):
for _ in range(TOTAL_SCHEDULES):
# time.sleep(1)
pipeline_context.increase_a_clock_cycle()

pipeline_context = PipelineContext(scheduler, parallel_context)
clock_thread = threading.Thread(target=increase_clock_every_second, args=(pipeline_context,))
clock_thread.start()

prev_clock_idx = -1
for tasks in pipeline_context.get_schedule():
assert isinstance(tasks, Task)
assert pipeline_context.clock_idx == prev_clock_idx + 1
prev_clock_idx = pipeline_context.clock_idx

assert pipeline_context.clock_idx == TOTAL_SCHEDULES


@pytest.mark.parametrize("pipeline_parallel_size", [1, 2])
def test_get_syncronous_schedule(pipeline_parallel_size):
TENSOR_PARALLEL_SIZE = 1
DATA_PARALLEL_SIZE = 1

world_size = pipeline_parallel_size * TENSOR_PARALLEL_SIZE * DATA_PARALLEL_SIZE

spawn(
run_get_syncronous_schedule,
world_size=world_size,
tensor_parallel_size=TENSOR_PARALLEL_SIZE,
pipeline_parallel_size=pipeline_parallel_size,
data_parallel_size=DATA_PARALLEL_SIZE,
)
47 changes: 3 additions & 44 deletions tests/nn/pipeline_parallel_2/test_scheduler2.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
import pytest

from pipegoose.nn.pipeline_parallel2._job.job_type import JobType
from pipegoose.nn.pipeline_parallel2.scheduler import SchedulerType, get_scheduler
from pipegoose.testing.utils import spawn


def test_generate_schedules_using_gpipe_scheduler():
Expand All @@ -13,11 +10,11 @@ def test_generate_schedules_using_gpipe_scheduler():
TOTAL_CLOCK_CYCLES_IN_FORWARD = N_MICROBATCHES + N_PARTITIONS - 1
TOTAL_CLOCK_CYCLES = TOTAL_CLOCK_CYCLES_IN_FORWARD * 2

scheduler = get_scheduler(SchedulerType.GPIPE)
schedules = scheduler(N_MICROBATCHES, N_PARTITIONS).get_schedules()
scheduler = get_scheduler(SchedulerType.GPIPE)(N_MICROBATCHES, N_PARTITIONS)

# schedules = GPipeScheduler(N_MICROBATCHES, N_PARTITIONS).get_schedules()
assert scheduler.total_clock_cycles == TOTAL_CLOCK_CYCLES

schedules = scheduler.get_schedules()
assert len(schedules) == TOTAL_CLOCK_CYCLES

for tasks in schedules:
Expand All @@ -27,41 +24,3 @@ def test_generate_schedules_using_gpipe_scheduler():
assert task.job_type in JOB_TYPES
assert isinstance(task.partition_idx, int)
assert isinstance(task.microbatch_idx, int)


def run_syncronous_gpipe_scheduler(rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size):
# parallel_context = init_parallel_context(
# rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size
# )

# local_clock_idx = 0

# scheduler = GPipeScheduler(parallel_context)
# scheduler.start()

# assert scheduler.clock_idx == local_clock_idx

# for _ in range(5):
# # NOTE: simulate that different nodes have different processing times
# sleep(random.uniform(1, 5))

# scheduler.confirm()

# assert scheduler.clock_idx == local_clock_idx + 1
# local_clock_idx += 1
pass


@pytest.mark.skip
@pytest.mark.parametrize("pipeline_parallel_size", [1, 2])
def test_syncronous_scheduler(pipeline_parallel_size):
TENSOR_PARALLEL_SIZE = 1
DATA_PARALLEL_SIZE = 1

spawn(
run_syncronous_gpipe_scheduler,
world_size=pipeline_parallel_size,
tensor_parallel_size=TENSOR_PARALLEL_SIZE,
pipeline_parallel_size=pipeline_parallel_size,
data_parallel_size=DATA_PARALLEL_SIZE,
)

0 comments on commit 6d5390d

Please sign in to comment.