sensitivity_jax
is a package designed to allow taking first- and
second-order derivatives through optimization or any other fixed-point
process.
This package builds on top of JAX. We also maintain an implementation in PyTorch here.
Documentation can be found here.
Install using pip
$ pip install git+https://github.com/rdyro/sensitivity_jax.git
or from source
$ git clone [email protected]:rdyro/sensitivity_jax.git
$ cd sensitivity_jax
$ python3 setup.py install --user
Run all unit tests using
$ python3 setup.py test