-
I'm toying around with a probabilistic language built on top of Let's imagine I create a new primitive def model(rng_key, v):
sub_key, v = trace(stats.beta, rng_key, v, 1.0)
return sub_key, v here, Now, for PPL interfaces -- I actually want to be able to construct a If I try to do this, I'll run into an issue: jaxpr = jax.make_jaxpr(model)(key, 3.0) because That's okay! So what I want is the ability to construct an interpreter which is similar to v, tree = simulate(model)(key, 3.0) roughly, what
What's the right level to construct this functionality? Is the correct thing to make my own |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
To answer your question about primitives, if trace_p = jax.core.Primitive("trace")
def trace(dist, key, *args):
return trace_p.bind(key, *args, dist=dist) This will bake I think Jaxpr interpreters using Specifically, in Oryx you could look at the |
Beta Was this translation helpful? Give feedback.
To answer your question about primitives, if
trace
is a primitive, it can only take in array valued inputs as main arguments and any metadata can be passed in as params. You could pass instats.beta
as a param like this:This will bake
stats.beta
into the jaxpr and a custom interpreter will be able to use that information.I think Jaxpr interpreters using
make_jaxpr
are a viable approach for building a PPL on top of JAX. You could look at some existing JAX PPLs as an example. NumPyro implements its own effect-handling system in Python but Oryx uses Jaxpr interpreters.Specifica…