Skip to content

Commit

Permalink
seems to work, but the first rank is stuck
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Sep 27, 2023
1 parent 3be30ed commit 62b6680
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 47 deletions.
8 changes: 7 additions & 1 deletion pipegoose/nn/pipeline_parallel2/_job/creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
SendBackwardPackageCallback,
)
from pipegoose.nn.pipeline_parallel2._job.forward import (
ConfirmCompleteATaskToProgressTracker,
CreateForwardOutputPackageCallback,
ForwardJob,
SaveActivationIfTrainingCallback,
Expand All @@ -37,7 +38,12 @@ def create(self) -> Job:
class _ForwardJobCreator(JobCreator):
"""Put a forward job into job queue for a worker to execute."""

CBS = [CreateForwardOutputPackageCallback, SaveActivationIfTrainingCallback, SendForwardPackageCallback]
CBS = [
CreateForwardOutputPackageCallback,
SaveActivationIfTrainingCallback,
SendForwardPackageCallback,
ConfirmCompleteATaskToProgressTracker,
]

@classmethod
def create(cls, function: Callable, package: Package, pipeline_context: PipelineContext) -> ForwardJob:
Expand Down
32 changes: 16 additions & 16 deletions pipegoose/nn/pipeline_parallel2/_job/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@
from pipegoose.nn.pipeline_parallel2._comm import send_package
from pipegoose.nn.pipeline_parallel2._job.callback import Callback
from pipegoose.nn.pipeline_parallel2._job.job import Job
from pipegoose.nn.pipeline_parallel2._job.job_type import JobType
from pipegoose.nn.pipeline_parallel2._package import Package
from pipegoose.nn.pipeline_parallel2.sync.handshake import get_progress_tracker
from pipegoose.nn.pipeline_parallel2.task import Task


class ForwardJob(Job):
Expand Down Expand Up @@ -46,14 +44,14 @@ def _update_next_pipeline_stage(self, package: Package) -> Package:
# TODO: take into account that a pipeline stage can has more than one task
# in a clock cycle, then find the correspond task to send the output to

print("---------- _update_next_pipeline_stage -----------")
print(f"rank: {self.job.pipeline_context.parallel_context.get_local_rank(ParallelMode.GLOBAL)}")
print(f"clock_idx: {self.job.pipeline_context.clock_idx}")
print(
f"schedules = {self.job.pipeline_context.get_schedule_from_microbatch(clock_idx=self.job.pipeline_context.clock_idx+1, microbatch_idx=microbatch_idx)}"
)
print(f"microbatch_idx: {microbatch_idx}")
print(f"next_schedule: {next_schedule}")
# print("---------- _update_next_pipeline_stage -----------")
# print(f"rank: {self.job.pipeline_context.parallel_context.get_local_rank(ParallelMode.GLOBAL)}")
# print(f"clock_idx: {self.job.pipeline_context.clock_idx}")
# print(
# f"schedules = {self.job.pipeline_context.get_schedule_from_microbatch(clock_idx=self.job.pipeline_context.clock_idx+1, microbatch_idx=microbatch_idx)}"
# )
# print(f"microbatch_idx: {microbatch_idx}")
# print(f"next_schedule: {next_schedule}")

next_partition = next_schedule[0].partition_idx
package.metadata.partition_idx = next_partition
Expand Down Expand Up @@ -110,9 +108,11 @@ class ConfirmCompleteATaskToProgressTracker(Callback):
def after_compute(self):
progress_tracker = get_progress_tracker()
microbatch_idx = self.job.input.metadata.microbatch_idx
task = Task(
job_type=JobType.FORWARD,
microbatch_idx=microbatch_idx,
partition_idx=self.job.input.metadata.partition_idx,
)
progress_tracker.confirm(task)
partition_idx = self.job.input.metadata.partition_idx
# task = Task(
# job_type=JobType.FORWARD,
# microbatch_idx=microbatch_idx,
# partition_idx=self.job.input.metadata.partition_idx,
# )
key = (microbatch_idx, partition_idx)
progress_tracker.confirm(key)
7 changes: 4 additions & 3 deletions pipegoose/nn/pipeline_parallel2/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ def sleep(timeout: int = 0.05):

def get_partition_idx(parallel_context: ParallelContext) -> int:
rank = parallel_context.get_local_rank(ParallelMode.PIPELINE)
n_ranks_per_group = len(parallel_context.get_ranks_in_group(ParallelMode.PIPELINE))
pipeline_stage_idx = rank // n_ranks_per_group
return pipeline_stage_idx
ranks_in_group = parallel_context.get_ranks_in_group(ParallelMode.PIPELINE)
# pipeline_stage_idx = rank // n_ranks_per_group
# return pipeline_stage_idx
return ranks_in_group.index(rank)
11 changes: 6 additions & 5 deletions pipegoose/nn/pipeline_parallel2/pipeline_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import List

from pipegoose.distributed.parallel_context import ParallelContext
from pipegoose.distributed.parallel_mode import ParallelMode
from pipegoose.nn.pipeline_parallel2._utils import get_partition_idx
from pipegoose.nn.pipeline_parallel2.scheduler import BaseScheduler


Expand All @@ -21,9 +21,10 @@ def __init__(self, scheduler: BaseScheduler, parallel_context: ParallelContext):
@property
def partition_idx(self) -> int:
parallel_context = self.parallel_context
rank = parallel_context.get_local_rank(ParallelMode.PIPELINE)
n_ranks_per_group = len(parallel_context.get_ranks_in_group(ParallelMode.PIPELINE))
pipeline_stage_idx = rank // n_ranks_per_group
# rank = parallel_context.get_local_rank(ParallelMode.PIPELINE)
# n_ranks_per_group = len(parallel_context.get_ranks_in_group(ParallelMode.PIPELINE))
# pipeline_stage_idx = rank // n_ranks_per_group
pipeline_stage_idx = get_partition_idx(parallel_context)
return pipeline_stage_idx

@property
Expand Down Expand Up @@ -52,7 +53,7 @@ def get_schedule(self):
with self._wait_new_clock_cycle:
while self.clock_idx < self.scheduler.total_clock_cycles:
schedules = self.get_schedule_from_partition(self.clock_idx, self.partition_idx)
yield from schedules
yield schedules

# NOTE: wait for the next clock cycle
self._wait_new_clock_cycle.wait()
Expand Down
60 changes: 38 additions & 22 deletions pipegoose/nn/pipeline_parallel2/pipeline_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,14 @@ def after_new_clock_cycle(self, progress, clock_idx):
)

