Skip to content

How to return 'early' (possibly with an error code) from jitted functions? #15265

Answered by jakevdp
wiep asked this question in Q&A
Discussion options

You must be logged in to vote

There's no way to return early from a JIT-compiled function conditioned on a dynamic value. The best available pattern is probably to define parts of your function in blocks, and use lax.cond to string them together; i.e. something like this:

import jax

def f(x):
    return 0, x + 1

def g(x):
    return 0, x + 1

def h(x):
    return 0, x + 1

def identity(x):
  return 0, x

@jax.jit
def do_work(x):
  error_code, x = f(x)
  error_code, x = jax.lax.cond(error_code == 0, g, identity, x)
  error_code, x = jax.lax.cond(error_code == 0, h, identity, x)
  return x

print(do_work(2))
# 5

If you're interested in tracking and returning error codes from your functions, you can certainly do that …

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@wiep
Comment options

Answer selected by wiep
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants