Skip to content

Commit

Permalink
Clean up and fix primal type to tangent type mapping
Browse files Browse the repository at this point in the history
This is part of the ["stackless"](#23299) change. I'm splitting it out into a separate PR because we need it for some work on sharding types.

Changes:
  1. Rename `at_least_vspace` to `to_tangent_type` since that's what we always meant by it. `at_least_vspace` was always a bad name (sorry!) but it makes even less sense when you can have a special tangent type for a primal types that's already a vector space itself.
  2. Replace `Zero.from_value` with `Zero.from_primal_value`, which does the required primal-type-to-tangent-type conversion.
  3. Add `to_tangent_type` calls in various other places they're missing.
  4. Remove non-support for float0 in custom deriviatives?
  5. [Optional, WIP] Reinstate some checks that had been skipped over, presumably because of these bugs. (We'll see how far I get with it. Might end up being a separate PR.)
PiperOrigin-RevId: 675606346
  • Loading branch information
dougalm authored and The oryx Authors committed Sep 18, 2024
1 parent 93fc9ef commit c588523
Showing 1 changed file with 0 additions and 2 deletions.
2 changes: 0 additions & 2 deletions oryx/core/primitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,6 @@ def _jvp(primals, tangents, **params):
primals_out, tangents_out = ad.jvp(lu.wrap_init(self.impl,
params)).call_wrapped(
primals, tangents)
tangents_out = jax_util.safe_map(ad.recast_to_float0, primals_out,
tangents_out)
return primals_out, tangents_out

ad.primitive_jvps[self] = _jvp
Expand Down

0 comments on commit c588523

Please sign in to comment.