Skip to content
/ nam_jax Public

Jax-based implementation of Neural Additive Models

License

Notifications You must be signed in to change notification settings

Habush/nam_jax

Repository files navigation

Neural Additive Models in JAX

This repo contains JAX-based version of the model introduced in Neural Additive Models: Interpretable Machine Learning with Neural Nets by R. Agarwal et.al 2021.

NAM Architecture

Dependencies

  • jax
  • optax
  • haiku # used for implementing NN model
  • torch # used for creating mini-batches
  • numpy
  • scikit-learn

Examples

Checkout the nam_regression_example.ipynb notebook to see an example of using the model for the California housing Dataset

About

Jax-based implementation of Neural Additive Models

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published