Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Leaked tracers when jitting MultiVariateNormalDiag distribution #230

Open
keraJLi opened this issue Mar 11, 2023 · 0 comments
Open

Leaked tracers when jitting MultiVariateNormalDiag distribution #230

keraJLi opened this issue Mar 11, 2023 · 0 comments

Comments

@keraJLi
Copy link

keraJLi commented Mar 11, 2023

I would like to return a MultiVariateNormalDiag distribution from a jitted function. However, I'm getting a leaked tracer. I've created the following minimal example

import jax
import distrax
import jax.numpy as jnp

jax.config.update("jax_check_tracer_leaks", True)

jitted_cat = jax.jit(lambda: distrax.Categorical(logits=jnp.zeros(1)))
jitted_mvn = jax.jit(lambda: distrax.MultivariateNormalDiag(loc=jnp.zeros(1)))

print("Creating jitted categorical distribution")
jitted_cat()
print("Creating jitted multivariate normal distribution")
jitted_mvn()

which outputs

Creating jitted categorical distribution
Creating jitted multivariate normal distribution
Traceback (most recent call last):
  File "test_mvn.py", line 13, in <module>
    jitted_mvn()
  File "/home/me/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 163, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/me/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 235, in cache_miss
    outs, out_flat, out_tree, args_flat = _python_pjit_helper(
  File "/home/me/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 179, in _python_pjit_helper
    args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
  File "/home/me/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/api.py", line 440, in infer_params
    return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
  File "/home/me/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 513, in common_infer_params
    jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
  File "/home/me/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 965, in _pjit_jaxpr
    jaxpr, final_consts, global_out_avals = _create_pjit_jaxpr(
  File "/home/me/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/linear_util.py", line 301, in memoized_fun
    ans = call(fun, *args)
  File "/home/me/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 923, in _create_pjit_jaxpr
    jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
  File "/home/me/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/me/anaconda3/envs/jax/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 2031, in trace_to_jaxpr_dynamic
    with core.new_main(DynamicJaxprTrace, dynamic=True) as main:  # type: ignore
  File "/home/me/anaconda3/envs/jax/lib/python3.10/contextlib.py", line 142, in __exit__
    next(self.gen)
  File "/home/me/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/core.py", line 1083, in new_main
    if leaked_tracers: raise leaked_tracer_error("trace", t(), leaked_tracers)
jax._src.traceback_util.UnfilteredStackTrace: Exception: Leaked trace MainTrace(1,DynamicJaxprTrace). Leaked tracer(s):

Traced<ShapedArray(float32[1])>with<DynamicJaxprTrace(level=1/0)>
This DynamicJaxprTracer was created on line /home/me/test_mvn.py:8 (<lambda>)
<DynamicJaxprTracer 140510728702848> is referred to by <ScalarAffine 140511566864992>._scale
<ScalarAffine 140511566864992> is referred to by <Block 140511566865664>._bijector
<Block 140511566865664> is referred to by <method 140511570285120>
<method 140511570285120> is referred to by <list 140510728761984>[14]
<list 140510728761984> is referred to by <tuple 140510728754816>[0]

Traced<ShapedArray(float32[1])>with<DynamicJaxprTrace(level=1/0)>
This DynamicJaxprTracer was created on line /home/me/test_mvn.py:8 (<lambda>)
<DynamicJaxprTracer 140510728705008> is referred to by <ScalarAffine 140511566864992>._inv_scale
<ScalarAffine 140511566864992> is referred to by <Block 140511566865664>._bijector
<Block 140511566865664> is referred to by <method 140511570285120>
<method 140511570285120> is referred to by <list 140510728761984>[14]
<list 140510728761984> is referred to by <tuple 140510728754816>[0]

Traced<ShapedArray(float32[1])>with<DynamicJaxprTrace(level=1/0)>
This DynamicJaxprTracer was created on line /home/me/test_mvn.py:7 (<lambda>)
<DynamicJaxprTracer 140510728705328> is referred to by <ScalarAffine 140511566864992>._log_scale
<ScalarAffine 140511566864992> is referred to by <Block 140511566865664>._bijector
<Block 140511566865664> is referred to by <method 140511570285120>
<method 140511570285120> is referred to by <list 140510728761984>[14]
<list 140510728761984> is referred to by <tuple 140510728754816>[0]

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/me/test_mvn.py", line 13, in <module>
    jitted_mvn()
  File "/home/me/anaconda3/envs/jax/lib/python3.10/contextlib.py", line 142, in __exit__
    next(self.gen)
Exception: Leaked trace MainTrace(1,DynamicJaxprTrace). Leaked tracer(s):

Traced<ShapedArray(float32[1])>with<DynamicJaxprTrace(level=1/0)>
This DynamicJaxprTracer was created on line /home/me/test_mvn.py:8 (<lambda>)
<DynamicJaxprTracer 140510728702848> is referred to by <ScalarAffine 140511566864992>._scale
<ScalarAffine 140511566864992> is referred to by <Block 140511566865664>._bijector
<Block 140511566865664> is referred to by <method 140511570285120>
<method 140511570285120> is referred to by <list 140510728761984>[14]
<list 140510728761984> is referred to by <tuple 140510728754816>[0]

Traced<ShapedArray(float32[1])>with<DynamicJaxprTrace(level=1/0)>
This DynamicJaxprTracer was created on line /home/me/test_mvn.py:8 (<lambda>)
<DynamicJaxprTracer 140510728705008> is referred to by <ScalarAffine 140511566864992>._inv_scale
<ScalarAffine 140511566864992> is referred to by <Block 140511566865664>._bijector
<Block 140511566865664> is referred to by <method 140511570285120>
<method 140511570285120> is referred to by <list 140510728761984>[14]
<list 140510728761984> is referred to by <tuple 140510728754816>[0]

Traced<ShapedArray(float32[1])>with<DynamicJaxprTrace(level=1/0)>
This DynamicJaxprTracer was created on line /home/me/test_mvn.py:7 (<lambda>)
<DynamicJaxprTracer 140510728705328> is referred to by <ScalarAffine 140511566864992>._log_scale
<ScalarAffine 140511566864992> is referred to by <Block 140511566865664>._bijector
<Block 140511566865664> is referred to by <method 140511570285120>
<method 140511570285120> is referred to by <list 140510728761984>[14]
<list 140510728761984> is referred to by <tuple 140510728754816>[0]

As seen from the output, returning a Categorical works fine, but returning a MultivariateNormalDiag results in leaked tracers. This seems like a bug. I'm using

distrax==0.1.3
jax==0.4.5
jaxlib==0.4.4+cuda11.cudnn82
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant