Skip to content

Commit

Permalink
add the forward pass of pipeline engine
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Sep 30, 2023
1 parent 821a5fb commit 8c83d49
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 68 deletions.
4 changes: 2 additions & 2 deletions pipegoose/nn/pipeline_parallel2/_job/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,6 @@ def after_compute(self):
# )
key = (microbatch_idx, partition_idx)
progress_tracker.confirm(key)
import time
# import time

time.sleep(3)
# time.sleep(3)
6 changes: 6 additions & 0 deletions pipegoose/nn/pipeline_parallel2/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,9 @@ def get_partition_idx(parallel_context: ParallelContext) -> int:
# pipeline_stage_idx = rank // n_ranks_per_group
# return pipeline_stage_idx
return ranks_in_group.index(rank)


def is_last_stage(parallel_context: ParallelContext) -> bool:
partition_idx = get_partition_idx(parallel_context)
n_stages = parallel_context.pipeline_parallel_size
return partition_idx == (n_stages - 1)
44 changes: 16 additions & 28 deletions pipegoose/nn/pipeline_parallel2/pipeline_engine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import time
from dataclasses import dataclass

import torch
Expand Down Expand Up @@ -43,9 +42,9 @@ def __init__(
partition_func,
):
assert isinstance(module, nn.Module), f"module must be an instance of nn.Module, got {type(module)}"
# assert isinstance(
# parallel_context, ParallelContext
# ), f"parallel_context must be an instance of ParallelContext, got {type(parallel_context)}"
assert isinstance(
parallel_context, ParallelContext
), f"parallel_context must be an instance of ParallelContext, got {type(parallel_context)}"

