You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In most JAX-based implementations, jit is almost always included. Basically, if there is no reason not to use it, people will try to take advantage of its speedup.
I noticed that @chex.variants(with_jit=True, without_jit=True) is a great way to assert the same behavior for both execution paths, as long as the variant is derived from a non-jitted function.
In the following example, I would expect to see "Tracing fn" four times total: Three times for the non-jitted variants and once for the initial jit compiliation. In reality, test_variant_pre_jitted() is executed twice with the jitted fn, resulting in two tracer outputs.
Of course, omitting @jit will lead to the expected behavior. However, when more complex implementations already make use of jit, variants do not make sense anymore, sadly.
My case is the latter and I only see the option of implementing a model-wide use_jit flag so that I can derive variants from non-jitted code. However, this makes the whole idea of variants rather obsolete altogether.
I'm aware this could well be a limitation of JAX and jit itself rather than chex. In that case, I think an error when jitted code is passed to variant() would make this more transparent.
The text was updated successfully, but these errors were encountered:
In most JAX-based implementations,
jit
is almost always included. Basically, if there is no reason not to use it, people will try to take advantage of its speedup.I noticed that
@chex.variants(with_jit=True, without_jit=True)
is a great way to assert the same behavior for both execution paths, as long as the variant is derived from a non-jitted function.In the following example, I would expect to see "Tracing fn" four times total: Three times for the non-jitted variants and once for the initial jit compiliation. In reality,
test_variant_pre_jitted()
is executed twice with the jittedfn
, resulting in two tracer outputs.Of course, omitting
@jit
will lead to the expected behavior. However, when more complex implementations already make use ofjit
, variants do not make sense anymore, sadly.My case is the latter and I only see the option of implementing a model-wide
use_jit
flag so that I can derive variants from non-jitted code. However, this makes the whole idea of variants rather obsolete altogether.I'm aware this could well be a limitation of JAX and
jit
itself rather than chex. In that case, I think an error when jitted code is passed tovariant()
would make this more transparent.The text was updated successfully, but these errors were encountered: