v0.0.5 - Rename JaxModule->WrappedJaxFunction, add wrapper for jax scalar-valued functions
- Rename
JaxModule
->WrappedJaxFunction
- Add
WrappedJaxScalarFunction
:- Offered as an alternative to
WrappedJaxFunction
for scalar-valued functions (although both will work). - This is potentially more efficient, since it uses a
jax.jit
-edjax.value_and_grad
to do a fused forward and backward pass compared withWrappedJaxFunction
that usesjax.vjp
.
- Offered as an alternative to
- Add more tests
- Add more doctests, which serve as both documentation and unit tests.
- Add some examples in the README.
Full Changelog: v0.0.4...v0.0.5