Skip to content

Implementing a DSL which desugars to Jaxpr #11778

Answered by sharadmv
femtomc asked this question in General
Discussion options

You must be logged in to vote

To answer your question about primitives, if trace is a primitive, it can only take in array valued inputs as main arguments and any metadata can be passed in as params. You could pass in stats.beta as a param like this:

trace_p = jax.core.Primitive("trace")

def trace(dist, key, *args):
  return trace_p.bind(key, *args, dist=dist)

This will bake stats.beta into the jaxpr and a custom interpreter will be able to use that information.

I think Jaxpr interpreters using make_jaxpr are a viable approach for building a PPL on top of JAX. You could look at some existing JAX PPLs as an example. NumPyro implements its own effect-handling system in Python but Oryx uses Jaxpr interpreters.

Specifica…

Replies: 1 comment 2 replies

Comment options

You must be logged in to vote
2 replies
@femtomc
Comment options

@sharadmv
Comment options

sharadmv Aug 8, 2022
Collaborator

Answer selected by femtomc
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
2 participants