diff --git a/pipegoose/nn/pipeline_parallel2/_job/forward.py b/pipegoose/nn/pipeline_parallel2/_job/forward.py index cc195e6..946ba89 100644 --- a/pipegoose/nn/pipeline_parallel2/_job/forward.py +++ b/pipegoose/nn/pipeline_parallel2/_job/forward.py @@ -116,6 +116,6 @@ def after_compute(self): # ) key = (microbatch_idx, partition_idx) progress_tracker.confirm(key) - import time + # import time - time.sleep(3) + # time.sleep(3) diff --git a/pipegoose/nn/pipeline_parallel2/_utils.py b/pipegoose/nn/pipeline_parallel2/_utils.py index b06bde3..40f6241 100644 --- a/pipegoose/nn/pipeline_parallel2/_utils.py +++ b/pipegoose/nn/pipeline_parallel2/_utils.py @@ -14,3 +14,9 @@ def get_partition_idx(parallel_context: ParallelContext) -> int: # pipeline_stage_idx = rank // n_ranks_per_group # return pipeline_stage_idx return ranks_in_group.index(rank) + + +def is_last_stage(parallel_context: ParallelContext) -> bool: + partition_idx = get_partition_idx(parallel_context) + n_stages = parallel_context.pipeline_parallel_size + return partition_idx == (n_stages - 1) diff --git a/pipegoose/nn/pipeline_parallel2/pipeline_engine.py b/pipegoose/nn/pipeline_parallel2/pipeline_engine.py index 7c220a9..be52593 100644 --- a/pipegoose/nn/pipeline_parallel2/pipeline_engine.py +++ b/pipegoose/nn/pipeline_parallel2/pipeline_engine.py @@ -1,4 +1,3 @@ -import time from dataclasses import dataclass import torch @@ -43,9 +42,9 @@ def __init__( partition_func, ): assert isinstance(module, nn.Module), f"module must be an instance of nn.Module, got {type(module)}" - # assert isinstance( - # parallel_context, ParallelContext - # ), f"parallel_context must be an instance of ParallelContext, got {type(parallel_context)}" + assert isinstance( + parallel_context, ParallelContext + ), f"parallel_context must be an instance of ParallelContext, got {type(parallel_context)}" self.module = module # self.partitioner = partitioner @@ -60,9 +59,6 @@ def __init__( def run(self, inputs: torch.Tensor) -> torch.Tensor: MASTER_RANK = 0 - # from hanging_threads import start_monitoring - # monitoring_thread = start_monitoring() - self.worker_manager.spawn() n_microbatches = self.scheduler.n_microbatches @@ -80,7 +76,6 @@ def after_new_clock_cycle(self, progress, clock_idx): parallel_context = self.pipeline_context.parallel_context print(f"increase clock, clock_idx={clock_idx}, rank={parallel_context.get_local_rank(ParallelMode.GLOBAL)}") self.pipeline_context.increase_a_clock_cycle() - time.sleep(1) callbacks = [IncreasePipelineContextClockCycleCallback(self.pipeline_context)] progress_tracker = ProgressTracker( @@ -88,7 +83,6 @@ def after_new_clock_cycle(self, progress, clock_idx): ) # NOTE: wait for all ranks to be initiated dist.barrier() - time.sleep(1) # if self.parallel_context.is_first_rank(ParallelMode.PIPELINE): if self.parallel_context.get_global_rank() == 0: @@ -101,20 +95,19 @@ def after_new_clock_cycle(self, progress, clock_idx): print(progress) dist.barrier() - time.sleep(5) set_progress_tracker(progress_tracker) dist.barrier() - time.sleep(2) - - # from hanging_threads import start_monitoring - # monitoring_thread = start_monitoring() for tasks in self.pipeline_context.get_schedule(): - dist.barrier() + if self.pipeline_context.clock_idx == 9: + # TODO: remove this + # this is for breaking the loop once getting backward tasks + break + rank = self.parallel_context.get_global_rank() partition_idx = self.pipeline_context.partition_idx @@ -125,7 +118,6 @@ def after_new_clock_cycle(self, progress, clock_idx): assert 1 == 1 if len(tasks) > 0: - # print(f"[enter look] clock_idx={self.pipeline_context.clock_idx}, rank={rank}, partition_idx={partition_idx}") for task in tasks: microbatch_idx = task.microbatch_idx partition_idx = task.partition_idx @@ -136,23 +128,19 @@ def after_new_clock_cycle(self, progress, clock_idx): else: package = RECV_QUEUE.get() - # print( - # f"[received a package]clock_idx={self.pipeline_context.clock_idx}, rank={rank}, partition_idx={partition_idx}", - # package.metadata, - # ) - job = create_job(self.partition_func, package, self.pipeline_context) - - # print(f"created a job: {package.metadata}") - JobQueue.PENDING_JOBS.put(job) dist.barrier() - # def _retrieve_package_from_received_package(self, microbatch_idx, partition_idx): - # # package = RECV_QUEUE[(microbatch_idx, partition_idx)] - # package = RECV_QUEUE.get() - # return package + dist.barrier() + + if self.pipeline_context.is_last_stage: + from pipegoose.nn.pipeline_parallel2.queue import _SAVED_ACTIVATIONS + + outputs = [_SAVED_ACTIVATIONS[(microbatch_idx, partition_idx)] for microbatch_idx in range(n_microbatches)] + outputs = torch.cat(outputs, dim=0) + return outputs def _construct_first_package(self, microbatch_idx: int, input: torch.Tensor): """Construct the first forward package of a microbatch.""" diff --git a/tests/nn/pipeline_parallel_2/test_pipeline_engine.py b/tests/nn/pipeline_parallel_2/test_pipeline_engine.py index 7e42f86..26db008 100644 --- a/tests/nn/pipeline_parallel_2/test_pipeline_engine.py +++ b/tests/nn/pipeline_parallel_2/test_pipeline_engine.py @@ -1,15 +1,13 @@ import time -import pytest import torch from torch import nn -from pipegoose.distributed.parallel_context import ParallelContext -from pipegoose.nn.pipeline_parallel2._utils import get_partition_idx, sleep +from pipegoose.nn.pipeline_parallel2._utils import get_partition_idx, is_last_stage from pipegoose.nn.pipeline_parallel2._worker import WorkerManager from pipegoose.nn.pipeline_parallel2.pipeline_engine import PipelineEngine from pipegoose.nn.pipeline_parallel2.scheduler import GPipeScheduler -from pipegoose.testing.utils import spawn +from pipegoose.testing.utils import init_parallel_context, spawn model = nn.Sequential( nn.Linear(5, 5), @@ -18,32 +16,17 @@ ) -def run_pipeline_engine(rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size, package): +def run_pipeline_engine(rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size): BATCH_SIZE = 32 SEQ_LEN = 10 HIDDEN_DIM = 5 - N_MICROBATCHES = 6 inputs = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_DIM) scheduler = GPipeScheduler(N_MICROBATCHES, pipeline_parallel_size) - # parallel_context = init_parallel_context( - # rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size - # ) - parallel_context = ParallelContext( - rank=rank, - local_rank=rank, - world_size=world_size, - local_world_size=world_size, - host="localhost", - port=port, - seed=69, - backend="gloo", - tensor_parallel_size=tensor_parallel_size, - pipeline_parallel_size=pipeline_parallel_size, - data_parallel_size=data_parallel_size, + parallel_context = init_parallel_context( + rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size ) - forward_timeline = [] class Function(nn.Module): @@ -70,32 +53,28 @@ def forward(self, input): parallel_context=parallel_context, partition_func=partition_func, ) + EXPECTED_FORWARD_TIMELINE = [(microbatch_idx, partition_idx) for microbatch_idx in range(N_MICROBATCHES)] - pipeline_engine.run(inputs) + outputs = pipeline_engine.run(inputs) - sleep(3) + if is_last_stage(parallel_context): + assert forward_timeline == EXPECTED_FORWARD_TIMELINE + else: + # NOTE: earlier stages should not return the final output + assert outputs is None - assert forward_timeline == [ - (0, partition_idx), - (1, partition_idx), - (2, partition_idx), - (3, partition_idx), - (4, partition_idx), - ] - -@pytest.mark.parametrize("pipeline_parallel_size", [1, 2, 4]) -def test_pipeline_engine(pipeline_parallel_size, forward_package): - DATA_PARALLEL_SIZE = 1 +def test_pipeline_engine(): TENSOR_PARALLEL_SIZE = 1 + PIPELINE_PARALLEL_SIZE = 4 + DATA_PARALLEL_SIZE = 1 - WORLD_SIZE = pipeline_parallel_size * DATA_PARALLEL_SIZE * TENSOR_PARALLEL_SIZE + WORLD_SIZE = PIPELINE_PARALLEL_SIZE * DATA_PARALLEL_SIZE * TENSOR_PARALLEL_SIZE spawn( run_pipeline_engine, world_size=WORLD_SIZE, tensor_parallel_size=TENSOR_PARALLEL_SIZE, - pipeline_parallel_size=pipeline_parallel_size, + pipeline_parallel_size=PIPELINE_PARALLEL_SIZE, data_parallel_size=DATA_PARALLEL_SIZE, - package=forward_package, )