Skip to content

get_dependencies does not behave as expected with eqx_module #2077

@rayisaacalan

Description

@rayisaacalan

Bug Description

When replicating the "NNX and NumPyro Integration" example using equinox as a drop-in replacement for NNX, the model can be fit and behaves as expected for inference but cannot be visualized with render_model due to a TypeError in get_dependencies.

Steps to Reproduce

import jax
import jax.numpy as jnp
from jax import random
import jax.tree_util as jtu
import equinox as eqx

import numpyro
from numpyro.contrib.module import eqx_module, random_eqx_module
import numpyro.distributions as dist

rng_key = random.PRNGKey(seed=42)

n = 32 * 10
rng_key, rng_subkey = random.split(rng_key)
x = jnp.linspace(1, jnp.pi, n)
x_train = x[..., None]

class LocMLP(eqx.Module):
    """3-layer Multi-layer perceptron for the mean."""
    linear1: eqx.nn.Linear
    linear2: eqx.nn.Linear
    linear3: eqx.nn.Linear

    def __init__(self, din: int, dmid: int, dout: int, *, key):
        key1, key2, key3 = random.split(key, 3)
        self.linear1 = eqx.nn.Linear(din, dmid, key=key1)
        self.linear2 = eqx.nn.Linear(dmid, dmid, key=key2)
        self.linear3 = eqx.nn.Linear(dmid, dout, key=key3)

    def __call__(self, x):
        x = self.linear1(x)
        x = jax.nn.sigmoid(x)
        x = self.linear2(x)
        x = jax.nn.sigmoid(x)
        x = self.linear3(x)
        return x


class ScaleMLP(eqx.Module):
    """Single-layer MLP for the standard deviation."""
    linear: eqx.nn.Linear

    def __init__(self, *, key) -> None:
        self.linear = eqx.nn.Linear(1, 1, key=key)
    def __call__(self, x):
        x = self.linear(x)
        return jax.nn.softplus(x)

rng_key, mu_key, sigma_key = random.split(rng_key, 3)
mu_nn_module = LocMLP(din=1, dmid=8, dout=1, key=mu_key)
sigma_nn_module = ScaleMLP(key=sigma_key)
[jtu.keystr(path)[1:] for path, _ in jtu.tree_leaves_with_path(sigma_nn_module)]

['linear.weight', 'linear.bias']

def model(x):
    mu_nn = eqx_module("mu_nn", mu_nn_module)
    sigma_nn = random_eqx_module(
        "sigma_nn",
        sigma_nn_module,
        prior={
            "linear.weight": dist.HalfNormal(scale=1),
            "linear.bias": dist.Normal(loc=0, scale=1),
        },
    )

    mu = numpyro.deterministic("mu", jax.vmap(mu_nn)(x).squeeze())
    sigma = numpyro.deterministic("sigma", jax.vmap(sigma_nn)(x).squeeze())

    with numpyro.plate("data", x.shape[0]):
        numpyro.sample("likelihood", dist.Normal(loc=mu, scale=sigma))

## Everything up to here works as expected, the following does not work:

numpyro.render_model(
    model=model,
    model_args=(x_train,),
    render_distributions=True,
    render_params=True,
)

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[15], [line 1](vscode-notebook-cell:?execution_count=15&line=1)
----> [1](vscode-notebook-cell:?execution_count=15&line=1) numpyro.render_model(
      2     model=model,
      3     model_args=(x_train,),
      4     render_distributions=True,
      5     render_params=True,
      6 )

File numpyro\infer\inspect.py:626, in render_model(model, model_args, model_kwargs, filename, render_distributions, render_params)
    603 def render_model(
    604     model,
    605     model_args=None,
   (...)    609     render_params=False,
    610 ):
    611     """
    612     Wrap all functions needed to automatically render a model.
    613 
   (...)    624     :param bool render_params: Whether to show params in the plot.
    625     """
--> [626](file:numpyro/infer/inspect.py:626)     relations = get_model_relations(
    627         model,
    628         model_args=model_args,
    629         model_kwargs=model_kwargs,
    630     )
    631     graph_spec = generate_graph_specification(relations, render_params=render_params)
    632     graph = render_graph(graph_spec, render_distributions=render_distributions)

File numpyro\infer\inspect.py:326, in get_model_relations(model, model_args, model_kwargs)
    323     return PytreeTrace(trace)
    325 # We use eval_shape to avoid any array computation.
--> [326](numpyro/infer/inspect.py:326) trace = jax.eval_shape(get_trace).trace
    327 obs_sites = [
    328     name
    329     for name, site in trace.items()
    330     if site["type"] == "sample" and site["is_observed"]
    331 ]
    332 sample_dist = {
    333     name: site["fn_name"]
    334     for name, site in trace.items()
    335     if site["type"] in ["sample", "deterministic"]
    336 }

    [... skipping hidden 11 frame]

File jax\_src\interpreters\partial_eval.py:[2407](file:jax/_src/interpreters/partial_eval.py:2407), in _check_returned_jaxtypes(dbg, out_tracers)
   2405 else:
   2406   extra = ''
-> 2407 raise TypeError(
   2408 f"function {dbg.func_src_info} traced for {dbg.traced_for} returned a "
   2409 f"value of type {type(x)}{extra}, which is not a valid JAX type") from None

TypeError: function get_trace at numpyro\infer\inspect.py:307 traced for jit returned a value of type <class 'function'>, which is not a valid JAX type

Expected Behavior

Should be identical to the NNX example; only difference being the name of the parameters in sigma_nn being 'linear.weight' instead of 'linear.kernel'.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions