Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PROTOTYPE] generated batching rules for custom dispatcher ops #578

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions functorch/_src/custom_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ def custom_vjp(name, filter_fn, fwd_fn, bwd_fn):
m.def_(f"{name}(Tensor[] args) -> Tensor[]")
m.impl(f"{name}", "CompositeImplicitAutograd", fwd_fn)

m.gen_vmap_binding(f"{name}")

m.def_(f"{name}_vjp(Tensor[] args) -> Tensor[]")
m.impl(f"{name}_vjp", "CompositeImplicitAutograd", bwd_fn)

Expand Down
111 changes: 104 additions & 7 deletions functorch/csrc/CustomFunction.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <functorch/csrc/CustomFunction.h>
#include <functorch/csrc/BatchedTensorImpl.h>
#include <ATen/ATen.h>
#include <torch/csrc/autograd/function.h>
#include <torch/csrc/autograd/variable.h>
Expand Down Expand Up @@ -200,6 +201,18 @@ variable_list GenericPythonBackward::apply(variable_list&& grads) {
return grad_inputs;
}

thread_local bool come_up_with_a_better_name = true;

struct SwitchGuard {
public:
SwitchGuard() {
come_up_with_a_better_name = false;
}
~SwitchGuard() {
come_up_with_a_better_name = true;
}
};

typedef TensorList (*custom_python_function_t)(TensorList);

using torch::autograd::compute_requires_grad;
Expand All @@ -226,12 +239,26 @@ void customFunctionBoxed(const c10::OperatorHandle& op, torch::jit::Stack* stack
grad_fn->num_inputs_ = tensors_.size();
}

auto typed_handle = op.typed<custom_function_t>();
std::vector<Tensor> _tmp = ([&]() {
at::AutoDispatchBelowADInplaceOrView guard;
return typed_handle.call(tensors_);
})();
auto result = std::move(_tmp);
std::vector<at::Tensor> result;
// When this is true, we:
// - run the forward pass
// - construct the autograd graph
// - return the result
// When this is false, we:
// - DONT run the forward pass
// - construct the autograd graph, using the (unwrapped) inputs and outputs from the fwd pass
// - DONT return the result
if (come_up_with_a_better_name) {
auto typed_handle = op.typed<custom_function_t>();
std::vector<Tensor> _tmp = ([&]() {
at::AutoDispatchBelowADInplaceOrView guard;
return typed_handle.call(tensors_);
})();
result = std::move(_tmp);
} else {
result = torch::jit::pop(stack).toTensorList().vec();
}

if (grad_fn) {
for (auto& tensor : result) {
// TODO: is this right?
Expand All @@ -248,9 +275,72 @@ void customFunctionBoxed(const c10::OperatorHandle& op, torch::jit::Stack* stack
grad_fn->saved_tensors_.push_back(torch::autograd::SavedVariable(tensor, !is_input));
}
}
torch::jit::push(stack, result);
if (come_up_with_a_better_name) {
torch::jit::push(stack, result);
}
}

void generatedCustomBatchingRule(const c10::OperatorHandle& op, c10::DispatchKeySet ks, torch::jit::Stack* stack) {
// We basically simulate running the user's op in inference mode WITH the decomposition
// And then separately we create the autograd graph WITHOUT the decomposition.
// This allows us to decompose and "get batching rules for free",
// while still being able to run a user's custom backward function
// (which might be necessary for numeric stability).

auto tensors = torch::jit::pop(stack).toTensorList().vec();
auto typed_handle = op.typed<custom_function_t>();

// Step (1) = run the forward using the decomposition
std::vector<Tensor> _tmp = ([&]() {
at::AutoDispatchBelowADInplaceOrView guard;
// The tensor arguments should all be batched tensors at this point,
// so what will happen is we:
// (a) Skip the autograd key and go straight to the backend
// (potentially running other stuff like AMP along the way)
// (b) Enter the user's python kernel, which runs a bunch of "prim" aten ops
// (c) Those prim ops each enter the dispatcher, and we'll hit each of their
// batching rule kernels (because our inputs are *still* BatchedTensors)
constexpr DispatchKeySet after_vmap_keyset = DispatchKeySet(
DispatchKeySet::FULL_AFTER,
c10::DispatchKey::FuncTorchBatched);
// See the comment in DynamicLayer.cpp
auto final_ks = after_vmap_keyset.remove(kDynamicLayerBackModeKey);
return typed_handle.redispatch(ks & final_ks, tensors);
})();
auto forward_result = std::move(_tmp);

// Step (2) = Create the autograd graph without the decomposition.
// Taking special care to "re-use" the same inputs/outputs in the autograd kernel
// that we got from the forward pass.
// This is really hacky - I'm hardcoding the boxed autograd kernel
// to know that when it's running in "don't run the forward pass" mode,
// it can assume that the arguments on the stack are <unwrapped_output, unwrapped_inputs...>
// from the forward pass.
auto unwrapped_args = std::vector<Tensor>();
for (const auto& a : tensors) {
TORCH_INTERNAL_ASSERT(at::functorch::isBatchedTensor(a));
unwrapped_args.push_back(at::functorch::unsafeGetBatchedImpl(a)->value());
}
auto unwrapped_outs = std::vector<Tensor>();
for (const auto& a : forward_result) {
TORCH_INTERNAL_ASSERT(at::functorch::isBatchedTensor(a));
unwrapped_outs.push_back(at::functorch::unsafeGetBatchedImpl(a)->value());
}
// relying on customFunctionBoxed will push these off the stack.
torch::jit::push(stack, unwrapped_outs);
torch::jit::push(stack, unwrapped_args);
{
// When the guard is set, the autograd boxed fallback knows to:
// (a) add the vjp to the autograd graph
// (b) NOT run the forward pass
SwitchGuard guard;
customFunctionBoxed(op, stack);
}

torch::jit::push(stack, forward_result);
}


