From 10eeda09f2de57ec5f0b7198f14c8f305d773ac4 Mon Sep 17 00:00:00 2001 From: xrsrke Date: Sun, 1 Oct 2023 12:14:50 +0700 Subject: [PATCH] add calculating gradients in backward job --- .../nn/pipeline_parallel2/_job/backward.py | 21 ++++++++------ .../pipeline_parallel_2/job/test_backward.py | 28 ++++++++++++++----- 2 files changed, 33 insertions(+), 16 deletions(-) diff --git a/pipegoose/nn/pipeline_parallel2/_job/backward.py b/pipegoose/nn/pipeline_parallel2/_job/backward.py index 55069ee..2b0b059 100644 --- a/pipegoose/nn/pipeline_parallel2/_job/backward.py +++ b/pipegoose/nn/pipeline_parallel2/_job/backward.py @@ -3,6 +3,7 @@ from pipegoose.nn.pipeline_parallel2._job.callback import Callback from pipegoose.nn.pipeline_parallel2._job.job import Job from pipegoose.nn.pipeline_parallel2._package import Package +from pipegoose.nn.pipeline_parallel2.queue import SavedActivation class CreateBackwardOutputPackageCallback(Callback): @@ -26,15 +27,17 @@ class BackwardJob(Job): """Do backward pass.""" def run_compute(self) -> torch.Tensor: - # key = self.job.key - # activations = get_saved_activations(key) + # print("doing backward job") + # return self.function(self.input.data) - # grad_output = self.input.data + microbatch_idx = self.input.metadata.microbatch_idx + partition_idx = self.input.metadata.partition_idx + key = SavedActivation.get_key(microbatch_idx, partition_idx) + outputs = SavedActivation.get_saved_activations(key) + inputs = self.input.data - # if activations.requires_grad: - # with torch.enable_grad(): - # torch.autograd.backward(activations, grad_output) + if inputs.requires_grad: + with torch.enable_grad(): + torch.autograd.backward(inputs, outputs) - # return activations.grad - print("doing backward job") - return self.function(self.input.data) + return inputs.grad diff --git a/tests/nn/pipeline_parallel_2/job/test_backward.py b/tests/nn/pipeline_parallel_2/job/test_backward.py index 9db4992..aa6699a 100644 --- a/tests/nn/pipeline_parallel_2/job/test_backward.py +++ b/tests/nn/pipeline_parallel_2/job/test_backward.py @@ -10,7 +10,7 @@ from pipegoose.nn.pipeline_parallel2._job.backward import BackwardJob from pipegoose.nn.pipeline_parallel2._job.creator import schedule_backward_job from pipegoose.nn.pipeline_parallel2.pipeline_context import PipelineContext -from pipegoose.nn.pipeline_parallel2.queue import JobQueue +from pipegoose.nn.pipeline_parallel2.queue import JobQueue, SavedActivation from pipegoose.nn.pipeline_parallel2.scheduler import SchedulerType, get_scheduler from pipegoose.testing.utils import init_parallel_context, spawn @@ -47,8 +47,6 @@ def run_create_a_backward_job_if_a_tensor_do_backprop( dist.barrier() - # NOTE: both the forward job and backward job of the same package - # execute on the same node if rank == DST: # NOTE: we enqueue the backward job in the destination rank ORIG_FORWARD_PACKAGE = deepcopy(forward_package) @@ -94,10 +92,26 @@ def test_create_a_backward_job_if_a_tensor_do_backprop_in_the_same_node(request, ) -@pytest.mark.skip -def test_execute_a_backward_job_and_send_the_output(): - pass +def test_execute_a_backward_job(backward_job): + # NOTE: save activations + MICROBATCH_IDX = backward_job.input.metadata.microbatch_idx + PARTITION_IDX = backward_job.input.metadata.partition_idx + INPUT = backward_job.input.data + + key = SavedActivation.get_key(MICROBATCH_IDX, PARTITION_IDX) + output = torch.randn_like(INPUT) + SavedActivation.save_activations(key, output) + + torch.autograd.backward(INPUT, output, retain_graph=True) + GROUND_GRADIENT = deepcopy(INPUT.grad) + INPUT.grad = None + grad_package = backward_job.compute() -def test_execute_a_backward_job(): + assert grad_package.data.shape == INPUT.shape + assert torch.equal(grad_package.data, GROUND_GRADIENT) + + +@pytest.mark.skip +def test_execute_a_backward_job_and_send_the_output(): pass