From c5885230fbbfa5a5a4021baf9e4ced7821e7fdca Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Tue, 17 Sep 2024 09:54:15 -0700 Subject: [PATCH] Clean up and fix primal type to tangent type mapping This is part of the ["stackless"](#23299) change. I'm splitting it out into a separate PR because we need it for some work on sharding types. Changes: 1. Rename `at_least_vspace` to `to_tangent_type` since that's what we always meant by it. `at_least_vspace` was always a bad name (sorry!) but it makes even less sense when you can have a special tangent type for a primal types that's already a vector space itself. 2. Replace `Zero.from_value` with `Zero.from_primal_value`, which does the required primal-type-to-tangent-type conversion. 3. Add `to_tangent_type` calls in various other places they're missing. 4. Remove non-support for float0 in custom deriviatives? 5. [Optional, WIP] Reinstate some checks that had been skipped over, presumably because of these bugs. (We'll see how far I get with it. Might end up being a separate PR.) PiperOrigin-RevId: 675606346 --- oryx/core/primitive.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/oryx/core/primitive.py b/oryx/core/primitive.py index 2ae0b2a..d841b2f 100644 --- a/oryx/core/primitive.py +++ b/oryx/core/primitive.py @@ -142,8 +142,6 @@ def _jvp(primals, tangents, **params): primals_out, tangents_out = ad.jvp(lu.wrap_init(self.impl, params)).call_wrapped( primals, tangents) - tangents_out = jax_util.safe_map(ad.recast_to_float0, primals_out, - tangents_out) return primals_out, tangents_out ad.primitive_jvps[self] = _jvp