schedules = self.pipeline_context.schedules
progress_tracker.initiate(schedules)
progress = {
i: {(item.microbatch_idx, item.partition_idx): False for item in sublist}
for i, sublist in enumerate(schedules)
}
progress_tracker.initiate(progress)
time.sleep(2)
else:
time.sleep(2)
progress_tracker = ProgressTracker(
MASTER_RANK, callbacks=callbacks, parallel_context=self.parallel_context, parallel_mode=ParallelMode.GLOBAL
)
Expand All @@ -97,28 +103,38 @@ def after_new_clock_cycle(self, progress, clock_idx):

time.sleep(2)

for task in self.pipeline_context.get_schedule():
time.sleep(2)
print("[loop] ------------------")
print("[loop] clock_idx", self.pipeline_context.clock_idx)
print("[loop] rank", self.parallel_context.get_local_rank(ParallelMode.GLOBAL))

microbatch_idx = task.microbatch_idx
partition_idx = task.partition_idx
if self.parallel_context.is_first_rank(ParallelMode.PIPELINE):
if partition_idx == 0:
batch = microbatches[microbatch_idx]
package = self._construct_first_package(microbatch_idx, input=batch)
else:
package = RECV_QUEUE.get()
for tasks in self.pipeline_context.get_schedule():

print("received a package", package.metadata)

job = create_job(self.partition_func, package, self.pipeline_context)

print(f"created a job: {package.metadata}")

JobQueue.PENDING_JOBS.put(job)
time.sleep(2)
# print("[loop] ------------------")
rank = self.parallel_context.get_local_rank(ParallelMode.GLOBAL)
partition_idx = self.pipeline_context.partition_idx

if rank == 0:
assert 1 == 1

if len(tasks) > 0:
print(f"[loop] clock_idx={self.pipeline_context.clock_idx}, rank={rank}, partition_idx={partition_idx}")
for task in tasks:
microbatch_idx = task.microbatch_idx
partition_idx = task.partition_idx
if self.parallel_context.is_first_rank(ParallelMode.PIPELINE):
if partition_idx == 0:
batch = microbatches[microbatch_idx]
package = self._construct_first_package(microbatch_idx, input=batch)
else:
package = RECV_QUEUE.get()

print(
f"[received a package]clock_idx={self.pipeline_context.clock_idx}, rank={rank}, partition_idx={partition_idx}",
package.metadata,
)

job = create_job(self.partition_func, package, self.pipeline_context)

# print(f"created a job: {package.metadata}")

JobQueue.PENDING_JOBS.put(job)

# def _retrieve_package_from_received_package(self, microbatch_idx, partition_idx):
# # package = RECV_QUEUE[(microbatch_idx, partition_idx)]
Expand Down
2 changes: 2 additions & 0 deletions pipegoose/nn/pipeline_parallel2/sync/handshake.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ def confirm(self, task: Task):
# rank = self.parallel_context.get_local_rank(self.parallel_mode)
rpc.rpc_sync(master_worker_name, func=ProgressTracker._recv_confirm_from_worker, args=(task,))

# NOTE: if master node confirm itself, then no need rpc call

# NOTE: a worker node should confirm itself
ProgressTracker._update_progress(task)

Expand Down

0 comments on commit 62b6680

Please sign in to comment.