From 7eeb681ccca2dd49329131459a33373fdc17e95f Mon Sep 17 00:00:00 2001 From: Brian Hirsh Date: Tue, 8 Mar 2022 16:26:08 -0800 Subject: [PATCH 1/4] [PROTOTYPE] generated batching rules for custom dispatcher ops --- functorch/_src/custom_function.py | 2 + functorch/csrc/CustomFunction.cpp | 111 ++++++++++++++++++++++++++++-- functorch/csrc/DynamicLayer.cpp | 13 +++- test/test_eager_transforms.py | 40 +++++++++++ 4 files changed, 158 insertions(+), 8 deletions(-) diff --git a/functorch/_src/custom_function.py b/functorch/_src/custom_function.py index 028a246c6..9192e03ee 100644 --- a/functorch/_src/custom_function.py +++ b/functorch/_src/custom_function.py @@ -8,6 +8,8 @@ def custom_vjp(name, filter_fn, fwd_fn, bwd_fn): m.def_(f"{name}(Tensor[] args) -> Tensor[]") m.impl(f"{name}", "CompositeImplicitAutograd", fwd_fn) + m.gen_vmap_binding(f"{name}") + m.def_(f"{name}_vjp(Tensor[] args) -> Tensor[]") m.impl(f"{name}_vjp", "CompositeImplicitAutograd", bwd_fn) diff --git a/functorch/csrc/CustomFunction.cpp b/functorch/csrc/CustomFunction.cpp index 290689781..05f0f3105 100644 --- a/functorch/csrc/CustomFunction.cpp +++ b/functorch/csrc/CustomFunction.cpp @@ -1,4 +1,5 @@ #include +#include #include #include #include @@ -200,6 +201,18 @@ variable_list GenericPythonBackward::apply(variable_list&& grads) { return grad_inputs; } +thread_local bool come_up_with_a_better_name = true; + +struct SwitchGuard { + public: + SwitchGuard() { + come_up_with_a_better_name = false; + } + ~SwitchGuard() { + come_up_with_a_better_name = true; + } +}; + typedef TensorList (*custom_python_function_t)(TensorList); using torch::autograd::compute_requires_grad; @@ -226,12 +239,26 @@ void customFunctionBoxed(const c10::OperatorHandle& op, torch::jit::Stack* stack grad_fn->num_inputs_ = tensors_.size(); } - auto typed_handle = op.typed(); - std::vector _tmp = ([&]() { - at::AutoDispatchBelowADInplaceOrView guard; - return typed_handle.call(tensors_); - })(); - auto result = std::move(_tmp); + std::vector result; + // When this is true, we: + // - run the forward pass + // - construct the autograd graph + // - return the result + // When this is false, we: + // - DONT run the forward pass + // - construct the autograd graph, using the (unwrapped) inputs and outputs from the fwd pass + // - DONT return the result + if (come_up_with_a_better_name) { + auto typed_handle = op.typed(); + std::vector _tmp = ([&]() { + at::AutoDispatchBelowADInplaceOrView guard; + return typed_handle.call(tensors_); + })(); + result = std::move(_tmp); + } else { + result = torch::jit::pop(stack).toTensorList().vec(); + } + if (grad_fn) { for (auto& tensor : result) { // TODO: is this right? @@ -248,9 +275,72 @@ void customFunctionBoxed(const c10::OperatorHandle& op, torch::jit::Stack* stack grad_fn->saved_tensors_.push_back(torch::autograd::SavedVariable(tensor, !is_input)); } } - torch::jit::push(stack, result); + if (come_up_with_a_better_name) { + torch::jit::push(stack, result); + } +} + +void generatedCustomBatchingRule(const c10::OperatorHandle& op, c10::DispatchKeySet ks, torch::jit::Stack* stack) { + // We basically simulate running the user's op in inference mode WITH the decomposition + // And then separately we create the autograd graph WITHOUT the decomposition. + // This allows us to decompose and "get batching rules for free", + // while still being able to run a user's custom backward function + // (which might be necessary for numeric stability). + + auto tensors = torch::jit::pop(stack).toTensorList().vec(); + auto typed_handle = op.typed(); + + // Step (1) = run the forward using the decomposition + std::vector _tmp = ([&]() { + at::AutoDispatchBelowADInplaceOrView guard; + // The tensor arguments should all be batched tensors at this point, + // so what will happen is we: + // (a) Skip the autograd key and go straight to the backend + // (potentially running other stuff like AMP along the way) + // (b) Enter the user's python kernel, which runs a bunch of "prim" aten ops + // (c) Those prim ops each enter the dispatcher, and we'll hit each of their + // batching rule kernels (because our inputs are *still* BatchedTensors) + constexpr DispatchKeySet after_vmap_keyset = DispatchKeySet( + DispatchKeySet::FULL_AFTER, + c10::DispatchKey::FuncTorchBatched); + // See the comment in DynamicLayer.cpp + auto final_ks = after_vmap_keyset.remove(kDynamicLayerBackModeKey); + return typed_handle.redispatch(ks & final_ks, tensors); + })(); + auto forward_result = std::move(_tmp); + + // Step (2) = Create the autograd graph without the decomposition. + // Taking special care to "re-use" the same inputs/outputs in the autograd kernel + // that we got from the forward pass. + // This is really hacky - I'm hardcoding the boxed autograd kernel + // to know that when it's running in "don't run the forward pass" mode, + // it can assume that the arguments on the stack are + // from the forward pass. + auto unwrapped_args = std::vector(); + for (const auto& a : tensors) { + TORCH_INTERNAL_ASSERT(at::functorch::isBatchedTensor(a)); + unwrapped_args.push_back(at::functorch::unsafeGetBatchedImpl(a)->value()); + } + auto unwrapped_outs = std::vector(); + for (const auto& a : forward_result) { + TORCH_INTERNAL_ASSERT(at::functorch::isBatchedTensor(a)); + unwrapped_outs.push_back(at::functorch::unsafeGetBatchedImpl(a)->value()); + } + // relying on customFunctionBoxed will push these off the stack. + torch::jit::push(stack, unwrapped_outs); + torch::jit::push(stack, unwrapped_args); + { + // When the guard is set, the autograd boxed fallback knows to: + // (a) add the vjp to the autograd graph + // (b) NOT run the forward pass + SwitchGuard guard; + customFunctionBoxed(op, stack); + } + + torch::jit::push(stack, forward_result); } + void initDispatchBindings(PyObject* module) { auto m = py::handle(module).cast(); @@ -272,6 +362,13 @@ void initDispatchBindings(PyObject* module) { torch::CppFunction::makeFromBoxedFunction<&customFunctionBoxed>()) ); }, "", py::arg("name"), py::arg("dispatch")) + .def("gen_vmap_binding", [](py::object self, const char* name) { + self.cast().impl( + name, + dispatch_str("FuncTorchBatched", + torch::CppFunction::makeFromBoxedFunction<&generatedCustomBatchingRule>()) + ); + }, "", py::arg("name")) .def("fallback_fallthrough", [](py::object self, const char* dispatch) { self.cast().fallback( dispatch_str(dispatch, torch::CppFunction::makeFallthrough()) diff --git a/functorch/csrc/DynamicLayer.cpp b/functorch/csrc/DynamicLayer.cpp index 9ed7ab9cd..689e04805 100644 --- a/functorch/csrc/DynamicLayer.cpp +++ b/functorch/csrc/DynamicLayer.cpp @@ -435,7 +435,18 @@ void dynamicLayerFrontFallback(const c10::OperatorHandle& op, torch::jit::Stack* } #endif if (dynamicLayerStack.size() == 0) { - sanityCheckStack(op, stack); + // total hack: for now, the logic I added to generate the vmap rule + // doesn't play well with DynamicLayer (only one layer of vmap works right now). + // Why? In the generated batching rule, I effectively want to treat it as a "composite kernel", + // and have it run the to the python-defined forward function. But: + // (1) I want to go there through the dispatcher so other functionalities can run (e.g. AMP). + // That means I need to re-enter the dispatcher, calling the *same* operator. + // (2) I DONT want to unwrap the batched tensors, since when we decompose I want to run the batching rule + // on the base ops + // (3) I can't use dispatcher::call(), since given the above two constraints I'll infinite loop + // (I can't add the batched key to the TLS exclude set) + // (4) I have to ::redispatch() then. But that plays poorly with dynamicLayer. + //sanityCheckStack(op, stack); c10::impl::ExcludeDispatchKeyGuard guard(all_dynlayer_keyset); op.callBoxed(stack); return; diff --git a/test/test_eager_transforms.py b/test/test_eager_transforms.py index 11666b056..2f696c414 100644 --- a/test/test_eager_transforms.py +++ b/test/test_eager_transforms.py @@ -1905,6 +1905,46 @@ def filter_fn(args): assert torch.allclose(x.grad, 3 * x.cos()) + @onlyCPU + def test_generated_batching_rule_for_custom_op(self, device): + called_impl = False + called_vjp = False + + def my_sin_impl(args): + x, = args + nonlocal called_impl + called_impl = True + called_impl = True + return x.sin(), x + + def my_sin_vjp(args): + grad_y, result, x = args + nonlocal called_vjp + called_vjp = True + return (grad_y * 3 * x.cos(),) + + def filter_fn(args): + return args[0] + + my_sin = custom_vjp('my_sin', filter_fn, my_sin_impl, my_sin_vjp) + + x = torch.tensor([[1., 2.], [3., 4.]], requires_grad=True, device=device) + x_copy = x.clone() + + vmap_my_sin = vmap(my_sin) + y = vmap_my_sin(x) + self.assertTrue(called_impl) + + y.sum().backward() + self.assertTrue(called_vjp) + + assert torch.allclose(x.grad, 3 * x.cos()) + + y_copy = my_sin(x_copy) + y_copy.sum().backward() + assert torch.allclose(y_copy, y) + assert torch.allclose(x_copy.grad, x) + class TestComposability(TestCase): def test_grad_grad(self, device): From 105a5d96097c2874ec13509dad409e3d39f7a8d3 Mon Sep 17 00:00:00 2001 From: Brian Hirsh Date: Tue, 8 Mar 2022 16:56:02 -0800 Subject: [PATCH 2/4] fix test --- test/test_eager_transforms.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_eager_transforms.py b/test/test_eager_transforms.py index 2f696c414..d53cfa04e 100644 --- a/test/test_eager_transforms.py +++ b/test/test_eager_transforms.py @@ -1926,10 +1926,10 @@ def my_sin_vjp(args): def filter_fn(args): return args[0] - my_sin = custom_vjp('my_sin', filter_fn, my_sin_impl, my_sin_vjp) + my_sin = custom_vjp('my_sin2', filter_fn, my_sin_impl, my_sin_vjp) x = torch.tensor([[1., 2.], [3., 4.]], requires_grad=True, device=device) - x_copy = x.clone() + x_copy = torch.tensor([[1., 2.], [3., 4.]], requires_grad=True, device=device) vmap_my_sin = vmap(my_sin) y = vmap_my_sin(x) @@ -1943,7 +1943,7 @@ def filter_fn(args): y_copy = my_sin(x_copy) y_copy.sum().backward() assert torch.allclose(y_copy, y) - assert torch.allclose(x_copy.grad, x) + assert torch.allclose(x_copy.grad, x.grad) class TestComposability(TestCase): From daa0cc777f656fab762f9a41575fe34facbefecf Mon Sep 17 00:00:00 2001 From: Brian Hirsh Date: Tue, 8 Mar 2022 21:05:49 -0800 Subject: [PATCH 3/4] remove unnecessary tls --- functorch/csrc/CustomFunction.cpp | 32 +++++++++++++------------------ 1 file changed, 13 insertions(+), 19 deletions(-) diff --git a/functorch/csrc/CustomFunction.cpp b/functorch/csrc/CustomFunction.cpp index 05f0f3105..1992ab9e3 100644 --- a/functorch/csrc/CustomFunction.cpp +++ b/functorch/csrc/CustomFunction.cpp @@ -201,18 +201,6 @@ variable_list GenericPythonBackward::apply(variable_list&& grads) { return grad_inputs; } -thread_local bool come_up_with_a_better_name = true; - -struct SwitchGuard { - public: - SwitchGuard() { - come_up_with_a_better_name = false; - } - ~SwitchGuard() { - come_up_with_a_better_name = true; - } -}; - typedef TensorList (*custom_python_function_t)(TensorList); using torch::autograd::compute_requires_grad; @@ -220,7 +208,7 @@ using torch::autograd::collect_next_edges; using torch::autograd::deleteNode; using torch::autograd::flatten_tensor_args; -void customFunctionBoxed(const c10::OperatorHandle& op, torch::jit::Stack* stack) { +void customFunctionBoxed(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool get_output_by_running_forward_pass) { auto tensors = torch::jit::pop(stack).toTensorList().vec(); auto tensors_ = unpack(tensors, "tensors", 0); auto _any_requires_grad = compute_requires_grad(tensors); @@ -245,10 +233,11 @@ void customFunctionBoxed(const c10::OperatorHandle& op, torch::jit::Stack* stack // - construct the autograd graph // - return the result // When this is false, we: - // - DONT run the forward pass + // - DONT run the forward pass (and instead, assume that the output from the forward pass + // was already pushed on the stack) // - construct the autograd graph, using the (unwrapped) inputs and outputs from the fwd pass // - DONT return the result - if (come_up_with_a_better_name) { + if (get_output_by_running_forward_pass) { auto typed_handle = op.typed(); std::vector _tmp = ([&]() { at::AutoDispatchBelowADInplaceOrView guard; @@ -275,11 +264,17 @@ void customFunctionBoxed(const c10::OperatorHandle& op, torch::jit::Stack* stack grad_fn->saved_tensors_.push_back(torch::autograd::SavedVariable(tensor, !is_input)); } } - if (come_up_with_a_better_name) { + // if we computed the output ourselves, return it. + if (get_output_by_running_forward_pass) { torch::jit::push(stack, result); } } +void customFunctionBoxed(const c10::OperatorHandle& op, torch::jit::Stack* stack) { + customFunctionBoxed(op, stack, /*get_output_by_running_forward_pass=*/true); +} + + void generatedCustomBatchingRule(const c10::OperatorHandle& op, c10::DispatchKeySet ks, torch::jit::Stack* stack) { // We basically simulate running the user's op in inference mode WITH the decomposition // And then separately we create the autograd graph WITHOUT the decomposition. @@ -330,11 +325,10 @@ void generatedCustomBatchingRule(const c10::OperatorHandle& op, c10::DispatchKey torch::jit::push(stack, unwrapped_outs); torch::jit::push(stack, unwrapped_args); { - // When the guard is set, the autograd boxed fallback knows to: + // When get_output_by_running_forward_pass is false, the autograd boxed fallback knows to: // (a) add the vjp to the autograd graph // (b) NOT run the forward pass - SwitchGuard guard; - customFunctionBoxed(op, stack); + customFunctionBoxed(op, stack, /*get_output_by_running_forward_pass=*/false); } torch::jit::push(stack, forward_result); From 49f3cbd5bc78cf8028f77122d5be9449a2907f5a Mon Sep 17 00:00:00 2001 From: Brian Hirsh Date: Wed, 9 Mar 2022 08:57:48 -0800 Subject: [PATCH 4/4] partially fix DynamicLayer logic, added a comment about it --- functorch/csrc/CustomFunction.cpp | 12 ++++++------ functorch/csrc/DynamicLayer.cpp | 13 +------------ test/test_eager_transforms.py | 17 ++++++++++++++--- 3 files changed, 21 insertions(+), 21 deletions(-) diff --git a/functorch/csrc/CustomFunction.cpp b/functorch/csrc/CustomFunction.cpp index 1992ab9e3..3debf6f99 100644 --- a/functorch/csrc/CustomFunction.cpp +++ b/functorch/csrc/CustomFunction.cpp @@ -287,6 +287,10 @@ void generatedCustomBatchingRule(const c10::OperatorHandle& op, c10::DispatchKey // Step (1) = run the forward using the decomposition std::vector _tmp = ([&]() { + // NOTE: I don't think this composes with DynamicLayer very well. + // I want to fully turn off autograd when I call the custom python operator, + // but when vmap() is active, DynamicLayer will overwrite TLS and (potentially) run autograd anyway. + // TODO: think more about this. at::AutoDispatchBelowADInplaceOrView guard; // The tensor arguments should all be batched tensors at this point, // so what will happen is we: @@ -295,12 +299,8 @@ void generatedCustomBatchingRule(const c10::OperatorHandle& op, c10::DispatchKey // (b) Enter the user's python kernel, which runs a bunch of "prim" aten ops // (c) Those prim ops each enter the dispatcher, and we'll hit each of their // batching rule kernels (because our inputs are *still* BatchedTensors) - constexpr DispatchKeySet after_vmap_keyset = DispatchKeySet( - DispatchKeySet::FULL_AFTER, - c10::DispatchKey::FuncTorchBatched); - // See the comment in DynamicLayer.cpp - auto final_ks = after_vmap_keyset.remove(kDynamicLayerBackModeKey); - return typed_handle.redispatch(ks & final_ks, tensors); + // TODO better idiom for this - I just want to go straight to the python impl + return typed_handle.redispatch(c10::DispatchKeySet(c10::DispatchKey::CPU), tensors); })(); auto forward_result = std::move(_tmp); diff --git a/functorch/csrc/DynamicLayer.cpp b/functorch/csrc/DynamicLayer.cpp index 689e04805..9ed7ab9cd 100644 --- a/functorch/csrc/DynamicLayer.cpp +++ b/functorch/csrc/DynamicLayer.cpp @@ -435,18 +435,7 @@ void dynamicLayerFrontFallback(const c10::OperatorHandle& op, torch::jit::Stack* } #endif if (dynamicLayerStack.size() == 0) { - // total hack: for now, the logic I added to generate the vmap rule - // doesn't play well with DynamicLayer (only one layer of vmap works right now). - // Why? In the generated batching rule, I effectively want to treat it as a "composite kernel", - // and have it run the to the python-defined forward function. But: - // (1) I want to go there through the dispatcher so other functionalities can run (e.g. AMP). - // That means I need to re-enter the dispatcher, calling the *same* operator. - // (2) I DONT want to unwrap the batched tensors, since when we decompose I want to run the batching rule - // on the base ops - // (3) I can't use dispatcher::call(), since given the above two constraints I'll infinite loop - // (I can't add the batched key to the TLS exclude set) - // (4) I have to ::redispatch() then. But that plays poorly with dynamicLayer. - //sanityCheckStack(op, stack); + sanityCheckStack(op, stack); c10::impl::ExcludeDispatchKeyGuard guard(all_dynlayer_keyset); op.callBoxed(stack); return; diff --git a/test/test_eager_transforms.py b/test/test_eager_transforms.py index d53cfa04e..ca3cb0a83 100644 --- a/test/test_eager_transforms.py +++ b/test/test_eager_transforms.py @@ -1909,18 +1909,23 @@ def filter_fn(args): def test_generated_batching_rule_for_custom_op(self, device): called_impl = False called_vjp = False + called_impl_with_batched_args = None + called_vjp_with_batched_args = None def my_sin_impl(args): x, = args nonlocal called_impl + nonlocal called_impl_with_batched_args called_impl = True - called_impl = True + called_impl_with_batched_args = functorch._C.is_batchedtensor(x) return x.sin(), x def my_sin_vjp(args): grad_y, result, x = args nonlocal called_vjp + nonlocal called_vjp_with_batched_args called_vjp = True + called_vjp_with_batched_args = all(functorch._C.is_batchedtensor(a) for a in [grad_y, result, x]) return (grad_y * 3 * x.cos(),) def filter_fn(args): @@ -1928,15 +1933,21 @@ def filter_fn(args): my_sin = custom_vjp('my_sin2', filter_fn, my_sin_impl, my_sin_vjp) - x = torch.tensor([[1., 2.], [3., 4.]], requires_grad=True, device=device) - x_copy = torch.tensor([[1., 2.], [3., 4.]], requires_grad=True, device=device) + x = torch.tensor([[[1., 2.], [3., 4.]], [[1., 2.], [3., 4.]]], requires_grad=True, device=device) + x_copy = torch.tensor([[[1., 2.], [3., 4.]], [[1., 2.], [3., 4.]]], requires_grad=True, device=device) vmap_my_sin = vmap(my_sin) y = vmap_my_sin(x) self.assertTrue(called_impl) + # We expect to run the custom forward with batched tensors, so when + # it decomposes into base ops we run the batching rule on each base op. + self.assertTrue(called_impl_with_batched_args) y.sum().backward() self.assertTrue(called_vjp) + # We expect to run the custom forward with non-batched tensors, + # because we didn't explictly vmap over the backward() call. + self.assertFalse(called_vjp_with_batched_args) assert torch.allclose(x.grad, 3 * x.cos())