Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Equinox Integration #2005

Draft
wants to merge 7 commits into
base: master
Choose a base branch
from

Conversation

kylejcaron
Copy link
Contributor

Closes #1709.

This adds Equinox support to numpyro's contrib module

@kylejcaron
Copy link
Contributor Author

I'm running into issues with stateful operations.

Equinox manages states functionally, ie

class Model(eqx.Module):
    norm: eqx.nn.BatchNorm
    linear: eqx.nn.Linear

    def __init__(self, key):
        self.norm = eqx.nn.BatchNorm(input_size=3, axis_name="batch")
        self.linear = eqx.nn.Linear(in_features=32, out_features=32, key=key)

    def __call__(self, x, state):
        x, state = self.norm(x, state)
        x = self.linear1(x)
        return x, state

model, state = eqx.nn.make_with_state(Model)(key=rng_key)  
...

for _ in range(steps):
    # Full-batch gradient descent in this simple example.
    model, state, opt_state = make_step(model, state, opt_state, xs, ys)

This is a bit different then the other jax neural net libraries. I think there are two options:

  1. Provide eqx_module() with an uninitialized model class and manage the state under the hood (current approach in this PR). I'm having a bit of trouble getting this to work
    • This also might be a bit awkward to equinox users who are used to managing states
  2. Trust users to use Eager initialization outside the model, with the state when applicable. They'd also have to know to pass the state through the call function. I'm not sure how registering the state with numpyro_mutable works with this

feedback appreciated!

) from e

rng_key = numpyro.prng_key()
nn_module, state = eqx.nn.make_with_state(nn_module)(key=rng_key, *args, **kwargs) # noqa: E1111
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest to let users eagerly initialize the nn_module because no init_fn, apply_fn pattern happens here. We do the same in nnx_module.

For state, equinox handles the state explicitly, so we don't know how users' call function is implemented, see https://docs.kidger.site/equinox/examples/stateful/ . I'm not sure what is a good api for it. Maybe users need to do something explicitly like

nn_module, eager_state = make_with_state(...)
mutable_holder = numpyro.mutable("nn_state")
if mutable_holder is None:
    mutable_holder = numpyro.mutable("nn_state", {"state": eager_state})
nn = equinox_module("nn", nn_module)
out, new_state = nn(in, mutable_holder["state"])  # assume the last output is the new state, up to users
mutable_holder["state"] = new_state

Copy link
Contributor Author

@kylejcaron kylejcaron Mar 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a quick commit that gets the current approach working (I think) just to have it as an option. In this approach state is handled entirely under the hood, which is nice for end users because it coincides with the other modules, but its not great for equinox users who are used to managing their states

I thought of two other possible options as alternatives to what you suggested (to try and reduce overhead for end users)

Approach 1: pass mutable_holder to users so some of it is handled under the hood

nn, mutable_holder = equinox_module("nn", nn_module, *args, **kwargs)
out, mutable_holder["state"] = nn(in, mutable_holder["state"])  # assume the last output is the new state, up to users

Approach 2: Eager initialization, user handles state

I think this approach stays most true to how equinox works. Some caveats:

  • There will probably have to be an apply function that updates state under the hood (not a big deal assuming it works)
  • I think end users will have to manage logic for passing through states for fitting vs. prediction (I'll have to see how difficult that is)
nn_module, eager_state = make_with_state(...)(**init_kwargs)
nn, state = equinox_module("nn", nn_module, init_state=eager_state) # state is optional
out, state = nn(x, state)  # assume the last output is the new state, up to users

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to reduce some code but I guess equinox users will want something explicit. We can hide some code but still expose the state (and the way we update it via e.g. mutable_holder), then why not expose the full logic to give equinox users the flexibility they want. I also dont like making equinox_module sometimes return state sometimes not.

Copy link
Contributor Author

@kylejcaron kylejcaron Mar 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah thats fair enough thanks for the feedback! I'll work on adjusting the implementation towards that example you provided (also below to confirm we're on the same page)

pattern with state:

nn_module, eager_state = make_with_state(model)(...)
mutable_holder = numpyro.mutable("nn_state")
if mutable_holder is None:
    mutable_holder = numpyro.mutable("nn_state", {"state": eager_state})
nn = equinox_module("nn", nn_module)
out, new_state = nn(in, mutable_holder["state"])  # assume the last output is the new state, up to users
mutable_holder["state"] = new_state

pattern when the model isnt stateful:

nn_module = model(...)
nn = equinox_module("nn", nn_module)
out = nn(in) 

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The usage looks good to me. I guess you can also do

mutable_holder = numpyro.mutable("nn_state", {"state": eager_state})
nn = equinox_module("nn", nn_module)

without the if statement.

@fehiepsi fehiepsi added the WIP label Mar 13, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Equinox models integration
2 participants