Skip to content

Commit

Permalink
WIP pipeline engine
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Sep 21, 2023
1 parent ed0408e commit 69606db
Show file tree
Hide file tree
Showing 8 changed files with 117 additions and 36 deletions.
17 changes: 17 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Python: Current File",
"type": "python",
"request": "launch",
"program": "${file}",
"console": "integratedTerminal",
"justMyCode": true,
"subProcess": true
}
]
}
4 changes: 2 additions & 2 deletions pipegoose/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

# NOTE: the minimum number of cocurrent worker threads that execute jobs
# in the background of pipeline parallelism
PIPELINE_MIN_WORKERS = 16
PIPELINE_MAX_WORKERS = 32
PIPELINE_MIN_WORKERS = 3
PIPELINE_MAX_WORKERS = 4

JOB_KEY_LENGTH = 15
4 changes: 2 additions & 2 deletions pipegoose/nn/pipeline_parallel2/_job/creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
CreateForwardOutputPackageCallback,
ForwardJob,
SaveActivationIfTrainingCallback,
SendForwardPackageCallback,
)
from pipegoose.nn.pipeline_parallel2._job.job import Job
from pipegoose.nn.pipeline_parallel2._job.job_type import JobType
Expand All @@ -37,7 +36,8 @@ def create(self) -> Job:
class _ForwardJobCreator(JobCreator):
"""Put a forward job into job queue for a worker to execute."""

CBS = [CreateForwardOutputPackageCallback, SaveActivationIfTrainingCallback, SendForwardPackageCallback]
CBS = [CreateForwardOutputPackageCallback, SaveActivationIfTrainingCallback]
# SendForwardPackageCallback

@classmethod
def create(cls, function: Callable, package: Package, pipeline_context: PipelineContext) -> ForwardJob:
Expand Down
60 changes: 44 additions & 16 deletions pipegoose/nn/pipeline_parallel2/pipeline_engine.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
from dataclasses import dataclass
from typing import List

import torch
from torch import nn

from pipegoose.distributed.parallel_context import ParallelContext
from pipegoose.distributed.parallel_mode import ParallelMode

# from pipegoose.nn.pipeline_parallel.partrition import BasePartitioner
from pipegoose.nn.pipeline_parallel2._job.creator import create_job
from pipegoose.nn.pipeline_parallel2._job.job_type import JobType
from pipegoose.nn.pipeline_parallel2._package import Metadata, Package, TrainingMetadata
from pipegoose.nn.pipeline_parallel2._worker import BaseWorkerManager
from pipegoose.nn.pipeline_parallel2.pipeline_context import PipelineContext
from pipegoose.nn.pipeline_parallel2.queue import JobQueue
from pipegoose.nn.pipeline_parallel2.scheduler import BaseScheduler
from pipegoose.nn.pipeline_parallel.partrition import BasePartitioner


