Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Sep 18, 2023
1 parent 64e619c commit 6018e82
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 17 deletions.
13 changes: 7 additions & 6 deletions pipegoose/nn/pipeline_parallel2/_comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,22 @@

from pipegoose.distributed.parallel_context import ParallelContext
from pipegoose.nn.pipeline_parallel2._package import Package
from pipegoose.nn.pipeline_parallel2.pipeline_context import PipelineContext

RECV_QUEUE = Queue()

# TODO: refactor to a singleton class
# NOTE: save parallel context for backward job
PIPELINE_CONTEXT = None
_PIPELINE_CONTEXT = None


def set_pipeline_context(pipeline_context):
global PIPELINE_CONTEXT
PIPELINE_CONTEXT = pipeline_context
def set_pipeline_context(pipeline_context: PipelineContext):
global _PIPELINE_CONTEXT
_PIPELINE_CONTEXT = pipeline_context


def get_pipeline_context():
return PIPELINE_CONTEXT
def get_pipeline_context() -> PipelineContext:
return _PIPELINE_CONTEXT


def _send_data(data: Any, src: int, dst: int, parallel_context: ParallelContext):
Expand Down
12 changes: 3 additions & 9 deletions pipegoose/nn/pipeline_parallel2/_job/creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,8 @@ def create_job(function: Callable, package: Package, pipeline_context: PipelineC
return job


# def create_backwardable_forward_job(function: Callable, package: Package, pipeline_context: PipelineContext) -> ForwardJob:
# """
# Create a forward job that automatically schedules
# a backward job if you call forward(input).backward()
# """


def create_backward_job_and_put_to_pending_queue(grad_input: torch.Tensor, metadata: Metadata):
def _create_backward_job_and_put_to_pending_queue(grad_input: torch.Tensor, metadata: Metadata):
"""Create a backward job and put it to pending queue."""
# NOTE: construct backward package
data = torch.randn(2, 4)
package = Package(data, metadata)
Expand Down Expand Up @@ -121,7 +115,7 @@ def backward(ctx: Any, grad_input: torch.Tensor):
# NOTE: the backward job create in the same node
# as the forward job
to=parallel_context.get_worker_name(metadata.src),
func=create_backward_job_and_put_to_pending_queue,
func=_create_backward_job_and_put_to_pending_queue,
args=(grad_input, metadata),
)

Expand Down
4 changes: 2 additions & 2 deletions tests/nn/pipeline_parallel_2/job/test_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,11 @@ def run_create_a_backward_job_if_a_tensor_do_backprop(
rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size, forward_package
):
SRC = forward_package.metadata.src
N_PARTITIONS = 3
N_MICROBATCHES = 5
parallel_context = init_parallel_context(
rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size
)
N_PARTITIONS = 3
N_MICROBATCHES = 5

scheduler = get_scheduler(SchedulerType.GPIPE)(N_MICROBATCHES, N_PARTITIONS)
pipeline_context = PipelineContext(scheduler, parallel_context)
Expand Down

0 comments on commit 6018e82

Please sign in to comment.