diff --git a/pipegoose/nn/pipeline_parallel2/_job/creator.py b/pipegoose/nn/pipeline_parallel2/_job/creator.py index 2fac89b..a3c4208 100644 --- a/pipegoose/nn/pipeline_parallel2/_job/creator.py +++ b/pipegoose/nn/pipeline_parallel2/_job/creator.py @@ -14,6 +14,7 @@ SendBackwardPackageCallback, ) from pipegoose.nn.pipeline_parallel2._job.forward import ( + ConfirmCompleteATaskToProgressTracker, CreateForwardOutputPackageCallback, ForwardJob, SaveActivationIfTrainingCallback, @@ -37,7 +38,12 @@ def create(self) -> Job: class _ForwardJobCreator(JobCreator): """Put a forward job into job queue for a worker to execute.""" - CBS = [CreateForwardOutputPackageCallback, SaveActivationIfTrainingCallback, SendForwardPackageCallback] + CBS = [ + CreateForwardOutputPackageCallback, + SaveActivationIfTrainingCallback, + SendForwardPackageCallback, + ConfirmCompleteATaskToProgressTracker, + ] @classmethod def create(cls, function: Callable, package: Package, pipeline_context: PipelineContext) -> ForwardJob: diff --git a/pipegoose/nn/pipeline_parallel2/_job/forward.py b/pipegoose/nn/pipeline_parallel2/_job/forward.py index 6ac8296..cb6e330 100644 --- a/pipegoose/nn/pipeline_parallel2/_job/forward.py +++ b/pipegoose/nn/pipeline_parallel2/_job/forward.py @@ -6,10 +6,8 @@ from pipegoose.nn.pipeline_parallel2._comm import send_package from pipegoose.nn.pipeline_parallel2._job.callback import Callback from pipegoose.nn.pipeline_parallel2._job.job import Job -from pipegoose.nn.pipeline_parallel2._job.job_type import JobType from pipegoose.nn.pipeline_parallel2._package import Package from pipegoose.nn.pipeline_parallel2.sync.handshake import get_progress_tracker -from pipegoose.nn.pipeline_parallel2.task import Task class ForwardJob(Job): @@ -46,14 +44,14 @@ def _update_next_pipeline_stage(self, package: Package) -> Package: # TODO: take into account that a pipeline stage can has more than one task # in a clock cycle, then find the correspond task to send the output to - print("---------- _update_next_pipeline_stage -----------") - print(f"rank: {self.job.pipeline_context.parallel_context.get_local_rank(ParallelMode.GLOBAL)}") - print(f"clock_idx: {self.job.pipeline_context.clock_idx}") - print( - f"schedules = {self.job.pipeline_context.get_schedule_from_microbatch(clock_idx=self.job.pipeline_context.clock_idx+1, microbatch_idx=microbatch_idx)}" - ) - print(f"microbatch_idx: {microbatch_idx}") - print(f"next_schedule: {next_schedule}") + # print("---------- _update_next_pipeline_stage -----------") + # print(f"rank: {self.job.pipeline_context.parallel_context.get_local_rank(ParallelMode.GLOBAL)}") + # print(f"clock_idx: {self.job.pipeline_context.clock_idx}") + # print( + # f"schedules = {self.job.pipeline_context.get_schedule_from_microbatch(clock_idx=self.job.pipeline_context.clock_idx+1, microbatch_idx=microbatch_idx)}" + # ) + # print(f"microbatch_idx: {microbatch_idx}") + # print(f"next_schedule: {next_schedule}") next_partition = next_schedule[0].partition_idx package.metadata.partition_idx = next_partition @@ -110,9 +108,11 @@ class ConfirmCompleteATaskToProgressTracker(Callback): def after_compute(self): progress_tracker = get_progress_tracker() microbatch_idx = self.job.input.metadata.microbatch_idx - task = Task( - job_type=JobType.FORWARD, - microbatch_idx=microbatch_idx, - partition_idx=self.job.input.metadata.partition_idx, - ) - progress_tracker.confirm(task) + partition_idx = self.job.input.metadata.partition_idx + # task = Task( + # job_type=JobType.FORWARD, + # microbatch_idx=microbatch_idx, + # partition_idx=self.job.input.metadata.partition_idx, + # ) + key = (microbatch_idx, partition_idx) + progress_tracker.confirm(key) diff --git a/pipegoose/nn/pipeline_parallel2/_utils.py b/pipegoose/nn/pipeline_parallel2/_utils.py index 0182ed1..a6923c0 100644 --- a/pipegoose/nn/pipeline_parallel2/_utils.py +++ b/pipegoose/nn/pipeline_parallel2/_utils.py @@ -10,6 +10,7 @@ def sleep(timeout: int = 0.05): def get_partition_idx(parallel_context: ParallelContext) -> int: rank = parallel_context.get_local_rank(ParallelMode.PIPELINE) - n_ranks_per_group = len(parallel_context.get_ranks_in_group(ParallelMode.PIPELINE)) - pipeline_stage_idx = rank // n_ranks_per_group - return pipeline_stage_idx + ranks_in_group = parallel_context.get_ranks_in_group(ParallelMode.PIPELINE) + # pipeline_stage_idx = rank // n_ranks_per_group + # return pipeline_stage_idx + return ranks_in_group.index(rank) diff --git a/pipegoose/nn/pipeline_parallel2/pipeline_context.py b/pipegoose/nn/pipeline_parallel2/pipeline_context.py index 9d1a8d8..0c25c85 100644 --- a/pipegoose/nn/pipeline_parallel2/pipeline_context.py +++ b/pipegoose/nn/pipeline_parallel2/pipeline_context.py @@ -2,7 +2,7 @@ from typing import List from pipegoose.distributed.parallel_context import ParallelContext -from pipegoose.distributed.parallel_mode import ParallelMode +from pipegoose.nn.pipeline_parallel2._utils import get_partition_idx from pipegoose.nn.pipeline_parallel2.scheduler import BaseScheduler @@ -21,9 +21,10 @@ def __init__(self, scheduler: BaseScheduler, parallel_context: ParallelContext): @property def partition_idx(self) -> int: parallel_context = self.parallel_context - rank = parallel_context.get_local_rank(ParallelMode.PIPELINE) - n_ranks_per_group = len(parallel_context.get_ranks_in_group(ParallelMode.PIPELINE)) - pipeline_stage_idx = rank // n_ranks_per_group + # rank = parallel_context.get_local_rank(ParallelMode.PIPELINE) + # n_ranks_per_group = len(parallel_context.get_ranks_in_group(ParallelMode.PIPELINE)) + # pipeline_stage_idx = rank // n_ranks_per_group + pipeline_stage_idx = get_partition_idx(parallel_context) return pipeline_stage_idx @property @@ -52,7 +53,7 @@ def get_schedule(self): with self._wait_new_clock_cycle: while self.clock_idx < self.scheduler.total_clock_cycles: schedules = self.get_schedule_from_partition(self.clock_idx, self.partition_idx) - yield from schedules + yield schedules # NOTE: wait for the next clock cycle self._wait_new_clock_cycle.wait() diff --git a/pipegoose/nn/pipeline_parallel2/pipeline_engine.py b/pipegoose/nn/pipeline_parallel2/pipeline_engine.py index 5113186..7f61c46 100644 --- a/pipegoose/nn/pipeline_parallel2/pipeline_engine.py +++ b/pipegoose/nn/pipeline_parallel2/pipeline_engine.py @@ -83,8 +83,14 @@ def after_new_clock_cycle(self, progress, clock_idx): ) schedules = self.pipeline_context.schedules - progress_tracker.initiate(schedules) + progress = { + i: {(item.microbatch_idx, item.partition_idx): False for item in sublist} + for i, sublist in enumerate(schedules) + } + progress_tracker.initiate(progress) + time.sleep(2) else: + time.sleep(2) progress_tracker = ProgressTracker( MASTER_RANK, callbacks=callbacks, parallel_context=self.parallel_context, parallel_mode=ParallelMode.GLOBAL ) @@ -97,28 +103,38 @@ def after_new_clock_cycle(self, progress, clock_idx): time.sleep(2) - for task in self.pipeline_context.get_schedule(): - time.sleep(2) - print("[loop] ------------------") - print("[loop] clock_idx", self.pipeline_context.clock_idx) - print("[loop] rank", self.parallel_context.get_local_rank(ParallelMode.GLOBAL)) - - microbatch_idx = task.microbatch_idx - partition_idx = task.partition_idx - if self.parallel_context.is_first_rank(ParallelMode.PIPELINE): - if partition_idx == 0: - batch = microbatches[microbatch_idx] - package = self._construct_first_package(microbatch_idx, input=batch) - else: - package = RECV_QUEUE.get() + for tasks in self.pipeline_context.get_schedule(): - print("received a package", package.metadata) - - job = create_job(self.partition_func, package, self.pipeline_context) - - print(f"created a job: {package.metadata}") - - JobQueue.PENDING_JOBS.put(job) + time.sleep(2) + # print("[loop] ------------------") + rank = self.parallel_context.get_local_rank(ParallelMode.GLOBAL) + partition_idx = self.pipeline_context.partition_idx + + if rank == 0: + assert 1 == 1 + + if len(tasks) > 0: + print(f"[loop] clock_idx={self.pipeline_context.clock_idx}, rank={rank}, partition_idx={partition_idx}") + for task in tasks: + microbatch_idx = task.microbatch_idx + partition_idx = task.partition_idx + if self.parallel_context.is_first_rank(ParallelMode.PIPELINE): + if partition_idx == 0: + batch = microbatches[microbatch_idx] + package = self._construct_first_package(microbatch_idx, input=batch) + else: + package = RECV_QUEUE.get() + + print( + f"[received a package]clock_idx={self.pipeline_context.clock_idx}, rank={rank}, partition_idx={partition_idx}", + package.metadata, + ) + + job = create_job(self.partition_func, package, self.pipeline_context) + + # print(f"created a job: {package.metadata}") + + JobQueue.PENDING_JOBS.put(job) # def _retrieve_package_from_received_package(self, microbatch_idx, partition_idx): # # package = RECV_QUEUE[(microbatch_idx, partition_idx)] diff --git a/pipegoose/nn/pipeline_parallel2/sync/handshake.py b/pipegoose/nn/pipeline_parallel2/sync/handshake.py index 49da4e5..cbb81f0 100644 --- a/pipegoose/nn/pipeline_parallel2/sync/handshake.py +++ b/pipegoose/nn/pipeline_parallel2/sync/handshake.py @@ -137,6 +137,8 @@ def confirm(self, task: Task): # rank = self.parallel_context.get_local_rank(self.parallel_mode) rpc.rpc_sync(master_worker_name, func=ProgressTracker._recv_confirm_from_worker, args=(task,)) + # NOTE: if master node confirm itself, then no need rpc call + # NOTE: a worker node should confirm itself ProgressTracker._update_progress(task)