@dataclass
Expand All @@ -25,7 +29,7 @@ class PipelineEngine:
def __init__(
self,
module: nn.Module,
partitioner: BasePartitioner,
# partitioner: BasePartitioner,
scheduler: BaseScheduler,
worker_manager: BaseWorkerManager,
parallel_context: ParallelContext,
Expand All @@ -36,26 +40,50 @@ def __init__(
), f"parallel_context must be an instance of ParallelContext, got {type(parallel_context)}"

self.module = module
self.partitioner = partitioner
# self.partitioner = partitioner
self.scheduler = scheduler
self.worker_manager = worker_manager
self.parallel_context = parallel_context

self.pipeline_context = PipelineContext(self.scheduler, self.parallel_context)

def parallelize(self):
pass
def run(self, inputs: torch.Tensor) -> torch.Tensor:
self.worker_manager.spawn()
n_microbatches = self.scheduler.n_microbatches

# microbatches = microbatch.split(inputs, n_microbatches=self.scheduler.n_microbatches)
microbatches = torch.chunk(inputs, chunks=n_microbatches, dim=0)

if self.parallel_context.is_first_rank(ParallelMode.PIPELINE):
for task in self.pipeline_context.schedule:
if task.partition_idx == 0:
microbatch_idx = task.microbatch_idx

batch = microbatches[microbatch_idx]
forward_job = self._construct_first_job(microbatch_idx=microbatch_idx, input=batch)

def forward(self, *args, **kwargs) -> torch.Tensor:
with self.worker_manager.spawn(self.pipeline_context):
for schedule in self.scheduler.generate():
self._compute(schedule)
JobQueue.PENDING_JOBS.put(forward_job)

def _compute(self, batches: torch.Tensor, schedule: List[Schedule]):
def get_current_node_schedule():
pass
def _construct_first_job(self, microbatch_idx: int, input: torch.Tensor):
PARTITION_IDX = 0
IS_TRAINING = torch.is_grad_enabled()

schedule = get_current_node_schedule(schedule)
metadata = Metadata(
microbatch_idx=microbatch_idx,
partition_idx=PARTITION_IDX,
job_type=JobType.FORWARD,
training=TrainingMetadata(
is_training=IS_TRAINING,
is_grad_enabled=IS_TRAINING,
),
src=self.parallel_context.get_global_rank(),
dst=self.parallel_context.get_global_rank(),
)
package = Package(
data=input,
metadata=metadata,
)

def _construct_first_job(self):
pass
function = nn.Linear(5, 5)
job = create_job(function, package, self.pipeline_context)
return job
4 changes: 2 additions & 2 deletions pipegoose/nn/pipeline_parallel2/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def is_running(self):
pass


class _GPipeScheduler(BaseScheduler):
class GPipeScheduler(BaseScheduler):
"""
torchgpipe: On-the-fly Pipeline Parallelism for Training Giant Models
https://arxiv.org/abs/2004.09910
Expand Down Expand Up @@ -124,7 +124,7 @@ def __init__(self, n_microbatches: int, n_partitions: int):

def get_scheduler(scheduler_type: SchedulerType) -> BaseScheduler:
scheduler_type_to_scheduler = {
SchedulerType.GPIPE: _GPipeScheduler,
SchedulerType.GPIPE: GPipeScheduler,
}

return scheduler_type_to_scheduler[scheduler_type]
1 change: 1 addition & 0 deletions tests/nn/pipeline_parallel_2/test_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def run_model_partitioner(rank, world_size, port, tensor_parallel_size, pipeline
)
module = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token

text = ["Hello world", "How are you?"]
inputs = tokenizer(text, return_tensors="pt", padding=True)
Expand Down
62 changes: 48 additions & 14 deletions tests/nn/pipeline_parallel_2/test_pipeline_engine.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,54 @@
import pytest
import torch
from torch import nn

from pipegoose.nn.pipeline_parallel2._worker import WorkerManager
from pipegoose.nn.pipeline_parallel2.pipeline_engine import PipelineEngine
from pipegoose.nn.pipeline_parallel2.scheduler import GPipeScheduler
from pipegoose.testing.utils import init_parallel_context, spawn

class FakeParallelContext:
pass
model = nn.Sequential(
nn.Linear(5, 5),
nn.ReLU(),
nn.Linear(5, 5),
)


@pytest.mark.skip
def test_pipeline_engine(model):
# BATCH_SIZE = 32
# parallel_context = FakeParallelContext()
# torch.randn(BATCH_SIZE, 4)
def run_pipeline_engine(rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size):
BATCH_SIZE = 32
SEQ_LEN = 10
HIDDEN_DIM = 5

# pipeline_engine = PipelineEngine(
# module=model,
# scheduler=GPipeScheduler(),
# worker_manager=WorkerManager(),
# parallel_context=parallel_context,
# )
pass
N_MICROBATCHES = 6

inputs = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_DIM)
scheduler = GPipeScheduler(N_MICROBATCHES, pipeline_parallel_size)
worker_manager = WorkerManager()
parallel_context = init_parallel_context(
rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size
)

pipeline_engine = PipelineEngine(
module=model,
scheduler=scheduler,
worker_manager=worker_manager,
parallel_context=parallel_context,
)

pipeline_engine.run(inputs)

# assert torch.allclose(outputs, model(inputs))


@pytest.mark.parametrize("pipeline_parallel_size", [1, 2])
def test_pipeline_engine(pipeline_parallel_size):
DATA_PARALLEL_SIZE = 1
TENSOR_PARALLEL_SIZE = 1

spawn(
run_pipeline_engine,
world_size=pipeline_parallel_size,
tensor_parallel_size=TENSOR_PARALLEL_SIZE,
pipeline_parallel_size=pipeline_parallel_size,
data_parallel_size=DATA_PARALLEL_SIZE,
)
1 change: 1 addition & 0 deletions tests/nn/pipeline_parallel_2/test_worker2.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def compute(self):

PENDING_JOBS.put(job)
assert PENDING_JOBS.qsize() == 1
assert SELECTED_JOBS.qsize() == 0

# NOTE: wait for job selector picks up the job
sleep(2)
Expand Down

0 comments on commit 69606db

Please sign in to comment.