-
Notifications
You must be signed in to change notification settings - Fork 269
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Bug Description
When replicating the "NNX and NumPyro Integration" example using equinox as a drop-in replacement for NNX, the model can be fit and behaves as expected for inference but cannot be visualized with render_model due to a TypeError in get_dependencies.
Steps to Reproduce
import jax
import jax.numpy as jnp
from jax import random
import jax.tree_util as jtu
import equinox as eqx
import numpyro
from numpyro.contrib.module import eqx_module, random_eqx_module
import numpyro.distributions as dist
rng_key = random.PRNGKey(seed=42)
n = 32 * 10
rng_key, rng_subkey = random.split(rng_key)
x = jnp.linspace(1, jnp.pi, n)
x_train = x[..., None]
class LocMLP(eqx.Module):
"""3-layer Multi-layer perceptron for the mean."""
linear1: eqx.nn.Linear
linear2: eqx.nn.Linear
linear3: eqx.nn.Linear
def __init__(self, din: int, dmid: int, dout: int, *, key):
key1, key2, key3 = random.split(key, 3)
self.linear1 = eqx.nn.Linear(din, dmid, key=key1)
self.linear2 = eqx.nn.Linear(dmid, dmid, key=key2)
self.linear3 = eqx.nn.Linear(dmid, dout, key=key3)
def __call__(self, x):
x = self.linear1(x)
x = jax.nn.sigmoid(x)
x = self.linear2(x)
x = jax.nn.sigmoid(x)
x = self.linear3(x)
return x
class ScaleMLP(eqx.Module):
"""Single-layer MLP for the standard deviation."""
linear: eqx.nn.Linear
def __init__(self, *, key) -> None:
self.linear = eqx.nn.Linear(1, 1, key=key)
def __call__(self, x):
x = self.linear(x)
return jax.nn.softplus(x)
rng_key, mu_key, sigma_key = random.split(rng_key, 3)
mu_nn_module = LocMLP(din=1, dmid=8, dout=1, key=mu_key)
sigma_nn_module = ScaleMLP(key=sigma_key)
[jtu.keystr(path)[1:] for path, _ in jtu.tree_leaves_with_path(sigma_nn_module)]
['linear.weight', 'linear.bias']
def model(x):
mu_nn = eqx_module("mu_nn", mu_nn_module)
sigma_nn = random_eqx_module(
"sigma_nn",
sigma_nn_module,
prior={
"linear.weight": dist.HalfNormal(scale=1),
"linear.bias": dist.Normal(loc=0, scale=1),
},
)
mu = numpyro.deterministic("mu", jax.vmap(mu_nn)(x).squeeze())
sigma = numpyro.deterministic("sigma", jax.vmap(sigma_nn)(x).squeeze())
with numpyro.plate("data", x.shape[0]):
numpyro.sample("likelihood", dist.Normal(loc=mu, scale=sigma))
## Everything up to here works as expected, the following does not work:
numpyro.render_model(
model=model,
model_args=(x_train,),
render_distributions=True,
render_params=True,
)
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[15], [line 1](vscode-notebook-cell:?execution_count=15&line=1)
----> [1](vscode-notebook-cell:?execution_count=15&line=1) numpyro.render_model(
2 model=model,
3 model_args=(x_train,),
4 render_distributions=True,
5 render_params=True,
6 )
File numpyro\infer\inspect.py:626, in render_model(model, model_args, model_kwargs, filename, render_distributions, render_params)
603 def render_model(
604 model,
605 model_args=None,
(...) 609 render_params=False,
610 ):
611 """
612 Wrap all functions needed to automatically render a model.
613
(...) 624 :param bool render_params: Whether to show params in the plot.
625 """
--> [626](file:numpyro/infer/inspect.py:626) relations = get_model_relations(
627 model,
628 model_args=model_args,
629 model_kwargs=model_kwargs,
630 )
631 graph_spec = generate_graph_specification(relations, render_params=render_params)
632 graph = render_graph(graph_spec, render_distributions=render_distributions)
File numpyro\infer\inspect.py:326, in get_model_relations(model, model_args, model_kwargs)
323 return PytreeTrace(trace)
325 # We use eval_shape to avoid any array computation.
--> [326](numpyro/infer/inspect.py:326) trace = jax.eval_shape(get_trace).trace
327 obs_sites = [
328 name
329 for name, site in trace.items()
330 if site["type"] == "sample" and site["is_observed"]
331 ]
332 sample_dist = {
333 name: site["fn_name"]
334 for name, site in trace.items()
335 if site["type"] in ["sample", "deterministic"]
336 }
[... skipping hidden 11 frame]
File jax\_src\interpreters\partial_eval.py:[2407](file:jax/_src/interpreters/partial_eval.py:2407), in _check_returned_jaxtypes(dbg, out_tracers)
2405 else:
2406 extra = ''
-> 2407 raise TypeError(
2408 f"function {dbg.func_src_info} traced for {dbg.traced_for} returned a "
2409 f"value of type {type(x)}{extra}, which is not a valid JAX type") from None
TypeError: function get_trace at numpyro\infer\inspect.py:307 traced for jit returned a value of type <class 'function'>, which is not a valid JAX type
Expected Behavior
Should be identical to the NNX example; only difference being the name of the parameters in sigma_nn
being 'linear.weight' instead of 'linear.kernel'.
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working