From 69606dbf97d799e5f43d5f886f2c921a93920957 Mon Sep 17 00:00:00 2001 From: xrsrke Date: Thu, 21 Sep 2023 10:32:11 +0700 Subject: [PATCH] WIP pipeline engine --- .vscode/launch.json | 17 +++++ pipegoose/constants.py | 4 +- .../nn/pipeline_parallel2/_job/creator.py | 4 +- .../nn/pipeline_parallel2/pipeline_engine.py | 60 +++++++++++++----- pipegoose/nn/pipeline_parallel2/scheduler.py | 4 +- .../pipeline_parallel_2/test_partitioner.py | 1 + .../test_pipeline_engine.py | 62 ++++++++++++++----- tests/nn/pipeline_parallel_2/test_worker2.py | 1 + 8 files changed, 117 insertions(+), 36 deletions(-) create mode 100644 .vscode/launch.json diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..8b52002 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,17 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "Python: Current File", + "type": "python", + "request": "launch", + "program": "${file}", + "console": "integratedTerminal", + "justMyCode": true, + "subProcess": true + } + ] +} \ No newline at end of file diff --git a/pipegoose/constants.py b/pipegoose/constants.py index 9c9ca8b..712d5cd 100644 --- a/pipegoose/constants.py +++ b/pipegoose/constants.py @@ -11,7 +11,7 @@ # NOTE: the minimum number of cocurrent worker threads that execute jobs # in the background of pipeline parallelism -PIPELINE_MIN_WORKERS = 16 -PIPELINE_MAX_WORKERS = 32 +PIPELINE_MIN_WORKERS = 3 +PIPELINE_MAX_WORKERS = 4 JOB_KEY_LENGTH = 15 diff --git a/pipegoose/nn/pipeline_parallel2/_job/creator.py b/pipegoose/nn/pipeline_parallel2/_job/creator.py index 2fac89b..7d7b2b4 100644 --- a/pipegoose/nn/pipeline_parallel2/_job/creator.py +++ b/pipegoose/nn/pipeline_parallel2/_job/creator.py @@ -17,7 +17,6 @@ CreateForwardOutputPackageCallback, ForwardJob, SaveActivationIfTrainingCallback, - SendForwardPackageCallback, ) from pipegoose.nn.pipeline_parallel2._job.job import Job from pipegoose.nn.pipeline_parallel2._job.job_type import JobType @@ -37,7 +36,8 @@ def create(self) -> Job: class _ForwardJobCreator(JobCreator): """Put a forward job into job queue for a worker to execute.""" - CBS = [CreateForwardOutputPackageCallback, SaveActivationIfTrainingCallback, SendForwardPackageCallback] + CBS = [CreateForwardOutputPackageCallback, SaveActivationIfTrainingCallback] + # SendForwardPackageCallback @classmethod def create(cls, function: Callable, package: Package, pipeline_context: PipelineContext) -> ForwardJob: diff --git a/pipegoose/nn/pipeline_parallel2/pipeline_engine.py b/pipegoose/nn/pipeline_parallel2/pipeline_engine.py index e820d82..cf3fc5e 100644 --- a/pipegoose/nn/pipeline_parallel2/pipeline_engine.py +++ b/pipegoose/nn/pipeline_parallel2/pipeline_engine.py @@ -1,15 +1,19 @@ from dataclasses import dataclass -from typing import List import torch from torch import nn from pipegoose.distributed.parallel_context import ParallelContext +from pipegoose.distributed.parallel_mode import ParallelMode + +# from pipegoose.nn.pipeline_parallel.partrition import BasePartitioner +from pipegoose.nn.pipeline_parallel2._job.creator import create_job from pipegoose.nn.pipeline_parallel2._job.job_type import JobType +from pipegoose.nn.pipeline_parallel2._package import Metadata, Package, TrainingMetadata from pipegoose.nn.pipeline_parallel2._worker import BaseWorkerManager from pipegoose.nn.pipeline_parallel2.pipeline_context import PipelineContext +from pipegoose.nn.pipeline_parallel2.queue import JobQueue from pipegoose.nn.pipeline_parallel2.scheduler import BaseScheduler -from pipegoose.nn.pipeline_parallel.partrition import BasePartitioner @dataclass @@ -25,7 +29,7 @@ class PipelineEngine: def __init__( self, module: nn.Module, - partitioner: BasePartitioner, + # partitioner: BasePartitioner, scheduler: BaseScheduler, worker_manager: BaseWorkerManager, parallel_context: ParallelContext, @@ -36,26 +40,50 @@ def __init__( ), f"parallel_context must be an instance of ParallelContext, got {type(parallel_context)}" self.module = module - self.partitioner = partitioner + # self.partitioner = partitioner self.scheduler = scheduler self.worker_manager = worker_manager self.parallel_context = parallel_context self.pipeline_context = PipelineContext(self.scheduler, self.parallel_context) - def parallelize(self): - pass + def run(self, inputs: torch.Tensor) -> torch.Tensor: + self.worker_manager.spawn() + n_microbatches = self.scheduler.n_microbatches + + # microbatches = microbatch.split(inputs, n_microbatches=self.scheduler.n_microbatches) + microbatches = torch.chunk(inputs, chunks=n_microbatches, dim=0) + + if self.parallel_context.is_first_rank(ParallelMode.PIPELINE): + for task in self.pipeline_context.schedule: + if task.partition_idx == 0: + microbatch_idx = task.microbatch_idx + + batch = microbatches[microbatch_idx] + forward_job = self._construct_first_job(microbatch_idx=microbatch_idx, input=batch) - def forward(self, *args, **kwargs) -> torch.Tensor: - with self.worker_manager.spawn(self.pipeline_context): - for schedule in self.scheduler.generate(): - self._compute(schedule) + JobQueue.PENDING_JOBS.put(forward_job) - def _compute(self, batches: torch.Tensor, schedule: List[Schedule]): - def get_current_node_schedule(): - pass + def _construct_first_job(self, microbatch_idx: int, input: torch.Tensor): + PARTITION_IDX = 0 + IS_TRAINING = torch.is_grad_enabled() - schedule = get_current_node_schedule(schedule) + metadata = Metadata( + microbatch_idx=microbatch_idx, + partition_idx=PARTITION_IDX, + job_type=JobType.FORWARD, + training=TrainingMetadata( + is_training=IS_TRAINING, + is_grad_enabled=IS_TRAINING, + ), + src=self.parallel_context.get_global_rank(), + dst=self.parallel_context.get_global_rank(), + ) + package = Package( + data=input, + metadata=metadata, + ) - def _construct_first_job(self): - pass + function = nn.Linear(5, 5) + job = create_job(function, package, self.pipeline_context) + return job diff --git a/pipegoose/nn/pipeline_parallel2/scheduler.py b/pipegoose/nn/pipeline_parallel2/scheduler.py index e705133..37fe009 100644 --- a/pipegoose/nn/pipeline_parallel2/scheduler.py +++ b/pipegoose/nn/pipeline_parallel2/scheduler.py @@ -37,7 +37,7 @@ def is_running(self): pass -class _GPipeScheduler(BaseScheduler): +class GPipeScheduler(BaseScheduler): """ torchgpipe: On-the-fly Pipeline Parallelism for Training Giant Models https://arxiv.org/abs/2004.09910 @@ -124,7 +124,7 @@ def __init__(self, n_microbatches: int, n_partitions: int): def get_scheduler(scheduler_type: SchedulerType) -> BaseScheduler: scheduler_type_to_scheduler = { - SchedulerType.GPIPE: _GPipeScheduler, + SchedulerType.GPIPE: GPipeScheduler, } return scheduler_type_to_scheduler[scheduler_type] diff --git a/tests/nn/pipeline_parallel_2/test_partitioner.py b/tests/nn/pipeline_parallel_2/test_partitioner.py index b49fabe..d5760cc 100644 --- a/tests/nn/pipeline_parallel_2/test_partitioner.py +++ b/tests/nn/pipeline_parallel_2/test_partitioner.py @@ -17,6 +17,7 @@ def run_model_partitioner(rank, world_size, port, tensor_parallel_size, pipeline ) module = AutoModelForCausalLM.from_pretrained(MODEL_NAME) tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + tokenizer.pad_token = tokenizer.eos_token text = ["Hello world", "How are you?"] inputs = tokenizer(text, return_tensors="pt", padding=True) diff --git a/tests/nn/pipeline_parallel_2/test_pipeline_engine.py b/tests/nn/pipeline_parallel_2/test_pipeline_engine.py index 30864a8..5dc61a7 100644 --- a/tests/nn/pipeline_parallel_2/test_pipeline_engine.py +++ b/tests/nn/pipeline_parallel_2/test_pipeline_engine.py @@ -1,20 +1,54 @@ import pytest +import torch +from torch import nn +from pipegoose.nn.pipeline_parallel2._worker import WorkerManager +from pipegoose.nn.pipeline_parallel2.pipeline_engine import PipelineEngine +from pipegoose.nn.pipeline_parallel2.scheduler import GPipeScheduler +from pipegoose.testing.utils import init_parallel_context, spawn -class FakeParallelContext: - pass +model = nn.Sequential( + nn.Linear(5, 5), + nn.ReLU(), + nn.Linear(5, 5), +) -@pytest.mark.skip -def test_pipeline_engine(model): - # BATCH_SIZE = 32 - # parallel_context = FakeParallelContext() - # torch.randn(BATCH_SIZE, 4) +def run_pipeline_engine(rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size): + BATCH_SIZE = 32 + SEQ_LEN = 10 + HIDDEN_DIM = 5 - # pipeline_engine = PipelineEngine( - # module=model, - # scheduler=GPipeScheduler(), - # worker_manager=WorkerManager(), - # parallel_context=parallel_context, - # ) - pass + N_MICROBATCHES = 6 + + inputs = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_DIM) + scheduler = GPipeScheduler(N_MICROBATCHES, pipeline_parallel_size) + worker_manager = WorkerManager() + parallel_context = init_parallel_context( + rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size + ) + + pipeline_engine = PipelineEngine( + module=model, + scheduler=scheduler, + worker_manager=worker_manager, + parallel_context=parallel_context, + ) + + pipeline_engine.run(inputs) + + # assert torch.allclose(outputs, model(inputs)) + + +@pytest.mark.parametrize("pipeline_parallel_size", [1, 2]) +def test_pipeline_engine(pipeline_parallel_size): + DATA_PARALLEL_SIZE = 1 + TENSOR_PARALLEL_SIZE = 1 + + spawn( + run_pipeline_engine, + world_size=pipeline_parallel_size, + tensor_parallel_size=TENSOR_PARALLEL_SIZE, + pipeline_parallel_size=pipeline_parallel_size, + data_parallel_size=DATA_PARALLEL_SIZE, + ) diff --git a/tests/nn/pipeline_parallel_2/test_worker2.py b/tests/nn/pipeline_parallel_2/test_worker2.py index 1969757..7b2ee4b 100644 --- a/tests/nn/pipeline_parallel_2/test_worker2.py +++ b/tests/nn/pipeline_parallel_2/test_worker2.py @@ -59,6 +59,7 @@ def compute(self): PENDING_JOBS.put(job) assert PENDING_JOBS.qsize() == 1 + assert SELECTED_JOBS.qsize() == 0 # NOTE: wait for job selector picks up the job sleep(2)