-
Notifications
You must be signed in to change notification settings - Fork 249
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Equinox Integration #2005
base: master
Are you sure you want to change the base?
Equinox Integration #2005
Conversation
I'm running into issues with stateful operations. Equinox manages states functionally, ie class Model(eqx.Module):
norm: eqx.nn.BatchNorm
linear: eqx.nn.Linear
def __init__(self, key):
self.norm = eqx.nn.BatchNorm(input_size=3, axis_name="batch")
self.linear = eqx.nn.Linear(in_features=32, out_features=32, key=key)
def __call__(self, x, state):
x, state = self.norm(x, state)
x = self.linear1(x)
return x, state
model, state = eqx.nn.make_with_state(Model)(key=rng_key)
...
for _ in range(steps):
# Full-batch gradient descent in this simple example.
model, state, opt_state = make_step(model, state, opt_state, xs, ys) This is a bit different then the other jax neural net libraries. I think there are two options:
feedback appreciated! |
) from e | ||
|
||
rng_key = numpyro.prng_key() | ||
nn_module, state = eqx.nn.make_with_state(nn_module)(key=rng_key, *args, **kwargs) # noqa: E1111 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would suggest to let users eagerly initialize the nn_module because no init_fn, apply_fn pattern happens here. We do the same in nnx_module
.
For state, equinox handles the state explicitly, so we don't know how users' call function is implemented, see https://docs.kidger.site/equinox/examples/stateful/ . I'm not sure what is a good api for it. Maybe users need to do something explicitly like
nn_module, eager_state = make_with_state(...)
mutable_holder = numpyro.mutable("nn_state")
if mutable_holder is None:
mutable_holder = numpyro.mutable("nn_state", {"state": eager_state})
nn = equinox_module("nn", nn_module)
out, new_state = nn(in, mutable_holder["state"]) # assume the last output is the new state, up to users
mutable_holder["state"] = new_state
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added a quick commit that gets the current approach working (I think) just to have it as an option. In this approach state is handled entirely under the hood, which is nice for end users because it coincides with the other modules, but its not great for equinox users who are used to managing their states
I thought of two other possible options as alternatives to what you suggested (to try and reduce overhead for end users)
Approach 1: pass mutable_holder
to users so some of it is handled under the hood
nn, mutable_holder = equinox_module("nn", nn_module, *args, **kwargs)
out, mutable_holder["state"] = nn(in, mutable_holder["state"]) # assume the last output is the new state, up to users
Approach 2: Eager initialization, user handles state
I think this approach stays most true to how equinox works. Some caveats:
- There will probably have to be an apply function that updates state under the hood (not a big deal assuming it works)
- I think end users will have to manage logic for passing through states for fitting vs. prediction (I'll have to see how difficult that is)
nn_module, eager_state = make_with_state(...)(**init_kwargs)
nn, state = equinox_module("nn", nn_module, init_state=eager_state) # state is optional
out, state = nn(x, state) # assume the last output is the new state, up to users
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be nice to reduce some code but I guess equinox users will want something explicit. We can hide some code but still expose the state (and the way we update it via e.g. mutable_holder), then why not expose the full logic to give equinox users the flexibility they want. I also dont like making equinox_module sometimes return state sometimes not.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah thats fair enough thanks for the feedback! I'll work on adjusting the implementation towards that example you provided (also below to confirm we're on the same page)
pattern with state:
nn_module, eager_state = make_with_state(model)(...)
mutable_holder = numpyro.mutable("nn_state")
if mutable_holder is None:
mutable_holder = numpyro.mutable("nn_state", {"state": eager_state})
nn = equinox_module("nn", nn_module)
out, new_state = nn(in, mutable_holder["state"]) # assume the last output is the new state, up to users
mutable_holder["state"] = new_state
pattern when the model isnt stateful:
nn_module = model(...)
nn = equinox_module("nn", nn_module)
out = nn(in)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The usage looks good to me. I guess you can also do
mutable_holder = numpyro.mutable("nn_state", {"state": eager_state})
nn = equinox_module("nn", nn_module)
without the if statement.
Closes #1709.
This adds Equinox support to numpyro's contrib module