diff --git a/functorch/_src/custom_function.py b/functorch/_src/custom_function.py index 028a246c6..9192e03ee 100644 --- a/functorch/_src/custom_function.py +++ b/functorch/_src/custom_function.py @@ -8,6 +8,8 @@ def custom_vjp(name, filter_fn, fwd_fn, bwd_fn): m.def_(f"{name}(Tensor[] args) -> Tensor[]") m.impl(f"{name}", "CompositeImplicitAutograd", fwd_fn) + m.gen_vmap_binding(f"{name}") + m.def_(f"{name}_vjp(Tensor[] args) -> Tensor[]") m.impl(f"{name}_vjp", "CompositeImplicitAutograd", bwd_fn) diff --git a/functorch/csrc/CustomFunction.cpp b/functorch/csrc/CustomFunction.cpp index 290689781..3debf6f99 100644 --- a/functorch/csrc/CustomFunction.cpp +++ b/functorch/csrc/CustomFunction.cpp @@ -1,4 +1,5 @@ #include +#include #include #include #include @@ -207,7 +208,7 @@ using torch::autograd::collect_next_edges; using torch::autograd::deleteNode; using torch::autograd::flatten_tensor_args; -void customFunctionBoxed(const c10::OperatorHandle& op, torch::jit::Stack* stack) { +void customFunctionBoxed(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool get_output_by_running_forward_pass) { auto tensors = torch::jit::pop(stack).toTensorList().vec(); auto tensors_ = unpack(tensors, "tensors", 0); auto _any_requires_grad = compute_requires_grad(tensors); @@ -226,12 +227,27 @@ void customFunctionBoxed(const c10::OperatorHandle& op, torch::jit::Stack* stack grad_fn->num_inputs_ = tensors_.size(); } - auto typed_handle = op.typed(); - std::vector _tmp = ([&]() { - at::AutoDispatchBelowADInplaceOrView guard; - return typed_handle.call(tensors_); - })(); - auto result = std::move(_tmp); + std::vector result; + // When this is true, we: + // - run the forward pass + // - construct the autograd graph + // - return the result + // When this is false, we: + // - DONT run the forward pass (and instead, assume that the output from the forward pass + // was already pushed on the stack) + // - construct the autograd graph, using the (unwrapped) inputs and outputs from the fwd pass + // - DONT return the result + if (get_output_by_running_forward_pass) { + auto typed_handle = op.typed(); + std::vector _tmp = ([&]() { + at::AutoDispatchBelowADInplaceOrView guard; + return typed_handle.call(tensors_); + })(); + result = std::move(_tmp); + } else { + result = torch::jit::pop(stack).toTensorList().vec(); + } + if (grad_fn) { for (auto& tensor : result) { // TODO: is this right? @@ -248,9 +264,77 @@ void customFunctionBoxed(const c10::OperatorHandle& op, torch::jit::Stack* stack grad_fn->saved_tensors_.push_back(torch::autograd::SavedVariable(tensor, !is_input)); } } - torch::jit::push(stack, result); + // if we computed the output ourselves, return it. + if (get_output_by_running_forward_pass) { + torch::jit::push(stack, result); + } +} + +void customFunctionBoxed(const c10::OperatorHandle& op, torch::jit::Stack* stack) { + customFunctionBoxed(op, stack, /*get_output_by_running_forward_pass=*/true); +} + + +void generatedCustomBatchingRule(const c10::OperatorHandle& op, c10::DispatchKeySet ks, torch::jit::Stack* stack) { + // We basically simulate running the user's op in inference mode WITH the decomposition + // And then separately we create the autograd graph WITHOUT the decomposition. + // This allows us to decompose and "get batching rules for free", + // while still being able to run a user's custom backward function + // (which might be necessary for numeric stability). + + auto tensors = torch::jit::pop(stack).toTensorList().vec(); + auto typed_handle = op.typed(); + + // Step (1) = run the forward using the decomposition + std::vector _tmp = ([&]() { + // NOTE: I don't think this composes with DynamicLayer very well. + // I want to fully turn off autograd when I call the custom python operator, + // but when vmap() is active, DynamicLayer will overwrite TLS and (potentially) run autograd anyway. + // TODO: think more about this. + at::AutoDispatchBelowADInplaceOrView guard; + // The tensor arguments should all be batched tensors at this point, + // so what will happen is we: + // (a) Skip the autograd key and go straight to the backend + // (potentially running other stuff like AMP along the way) + // (b) Enter the user's python kernel, which runs a bunch of "prim" aten ops + // (c) Those prim ops each enter the dispatcher, and we'll hit each of their + // batching rule kernels (because our inputs are *still* BatchedTensors) + // TODO better idiom for this - I just want to go straight to the python impl + return typed_handle.redispatch(c10::DispatchKeySet(c10::DispatchKey::CPU), tensors); + })(); + auto forward_result = std::move(_tmp); + + // Step (2) = Create the autograd graph without the decomposition. + // Taking special care to "re-use" the same inputs/outputs in the autograd kernel + // that we got from the forward pass. + // This is really hacky - I'm hardcoding the boxed autograd kernel + // to know that when it's running in "don't run the forward pass" mode, + // it can assume that the arguments on the stack are + // from the forward pass. + auto unwrapped_args = std::vector(); + for (const auto& a : tensors) { + TORCH_INTERNAL_ASSERT(at::functorch::isBatchedTensor(a)); + unwrapped_args.push_back(at::functorch::unsafeGetBatchedImpl(a)->value()); + } + auto unwrapped_outs = std::vector(); + for (const auto& a : forward_result) { + TORCH_INTERNAL_ASSERT(at::functorch::isBatchedTensor(a)); + unwrapped_outs.push_back(at::functorch::unsafeGetBatchedImpl(a)->value()); + } + // relying on customFunctionBoxed will push these off the stack. + torch::jit::push(stack, unwrapped_outs); + torch::jit::push(stack, unwrapped_args); + { + // When get_output_by_running_forward_pass is false, the autograd boxed fallback knows to: + // (a) add the vjp to the autograd graph + // (b) NOT run the forward pass + customFunctionBoxed(op, stack, /*get_output_by_running_forward_pass=*/false); + } + + torch::jit::push(stack, forward_result); } + void initDispatchBindings(PyObject* module) { auto m = py::handle(module).cast(); @@ -272,6 +356,13 @@ void initDispatchBindings(PyObject* module) { torch::CppFunction::makeFromBoxedFunction<&customFunctionBoxed>()) ); }, "", py::arg("name"), py::arg("dispatch")) + .def("gen_vmap_binding", [](py::object self, const char* name) { + self.cast().impl( + name, + dispatch_str("FuncTorchBatched", + torch::CppFunction::makeFromBoxedFunction<&generatedCustomBatchingRule>()) + ); + }, "", py::arg("name")) .def("fallback_fallthrough", [](py::object self, const char* dispatch) { self.cast().fallback( dispatch_str(dispatch, torch::CppFunction::makeFallthrough()) diff --git a/test/test_eager_transforms.py b/test/test_eager_transforms.py index 11666b056..ca3cb0a83 100644 --- a/test/test_eager_transforms.py +++ b/test/test_eager_transforms.py @@ -1905,6 +1905,57 @@ def filter_fn(args): assert torch.allclose(x.grad, 3 * x.cos()) + @onlyCPU + def test_generated_batching_rule_for_custom_op(self, device): + called_impl = False + called_vjp = False + called_impl_with_batched_args = None + called_vjp_with_batched_args = None + + def my_sin_impl(args): + x, = args + nonlocal called_impl + nonlocal called_impl_with_batched_args + called_impl = True + called_impl_with_batched_args = functorch._C.is_batchedtensor(x) + return x.sin(), x + + def my_sin_vjp(args): + grad_y, result, x = args + nonlocal called_vjp + nonlocal called_vjp_with_batched_args + called_vjp = True + called_vjp_with_batched_args = all(functorch._C.is_batchedtensor(a) for a in [grad_y, result, x]) + return (grad_y * 3 * x.cos(),) + + def filter_fn(args): + return args[0] + + my_sin = custom_vjp('my_sin2', filter_fn, my_sin_impl, my_sin_vjp) + + x = torch.tensor([[[1., 2.], [3., 4.]], [[1., 2.], [3., 4.]]], requires_grad=True, device=device) + x_copy = torch.tensor([[[1., 2.], [3., 4.]], [[1., 2.], [3., 4.]]], requires_grad=True, device=device) + + vmap_my_sin = vmap(my_sin) + y = vmap_my_sin(x) + self.assertTrue(called_impl) + # We expect to run the custom forward with batched tensors, so when + # it decomposes into base ops we run the batching rule on each base op. + self.assertTrue(called_impl_with_batched_args) + + y.sum().backward() + self.assertTrue(called_vjp) + # We expect to run the custom forward with non-batched tensors, + # because we didn't explictly vmap over the backward() call. + self.assertFalse(called_vjp_with_batched_args) + + assert torch.allclose(x.grad, 3 * x.cos()) + + y_copy = my_sin(x_copy) + y_copy.sum().backward() + assert torch.allclose(y_copy, y) + assert torch.allclose(x_copy.grad, x.grad) + class TestComposability(TestCase): def test_grad_grad(self, device):