Skip to content

Commit

Permalink
add calculating gradients in backward job
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Oct 1, 2023
1 parent ca3be06 commit 10eeda0
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 16 deletions.
21 changes: 12 additions & 9 deletions pipegoose/nn/pipeline_parallel2/_job/backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pipegoose.nn.pipeline_parallel2._job.callback import Callback
from pipegoose.nn.pipeline_parallel2._job.job import Job
from pipegoose.nn.pipeline_parallel2._package import Package
from pipegoose.nn.pipeline_parallel2.queue import SavedActivation


class CreateBackwardOutputPackageCallback(Callback):
Expand All @@ -26,15 +27,17 @@ class BackwardJob(Job):
"""Do backward pass."""

def run_compute(self) -> torch.Tensor:
# key = self.job.key
# activations = get_saved_activations(key)
# print("doing backward job")
# return self.function(self.input.data)

# grad_output = self.input.data
microbatch_idx = self.input.metadata.microbatch_idx
partition_idx = self.input.metadata.partition_idx
key = SavedActivation.get_key(microbatch_idx, partition_idx)
outputs = SavedActivation.get_saved_activations(key)
inputs = self.input.data

# if activations.requires_grad:
# with torch.enable_grad():
# torch.autograd.backward(activations, grad_output)
if inputs.requires_grad:
with torch.enable_grad():
torch.autograd.backward(inputs, outputs)

# return activations.grad
print("doing backward job")
return self.function(self.input.data)
return inputs.grad
28 changes: 21 additions & 7 deletions tests/nn/pipeline_parallel_2/job/test_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from pipegoose.nn.pipeline_parallel2._job.backward import BackwardJob
from pipegoose.nn.pipeline_parallel2._job.creator import schedule_backward_job
from pipegoose.nn.pipeline_parallel2.pipeline_context import PipelineContext
from pipegoose.nn.pipeline_parallel2.queue import JobQueue
from pipegoose.nn.pipeline_parallel2.queue import JobQueue, SavedActivation
from pipegoose.nn.pipeline_parallel2.scheduler import SchedulerType, get_scheduler
from pipegoose.testing.utils import init_parallel_context, spawn

Expand Down Expand Up @@ -47,8 +47,6 @@ def run_create_a_backward_job_if_a_tensor_do_backprop(

dist.barrier()

# NOTE: both the forward job and backward job of the same package
# execute on the same node
if rank == DST:
# NOTE: we enqueue the backward job in the destination rank
ORIG_FORWARD_PACKAGE = deepcopy(forward_package)
Expand Down Expand Up @@ -94,10 +92,26 @@ def test_create_a_backward_job_if_a_tensor_do_backprop_in_the_same_node(request,
)


@pytest.mark.skip
def test_execute_a_backward_job_and_send_the_output():
pass
def test_execute_a_backward_job(backward_job):
# NOTE: save activations
MICROBATCH_IDX = backward_job.input.metadata.microbatch_idx
PARTITION_IDX = backward_job.input.metadata.partition_idx
INPUT = backward_job.input.data

key = SavedActivation.get_key(MICROBATCH_IDX, PARTITION_IDX)
output = torch.randn_like(INPUT)
SavedActivation.save_activations(key, output)

torch.autograd.backward(INPUT, output, retain_graph=True)
GROUND_GRADIENT = deepcopy(INPUT.grad)
INPUT.grad = None

grad_package = backward_job.compute()

def test_execute_a_backward_job():
assert grad_package.data.shape == INPUT.shape
assert torch.equal(grad_package.data, GROUND_GRADIENT)


@pytest.mark.skip
def test_execute_a_backward_job_and_send_the_output():
pass

0 comments on commit 10eeda0

Please sign in to comment.