-
I am wondering whether there is a best practice or some syntax sugar to return early from a jit-compatible function. The best I could come up with are nested Consider the following code as an example. Here, the functions In either case the functions import jax
import jax.numpy as jnp
from typing import NamedTuple
class Return(NamedTuple):
payload: jax.Array
error_code: jax.Array = jnp.int32(0)
@staticmethod
def error(error_code: int, payload_desc: jax.Array):
return Return(jnp.empty_like(payload_desc), jnp.asarray(error_code, jnp.int32))
def f(x):
return 0, x + 1
def g(x):
return 0, x + 1
def h(x):
return 0, x + 1
# Function returns early if any of the functions called returns an non-zero
# error code.
def do_work_not_jitable(x: jax.Array) -> Return:
payload_desc = x
error_code, fx = f(x)
if error_code != 0:
return Return.error(error_code, payload_desc)
error_code, gx = g(fx)
if error_code != 0:
return Return.error(error_code, payload_desc)
error_code, hx = h(gx)
if error_code != 0:
return Return.error(error_code, payload_desc)
return Return(hx)
# Function returns early if any of the functions called returns an non-zero
# error code but uses jax.lax.cond to avoid python control flow.
def do_work_jitable(x: jax.Array) -> Return:
payload_desc = x
def step_h(gx):
error_code, hx = h(gx)
return jax.lax.cond(
error_code != 0,
lambda _: Return.error(error_code, payload_desc),
lambda hx: Return(hx),
hx,
)
def step_g(fx):
error_code, gx = g(fx)
return jax.lax.cond(
error_code != 0,
lambda _: Return.error(error_code, payload_desc),
step_h,
gx,
)
# step_f
error_code, fx = f(x)
return jax.lax.cond(
error_code != 0,
lambda _: Return.error(error_code, payload_desc),
step_g,
fx,
)
if __name__ == "__main__":
x = jnp.array([1, 2, 3])
print(do_work_not_jitable(x))
print(do_work_jitable(x))
do_work_jitable = jax.jit(do_work_jitable)
print(do_work_jitable(x)) |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
There's no way to return early from a JIT-compiled function conditioned on a dynamic value. The best available pattern is probably to define parts of your function in blocks, and use
If you're interested in tracking and returning error codes from your functions, you can certainly do that manually (by adding an extra return value) or you can do so more automatically using the experimental |
Beta Was this translation helpful? Give feedback.
There's no way to return early from a JIT-compiled function conditioned on a dynamic value. The best available pattern is probably to define parts of your function in blocks, and use
lax.cond
to string them together; i.e. something like this:If you're interested in tracking and returning error codes from your functions, you can certainly do that …