void initDispatchBindings(PyObject* module) {
auto m = py::handle(module).cast<py::module>();

Expand All @@ -272,6 +362,13 @@ void initDispatchBindings(PyObject* module) {
torch::CppFunction::makeFromBoxedFunction<&customFunctionBoxed>())
);
}, "", py::arg("name"), py::arg("dispatch"))
.def("gen_vmap_binding", [](py::object self, const char* name) {
self.cast<torch::Library&>().impl(
name,
dispatch_str("FuncTorchBatched",
torch::CppFunction::makeFromBoxedFunction<&generatedCustomBatchingRule>())
);
}, "", py::arg("name"))
.def("fallback_fallthrough", [](py::object self, const char* dispatch) {
self.cast<torch::Library&>().fallback(
dispatch_str(dispatch, torch::CppFunction::makeFallthrough())
Expand Down
13 changes: 12 additions & 1 deletion functorch/csrc/DynamicLayer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,18 @@ void dynamicLayerFrontFallback(const c10::OperatorHandle& op, torch::jit::Stack*
}
#endif
if (dynamicLayerStack.size() == 0) {
sanityCheckStack(op, stack);
// total hack: for now, the logic I added to generate the vmap rule
// doesn't play well with DynamicLayer (only one layer of vmap works right now).
// Why? In the generated batching rule, I effectively want to treat it as a "composite kernel",
// and have it run the to the python-defined forward function. But:
// (1) I want to go there through the dispatcher so other functionalities can run (e.g. AMP).
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, maybe I should just be treating this the same way that DynamicLayer already treats composite ops - just directly call into the composite function. That means that stuff like AMP will run on the base ops and not the composite ops, but maybe that's the right behavior (unless the user wants to write a custom "AMP rule" for their op)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(ended up doing this, although there's another issue with disabling autograd that I left a comment about)

// That means I need to re-enter the dispatcher, calling the *same* operator.
// (2) I DONT want to unwrap the batched tensors, since when we decompose I want to run the batching rule
// on the base ops
// (3) I can't use dispatcher::call(), since given the above two constraints I'll infinite loop
// (I can't add the batched key to the TLS exclude set)
// (4) I have to ::redispatch() then. But that plays poorly with dynamicLayer.
//sanityCheckStack(op, stack);
c10::impl::ExcludeDispatchKeyGuard guard(all_dynlayer_keyset);
op.callBoxed(stack);
return;
Expand Down
40 changes: 40 additions & 0 deletions test/test_eager_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1905,6 +1905,46 @@ def filter_fn(args):

assert torch.allclose(x.grad, 3 * x.cos())

@onlyCPU
def test_generated_batching_rule_for_custom_op(self, device):
called_impl = False
called_vjp = False

def my_sin_impl(args):
x, = args
nonlocal called_impl
called_impl = True
called_impl = True
return x.sin(), x

def my_sin_vjp(args):
grad_y, result, x = args
nonlocal called_vjp
called_vjp = True
return (grad_y * 3 * x.cos(),)

def filter_fn(args):
return args[0]

my_sin = custom_vjp('my_sin2', filter_fn, my_sin_impl, my_sin_vjp)

x = torch.tensor([[1., 2.], [3., 4.]], requires_grad=True, device=device)
x_copy = torch.tensor([[1., 2.], [3., 4.]], requires_grad=True, device=device)

vmap_my_sin = vmap(my_sin)
y = vmap_my_sin(x)
self.assertTrue(called_impl)

y.sum().backward()
self.assertTrue(called_vjp)

assert torch.allclose(x.grad, 3 * x.cos())

y_copy = my_sin(x_copy)
y_copy.sum().backward()
assert torch.allclose(y_copy, y)
assert torch.allclose(x_copy.grad, x.grad)


class TestComposability(TestCase):
def test_grad_grad(self, device):
Expand Down