Skip to content

Commit

Permalink
WIP fixing the schedule for the backward job, computing gradients in …
Browse files Browse the repository at this point in the history
…the pipeline isn't working yet
  • Loading branch information
xrsrke committed Oct 2, 2023
1 parent 10eeda0 commit 096cc70
Show file tree
Hide file tree
Showing 8 changed files with 118 additions and 46 deletions.
15 changes: 9 additions & 6 deletions pipegoose/nn/pipeline_parallel2/_job/backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,14 @@ def run_compute(self) -> torch.Tensor:
microbatch_idx = self.input.metadata.microbatch_idx
partition_idx = self.input.metadata.partition_idx
key = SavedActivation.get_key(microbatch_idx, partition_idx)
outputs = SavedActivation.get_saved_activations(key)
inputs = self.input.data
output = SavedActivation.get_saved_activations(key)
prev_grad = self.input.data

if inputs.requires_grad:
with torch.enable_grad():
torch.autograd.backward(inputs, outputs)
rank = self.pipeline_context.parallel_context.get_global_rank()

return inputs.grad
print(f"executing backward job, rank={rank}, microbatch_idx={microbatch_idx}, partition_idx={partition_idx}")

with torch.enable_grad():
torch.autograd.backward(output, prev_grad)

return output.grad
23 changes: 15 additions & 8 deletions pipegoose/nn/pipeline_parallel2/_job/creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,18 +81,23 @@ def create_job(function: Callable, package: Package, pipeline_context: PipelineC
def _create_backward_job_and_put_to_pending_queue(grad_input: torch.Tensor, metadata: Metadata):
"""Create a backward job and put it to pending queue."""
# NOTE: construct backward package
data = torch.randn(2, 4)
package = Package(data, metadata)
package = Package(grad_input, metadata)
package.metadata.job_type = JobType.BACKWARD

# NOTE: construct backward job
def backward_function(self):
pass

# TODO: make parallel_context automatically set when it initialize
parallel_context = get_pipeline_context()
pipeline_context = get_pipeline_context()
parallel_context = pipeline_context.parallel_context

backward_job = create_job(backward_function, package, parallel_context)
rank = parallel_context.get_global_rank()
microbatch_idx = metadata.microbatch_idx

print(f"invoked create_backward_job_and_put_to_pending_queue, rank={rank}, microbatch_idx={microbatch_idx}")

backward_job = create_job(backward_function, package, pipeline_context)

# NOTE : put the backward job to pending queue
JobQueue.PENDING_JOBS.put(backward_job)
Expand All @@ -105,8 +110,8 @@ class _ScheduleBackwardJob(torch.autograd.Function):
@staticmethod
def forward(ctx, metadata: Metadata, pipeline_context: PipelineContext, input: torch.Tensor):
# NOTE: can't assign metadata attribute to ctx
# AttributeError: attribute 'metadata' of 'torch._C._FunctionBase'
# objects is not writable
# "AttributeError: attribute 'metadata' of 'torch._C._FunctionBase'
# objects is not writable"
rank = pipeline_context.parallel_context.get_global_rank()
print(f"scheduled a backward job, rank={rank}, microbatch_idx={metadata.microbatch_idx}")
ctx.package_meta = metadata
Expand All @@ -121,14 +126,16 @@ def backward(ctx: Any, grad_input: torch.Tensor):

rank = parallel_context.get_global_rank()
microbatch_idx = metadata.microbatch_idx
print(f"creating a backward job, rank={rank}, microbatch_idx={microbatch_idx}")

dst_worker_name = parallel_context.get_worker_name(metadata.dst)
print(f"creating a backward job, rank={rank}, microbatch_idx={microbatch_idx}, dst_worker_name={dst_worker_name}")

# TODO: because forward job and backward job are in the same node
# rpc isn't necessary
rpc.rpc_sync(
# NOTE: the backward job create in the same node
# as the forward job
to=parallel_context.get_worker_name(metadata.src),
to=dst_worker_name,
func=_create_backward_job_and_put_to_pending_queue,
args=(grad_input, metadata),
)
Expand Down
5 changes: 4 additions & 1 deletion pipegoose/nn/pipeline_parallel2/_job/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@

class ForwardJob(Job):
def run_compute(self) -> torch.Tensor:
return self.function(self.input.data)
with torch.set_grad_enabled(True):
return self.function(self.input.data)


class CreateForwardOutputPackageCallback(Callback):
Expand Down Expand Up @@ -65,6 +66,7 @@ class SaveActivationIfTrainingCallback(Callback):

def after_compute(self):
is_training = self.job.input.metadata.training.is_training

if is_training is True:
from pipegoose.nn.pipeline_parallel2.queue import SavedActivation

Expand All @@ -73,6 +75,7 @@ def after_compute(self):
partition_idx = self.job.input.metadata.partition_idx

key = SavedActivation.get_key(microbatch_idx, partition_idx)
print("saving activation, data.shape=", self.job.output.data.shape)
SavedActivation.save_activations(key, self.job.output.data)


Expand Down
5 changes: 2 additions & 3 deletions pipegoose/nn/pipeline_parallel2/pipeline_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from pipegoose.distributed.parallel_context import ParallelContext
from pipegoose.distributed.parallel_mode import ParallelMode
from pipegoose.nn.pipeline_parallel2._utils import get_partition_idx
from pipegoose.nn.pipeline_parallel2._utils import get_partition_idx, is_last_stage
from pipegoose.nn.pipeline_parallel2.scheduler import BaseScheduler


Expand Down Expand Up @@ -100,5 +100,4 @@ def is_first_stage(self) -> bool:

@property
def is_last_stage(self) -> bool:
n_pipeline_stages = self.parallel_context.pipeline_parallel_size
return self.partition_idx == n_pipeline_stages - 1
return is_last_stage(self.parallel_context)
51 changes: 43 additions & 8 deletions pipegoose/nn/pipeline_parallel2/pipeline_engine.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import time
from dataclasses import dataclass

import torch
Expand All @@ -7,7 +8,10 @@
from pipegoose.distributed.parallel_context import ParallelContext
from pipegoose.distributed.parallel_mode import ParallelMode
from pipegoose.nn.pipeline_parallel2._comm import RECV_QUEUE
from pipegoose.nn.pipeline_parallel2._job.creator import create_job
from pipegoose.nn.pipeline_parallel2._job.creator import (
create_job,
schedule_backward_job,
)
from pipegoose.nn.pipeline_parallel2._job.job_type import JobType
from pipegoose.nn.pipeline_parallel2._package import Metadata, Package, TrainingMetadata
from pipegoose.nn.pipeline_parallel2._worker import BaseWorkerManager
Expand Down Expand Up @@ -74,8 +78,16 @@ def __init__(self, pipeline_context):

def after_new_clock_cycle(self, progress, clock_idx):
parallel_context = self.pipeline_context.parallel_context
print(f"increase clock, clock_idx={clock_idx}, rank={parallel_context.get_local_rank(ParallelMode.GLOBAL)}")
self.pipeline_context.increase_a_clock_cycle()

# NOTE: suppose we have tensor_parallel_size = 3
# that means a pipeline stage is split into 3 slices
# we want only one slice to increase the clock
# here we choose the last slice to increase the clock
if parallel_context.is_last_rank(ParallelMode.TENSOR):
print(
f"increase clock, clock_idx={clock_idx}, rank={parallel_context.get_local_rank(ParallelMode.GLOBAL)}"
)
self.pipeline_context.increase_a_clock_cycle()

callbacks = [IncreasePipelineContextClockCycleCallback(self.pipeline_context)]
progress_tracker = ProgressTracker(
Expand Down Expand Up @@ -127,7 +139,7 @@ def after_new_clock_cycle(self, progress, clock_idx):
package = self._construct_first_package(microbatch_idx, input=batch)
else:
package = RECV_QUEUE.get()

package = schedule_backward_job(package, self.pipeline_context)
job = create_job(self.partition_func, package, self.pipeline_context)
JobQueue.PENDING_JOBS.put(job)

Expand All @@ -136,11 +148,34 @@ def after_new_clock_cycle(self, progress, clock_idx):
dist.barrier()

if self.pipeline_context.is_last_stage:
from pipegoose.nn.pipeline_parallel2.queue import _SAVED_ACTIVATIONS
from pipegoose.nn.pipeline_parallel2.queue import (
_SAVED_ACTIVATIONS,
SavedActivation,
)

# TODO: use SavedActivation.get_key()
# outputs = [SavedActivation.get_saved_activations((microbatch_idx, partition_idx)) for microbatch_idx in range(n_microbatches)]
# outputs = torch.cat(outputs, dim=0)
# print(f"outputs.shape={outputs.shape}")

print("just run output.backward()")

outputs = [_SAVED_ACTIVATIONS[(microbatch_idx, partition_idx)] for microbatch_idx in range(n_microbatches)]
outputs = torch.cat(outputs, dim=0)
return outputs
# outputs.sum().backward()

# TODO: refactor this, this only take the last activations and trigger backward
key = SavedActivation.get_key(microbatch_idx=0, partition_idx=partition_idx)
output = _SAVED_ACTIVATIONS[key]

# SavedActivation.get_saved_activations((0, partition_idx)).sum().backward()
output.sum().backward()

time.sleep(100)
assert 1 == 1
else:
# NOTE: not terminate the worker, make it wait for processing further backward jobs
time.sleep(100)

dist.barrier()

def _construct_first_package(self, microbatch_idx: int, input: torch.Tensor):
"""Construct the first forward package of a microbatch."""
Expand Down
34 changes: 22 additions & 12 deletions tests/nn/pipeline_parallel_2/job/test_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,20 +59,23 @@ def run_create_a_backward_job_if_a_tensor_do_backprop(
data = forward_package.data
data.sum().backward()

# NOTE: wait for the backward job to be created
dist.barrier()
time.sleep(3)
time.sleep(0.1)

if rank == SRC:
# NOTE: since we don't launch any job selector workers in the background,
# after triggering the creation of a backward job,
# we expect the destination worker's job queue to have one job
assert JobQueue.PENDING_JOBS.qsize() == 1

backward_job = JobQueue.PENDING_JOBS.get()

assert isinstance(backward_job, BackwardJob)

# NOTE: wait for the backward job to be created
dist.barrier()
time.sleep(0.1)

if rank == SRC:
assert JobQueue.PENDING_JOBS.qsize() == 0


@pytest.mark.parametrize("pipeline_parallel_size", [2, 5])
@pytest.mark.parametrize("package", ["forward_package_in_same_node", "forward_package_in_different_nodes"])
Expand All @@ -93,23 +96,30 @@ def test_create_a_backward_job_if_a_tensor_do_backprop_in_the_same_node(request,


def test_execute_a_backward_job(backward_job):
# NOTE: save activations
MICROBATCH_IDX = backward_job.input.metadata.microbatch_idx
PARTITION_IDX = backward_job.input.metadata.partition_idx
INPUT = backward_job.input.data
backward_job.input.data = torch.ones_like(INPUT)

key = SavedActivation.get_key(MICROBATCH_IDX, PARTITION_IDX)
linear = nn.Linear(INPUT.shape[1], INPUT.shape[0])
# output = linear(INPUT)

# NOTE: compute the ground truth gradient
output = torch.randn_like(INPUT)
SavedActivation.save_activations(key, output)
with torch.set_grad_enabled(True):
output = linear(output)

torch.autograd.backward(INPUT, output, retain_graph=True)
GROUND_GRADIENT = deepcopy(INPUT.grad)
INPUT.grad = None
output.sum().backward()
# GROUND_GRADIENT = deepcopy(INPUT.grad)
# INPUT.grad = None

key = SavedActivation.get_key(MICROBATCH_IDX, PARTITION_IDX)
SavedActivation.save_activations(key, output)

grad_package = backward_job.compute()

assert grad_package.data.shape == INPUT.shape
assert torch.equal(grad_package.data, GROUND_GRADIENT)
# assert torch.allclose(grad_package.data, GROUND_GRADIENT)


@pytest.mark.skip
Expand Down
23 changes: 17 additions & 6 deletions tests/nn/pipeline_parallel_2/job/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@
function = nn.Linear(2, 4)


# OUTPUT_TO_SRC_DST_MAPPING = {
# 0: (0, 1),
# 1: (1, 1),
# 2: (1, 2),
# 3: (2, 3)
# }


def test_the_output_package_of_a_forward_job(forward_package, pipeline_context):
# NOTE: (microbatch_idx, partition_idx) -> (microbatch_idx, next_partition_idx)
OUTPUT_DESTINATION = {
Expand All @@ -28,6 +36,10 @@ def test_the_output_package_of_a_forward_job(forward_package, pipeline_context):
(4, 1): (4, 2),
}

OUTPUT_SRC_DST_RANK_MAPPING = {
(0): (0, 1),
}

forward_job = create_job(function, forward_package, pipeline_context)
ORIG_MICROBATCH_IDX = forward_job.input.metadata.microbatch_idx
ORIG_PARTITION_IDX = forward_job.input.metadata.partition_idx
Expand All @@ -44,15 +56,14 @@ def test_the_output_package_of_a_forward_job(forward_package, pipeline_context):
output.metadata.partition_idx,
)
for key in vars(output.metadata.training).keys():
# TODO: add test automatically switch to create new package
# for different mix precision training
assert getattr(output.metadata.training, key) == getattr(forward_job.input.metadata.training, key)

# NOTE: we expect the metadata of the output package to
# indicate which node executed it
# TODO: update source rank and destination rank based on pipeline context
assert isinstance(output.metadata.src, int)
assert isinstance(output.metadata.dst, int)
# indicate which node executed it, and the destination node
src, dst = output.metadata.src, output.metadata.dst
assert isinstance(src, int)
assert isinstance(dst, int)
assert (src, dst) == OUTPUT_SRC_DST_RANK_MAPPING[ORIG_MICROBATCH_IDX]


def test_forward_job_save_activations_for_backward_pass(forward_package, pipeline_context):
Expand Down
8 changes: 6 additions & 2 deletions tests/nn/pipeline_parallel_2/test_pipeline_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ def forward(self, input):

forward_timeline = []

scheduler = GPipeScheduler(n_microbatches, pipeline_parallel_size)
parallel_context = init_parallel_context(
rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size
)
scheduler = GPipeScheduler(n_microbatches, pipeline_parallel_size)
worker_manager = WorkerManager()
partition_idx = get_partition_idx(parallel_context)
partition_func = Function(partition_idx)
Expand All @@ -59,12 +59,14 @@ def forward(self, input):
if is_last_stage(parallel_context):
assert torch.allclose(p_outputs, outputs)
assert forward_timeline == EXPECTED_FORWARD_TIMELINE

# p_outputs.sum().backward()
else:
# NOTE: earlier stages should not return the final output
assert p_outputs is None


@pytest.mark.parametrize("tensor_parallel_size, pipeline_parallel_size, data_parallel_size", [(1, 4, 1)])
@pytest.mark.parametrize("tensor_parallel_size, pipeline_parallel_size, data_parallel_size", [(1, 4, 1), (2, 4, 2)])
def test_pipeline_engine(tensor_parallel_size, pipeline_parallel_size, data_parallel_size):
BATCH_SIZE = 32
N_MICROBATCHES = 6
Expand All @@ -76,6 +78,8 @@ def test_pipeline_engine(tensor_parallel_size, pipeline_parallel_size, data_para
model = nn.ModuleList([nn.Sequential(nn.Linear(HIDDEN_DIM, HIDDEN_DIM), nn.ReLU()) for _ in range(pipeline_parallel_size)])
outputs = reduce(lambda inputs, layer: layer(inputs), model, inputs)

# outputs.sum().backward()

spawn(
run_pipeline_engine,
world_size=WORLD_SIZE,
Expand Down

0 comments on commit 096cc70

Please sign in to comment.