Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[#42] pp data comm func (WIP) #47

Merged
merged 10 commits into from
Sep 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion oslo/torch/nn/parallel/pipeline_parallel/_buffers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,29 @@
from oslo.torch.nn.parallel.pipeline_parallel._sync import (
register_location_for_forward_counter,
)


# original forward dictionary
_ORIGINAL_FORWARDS = dict()

# module device locations
_MODULE_DEVICE_LOCATIONS = dict()


def register_original_forward_function(location, func, device):
_ORIGINAL_FORWARDS[location] = func
_MODULE_DEVICE_LOCATIONS[location] = device
register_location_for_forward_counter(location)


def get_original_forward_function(location):
return _ORIGINAL_FORWARDS[location]


def get_module_device_location(location):
return _MODULE_DEVICE_LOCATIONS[location]


# Activations
_ACTIVATIONS = dict()

Expand All @@ -7,4 +33,4 @@ def save_activation(key, activation):


def pop_activation(key):
return _ACTIVATIONS.pop(key)
return _ACTIVATIONS.pop(key, []) # TODO; okay?
99 changes: 55 additions & 44 deletions oslo/torch/nn/parallel/pipeline_parallel/_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,91 +2,102 @@
from torch.cuda.amp import custom_fwd, custom_bwd
from torch.distributed import rpc

from oslo.torch.nn.parallel.pipeline_parallel._buffers import _ACTIVATIONS

_FORWARD_MARKER = set()

_LOCAL_BACKWARD_DONE = False

_NUM_BACKWARD_DONE = 0


def add_forward_marker(mark):
_FORWARD_MARKER.add(mark)


def remove_forward_marker(mark):
_FORWARD_MARKER.remove(mark)


def len_forward_marker():
return len(_FORWARD_MARKER)


def increase_num_backward_done():
global _NUM_BACKWARD_DONE
_NUM_BACKWARD_DONE += 1
from oslo.torch.nn.parallel.pipeline_parallel._buffers import (
get_original_forward_function,
save_activation,
pop_activation,
)
from oslo.torch.nn.parallel.pipeline_parallel._sync import (
register_job_requires_backward,
notify_backward_job_done,
)
from oslo.torch.nn.parallel.pipeline_parallel._messages import (
pack_tensor_stub,
unpack_tensor_stub,
)


def remote_module_forward(
caller,
location,
unique_key,
args_stub,
kwargs_stub,
requires_redirection,
is_training,
is_grad_enabled,
*tensors
):
if requires_redirection and is_training and is_grad_enabled:
# prepare backward redirection to caller
tensors = apply_backward_redirection(
caller,
unique_key,
*tensors,
)

(args, kwargs), _ = unpack_tensor_stub([args_stub, kwargs_stub], tensors)

def get_num_backward_done():
global _NUM_BACKWARD_DONE
return _NUM_BACKWARD_DONE
forward_fn = get_original_forward_function(location)
with torch.set_grad_enabled(is_grad_enabled):
result = forward_fn(*args, **kwargs)

result_stub, tensors = pack_tensor_stub(result, [])
need_activation_save = (
any([t.requires_grad for t in tensors]) and is_training and is_grad_enabled
)
if need_activation_save:
save_activation(unique_key, tensors)

def reset_num_backward_done():
global _NUM_BACKWARD_DONE, _LOCAL_BACKWARD_DONE
_NUM_BACKWARD_DONE = 0
_LOCAL_BACKWARD_DONE = False
return result_stub, tensors, need_activation_save


def launch_remote_backward(unique_key, *grad_outputs):
activation = _ACTIVATIONS.pop(unique_key)
activation = pop_activation(unique_key)

# TODO; some output contains tuple of tuple..
# better way to deal with this?
new_act = []
new_grad = []
for act, grad in zip(activation, grad_outputs):
if act is not None and grad is not None and act.requires_grad:
new_act.append(act)
new_grad.append(grad)

torch.autograd.backward(tuple(new_act), tuple(new_grad))
remove_forward_marker(unique_key)
if len(new_act) > 0 and len(new_grad) > 0:
torch.autograd.backward(tuple(new_act), tuple(new_grad))
notify_backward_job_done(unique_key)


# TODO; why
# why
# forward(ctx, req, *args, **kwargs)
# ...
# return args, kwargs
# does not work???
# ->
#
# because that is the design of Pytorch
# see: github.com/pytorch/pytorch/issues/16940
#
# based on https://github.com/facebookresearch/fairscale/blob/main/fairscale/nn/pipe/rpc.py#L53
class _PipeBackwardRedirection(torch.autograd.Function):
@staticmethod
@custom_fwd
def forward(ctx, to, unique_key, *args):
ctx.to = to
ctx.unique_key = unique_key
ctx.num_nones = 2 + len(args) # counting req
ctx.num_nones = 2 + len(args)

# mark
# TODO; do this before remote_forward
rpc.rpc_sync(to=to, func=add_forward_marker, args=(unique_key,))
# TODO; can we do this before remote_forward
# without rpc call?
rpc.rpc_sync(to=to, func=register_job_requires_backward, args=(unique_key,))

return args

@staticmethod
@custom_bwd
@rpc.functions.async_execution
def backward(ctx, *grad_outputs):
to = ctx.to
unique_key = ctx.unique_key

# print(f'backward: {to=}, {unique_key=}')

rpc.rpc_async(
to=to,
func=launch_remote_backward,
Expand Down
Loading