Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

without_jit=True for already jitted functions #22

Open
fabiannagel opened this issue Feb 22, 2021 · 0 comments
Open

without_jit=True for already jitted functions #22

fabiannagel opened this issue Feb 22, 2021 · 0 comments

Comments

@fabiannagel
Copy link

In most JAX-based implementations, jit is almost always included. Basically, if there is no reason not to use it, people will try to take advantage of its speedup.

I noticed that @chex.variants(with_jit=True, without_jit=True) is a great way to assert the same behavior for both execution paths, as long as the variant is derived from a non-jitted function.

In the following example, I would expect to see "Tracing fn" four times total: Three times for the non-jitted variants and once for the initial jit compiliation. In reality, test_variant_pre_jitted() is executed twice with the jitted fn, resulting in two tracer outputs.

@chex.variants(with_jit=True, without_jit=Truue)
def test_variant_pre_jitted(self):
  @jit
  def fn(x, y):
    print("Tracing fn")
    return x + y

  var_fn = self.variant(fn)
  self.assertEqual(var_fn(1, 2), 3)
  self.assertEqual(var_fn(3, 4), 7)
  self.assertEqual(var_fn(5, 6), 11)

Of course, omitting @jit will lead to the expected behavior. However, when more complex implementations already make use of jit, variants do not make sense anymore, sadly.

My case is the latter and I only see the option of implementing a model-wide use_jit flag so that I can derive variants from non-jitted code. However, this makes the whole idea of variants rather obsolete altogether.

I'm aware this could well be a limitation of JAX and jit itself rather than chex. In that case, I think an error when jitted code is passed to variant() would make this more transparent.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant