Skip to content

v0.0.5 - Rename JaxModule->WrappedJaxFunction, add wrapper for jax scalar-valued functions

Compare
Choose a tag to compare
@lebrice lebrice released this 13 Jun 17:37
· 7 commits to master since this release
3ec8451
  • Rename JaxModule -> WrappedJaxFunction
  • Add WrappedJaxScalarFunction:
    • Offered as an alternative to WrappedJaxFunction for scalar-valued functions (although both will work).
    • This is potentially more efficient, since it uses a jax.jit-ed jax.value_and_grad to do a fused forward and backward pass compared with WrappedJaxFunction that uses jax.vjp.
  • Add more tests
  • Add more doctests, which serve as both documentation and unit tests.
  • Add some examples in the README.

Full Changelog: v0.0.4...v0.0.5