self.module = module
# self.partitioner = partitioner
Expand All @@ -60,9 +59,6 @@ def __init__(
def run(self, inputs: torch.Tensor) -> torch.Tensor:
MASTER_RANK = 0

# from hanging_threads import start_monitoring
# monitoring_thread = start_monitoring()

self.worker_manager.spawn()
n_microbatches = self.scheduler.n_microbatches

Expand All @@ -80,15 +76,13 @@ def after_new_clock_cycle(self, progress, clock_idx):
parallel_context = self.pipeline_context.parallel_context
print(f"increase clock, clock_idx={clock_idx}, rank={parallel_context.get_local_rank(ParallelMode.GLOBAL)}")
self.pipeline_context.increase_a_clock_cycle()
time.sleep(1)

callbacks = [IncreasePipelineContextClockCycleCallback(self.pipeline_context)]
progress_tracker = ProgressTracker(
MASTER_RANK, callbacks=callbacks, parallel_context=self.parallel_context, parallel_mode=ParallelMode.GLOBAL
)
# NOTE: wait for all ranks to be initiated
dist.barrier()
time.sleep(1)

# if self.parallel_context.is_first_rank(ParallelMode.PIPELINE):
if self.parallel_context.get_global_rank() == 0:
Expand All @@ -101,20 +95,19 @@ def after_new_clock_cycle(self, progress, clock_idx):
print(progress)

dist.barrier()
time.sleep(5)

set_progress_tracker(progress_tracker)

dist.barrier()
time.sleep(2)

# from hanging_threads import start_monitoring
# monitoring_thread = start_monitoring()

for tasks in self.pipeline_context.get_schedule():

dist.barrier()

if self.pipeline_context.clock_idx == 9:
# TODO: remove this
# this is for breaking the loop once getting backward tasks
break

rank = self.parallel_context.get_global_rank()
partition_idx = self.pipeline_context.partition_idx

Expand All @@ -125,7 +118,6 @@ def after_new_clock_cycle(self, progress, clock_idx):
assert 1 == 1

if len(tasks) > 0:
# print(f"[enter look] clock_idx={self.pipeline_context.clock_idx}, rank={rank}, partition_idx={partition_idx}")
for task in tasks:
microbatch_idx = task.microbatch_idx
partition_idx = task.partition_idx
Expand All @@ -136,23 +128,19 @@ def after_new_clock_cycle(self, progress, clock_idx):
else:
package = RECV_QUEUE.get()

# print(
# f"[received a package]clock_idx={self.pipeline_context.clock_idx}, rank={rank}, partition_idx={partition_idx}",
# package.metadata,
# )

job = create_job(self.partition_func, package, self.pipeline_context)

# print(f"created a job: {package.metadata}")

JobQueue.PENDING_JOBS.put(job)

dist.barrier()

# def _retrieve_package_from_received_package(self, microbatch_idx, partition_idx):
# # package = RECV_QUEUE[(microbatch_idx, partition_idx)]
# package = RECV_QUEUE.get()
# return package
dist.barrier()

if self.pipeline_context.is_last_stage:
from pipegoose.nn.pipeline_parallel2.queue import _SAVED_ACTIVATIONS

outputs = [_SAVED_ACTIVATIONS[(microbatch_idx, partition_idx)] for microbatch_idx in range(n_microbatches)]
outputs = torch.cat(outputs, dim=0)
return outputs

def _construct_first_package(self, microbatch_idx: int, input: torch.Tensor):
"""Construct the first forward package of a microbatch."""
Expand Down
55 changes: 17 additions & 38 deletions tests/nn/pipeline_parallel_2/test_pipeline_engine.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
import time

import pytest
import torch
from torch import nn

from pipegoose.distributed.parallel_context import ParallelContext
from pipegoose.nn.pipeline_parallel2._utils import get_partition_idx, sleep
from pipegoose.nn.pipeline_parallel2._utils import get_partition_idx, is_last_stage
from pipegoose.nn.pipeline_parallel2._worker import WorkerManager
from pipegoose.nn.pipeline_parallel2.pipeline_engine import PipelineEngine
from pipegoose.nn.pipeline_parallel2.scheduler import GPipeScheduler
from pipegoose.testing.utils import spawn
from pipegoose.testing.utils import init_parallel_context, spawn

model = nn.Sequential(
nn.Linear(5, 5),
Expand All @@ -18,32 +16,17 @@
)


def run_pipeline_engine(rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size, package):
def run_pipeline_engine(rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size):
BATCH_SIZE = 32
SEQ_LEN = 10
HIDDEN_DIM = 5

N_MICROBATCHES = 6

inputs = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_DIM)
scheduler = GPipeScheduler(N_MICROBATCHES, pipeline_parallel_size)
# parallel_context = init_parallel_context(
# rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size
# )
parallel_context = ParallelContext(
rank=rank,
local_rank=rank,
world_size=world_size,
local_world_size=world_size,
host="localhost",
port=port,
seed=69,
backend="gloo",
tensor_parallel_size=tensor_parallel_size,
pipeline_parallel_size=pipeline_parallel_size,
data_parallel_size=data_parallel_size,
parallel_context = init_parallel_context(
rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size
)

forward_timeline = []

class Function(nn.Module):
Expand All @@ -70,32 +53,28 @@ def forward(self, input):
parallel_context=parallel_context,
partition_func=partition_func,
)
EXPECTED_FORWARD_TIMELINE = [(microbatch_idx, partition_idx) for microbatch_idx in range(N_MICROBATCHES)]

pipeline_engine.run(inputs)
outputs = pipeline_engine.run(inputs)

sleep(3)
if is_last_stage(parallel_context):
assert forward_timeline == EXPECTED_FORWARD_TIMELINE
else:
# NOTE: earlier stages should not return the final output
assert outputs is None

assert forward_timeline == [
(0, partition_idx),
(1, partition_idx),
(2, partition_idx),
(3, partition_idx),
(4, partition_idx),
]


@pytest.mark.parametrize("pipeline_parallel_size", [1, 2, 4])
def test_pipeline_engine(pipeline_parallel_size, forward_package):
DATA_PARALLEL_SIZE = 1
def test_pipeline_engine():
TENSOR_PARALLEL_SIZE = 1
PIPELINE_PARALLEL_SIZE = 4
DATA_PARALLEL_SIZE = 1

WORLD_SIZE = pipeline_parallel_size * DATA_PARALLEL_SIZE * TENSOR_PARALLEL_SIZE
WORLD_SIZE = PIPELINE_PARALLEL_SIZE * DATA_PARALLEL_SIZE * TENSOR_PARALLEL_SIZE

spawn(
run_pipeline_engine,
world_size=WORLD_SIZE,
tensor_parallel_size=TENSOR_PARALLEL_SIZE,
pipeline_parallel_size=pipeline_parallel_size,
pipeline_parallel_size=PIPELINE_PARALLEL_SIZE,
data_parallel_size=DATA_PARALLEL_SIZE,
package=forward_package,
)

0 comments on commit 8c83d49

Please sign in to comment.