You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I would like to return a MultiVariateNormalDiag distribution from a jitted function. However, I'm getting a leaked tracer. I've created the following minimal example
Creating jitted categorical distribution
Creating jitted multivariate normal distribution
Traceback (most recent call last):
File "test_mvn.py", line 13, in <module>
jitted_mvn()
File "/home/me/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 163, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/me/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 235, in cache_miss
outs, out_flat, out_tree, args_flat = _python_pjit_helper(
File "/home/me/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 179, in _python_pjit_helper
args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
File "/home/me/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/api.py", line 440, in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
File "/home/me/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 513, in common_infer_params
jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
File "/home/me/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 965, in _pjit_jaxpr
jaxpr, final_consts, global_out_avals = _create_pjit_jaxpr(
File "/home/me/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/linear_util.py", line 301, in memoized_fun
ans = call(fun, *args)
File "/home/me/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 923, in _create_pjit_jaxpr
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
File "/home/me/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/home/me/anaconda3/envs/jax/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 2031, in trace_to_jaxpr_dynamic
with core.new_main(DynamicJaxprTrace, dynamic=True) as main: # type: ignore
File "/home/me/anaconda3/envs/jax/lib/python3.10/contextlib.py", line 142, in __exit__
next(self.gen)
File "/home/me/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/core.py", line 1083, in new_main
if leaked_tracers: raise leaked_tracer_error("trace", t(), leaked_tracers)
jax._src.traceback_util.UnfilteredStackTrace: Exception: Leaked trace MainTrace(1,DynamicJaxprTrace). Leaked tracer(s):
Traced<ShapedArray(float32[1])>with<DynamicJaxprTrace(level=1/0)>
This DynamicJaxprTracer was created on line /home/me/test_mvn.py:8 (<lambda>)
<DynamicJaxprTracer 140510728702848> is referred to by <ScalarAffine 140511566864992>._scale
<ScalarAffine 140511566864992> is referred to by <Block 140511566865664>._bijector
<Block 140511566865664> is referred to by <method 140511570285120>
<method 140511570285120> is referred to by <list 140510728761984>[14]
<list 140510728761984> is referred to by <tuple 140510728754816>[0]
Traced<ShapedArray(float32[1])>with<DynamicJaxprTrace(level=1/0)>
This DynamicJaxprTracer was created on line /home/me/test_mvn.py:8 (<lambda>)
<DynamicJaxprTracer 140510728705008> is referred to by <ScalarAffine 140511566864992>._inv_scale
<ScalarAffine 140511566864992> is referred to by <Block 140511566865664>._bijector
<Block 140511566865664> is referred to by <method 140511570285120>
<method 140511570285120> is referred to by <list 140510728761984>[14]
<list 140510728761984> is referred to by <tuple 140510728754816>[0]
Traced<ShapedArray(float32[1])>with<DynamicJaxprTrace(level=1/0)>
This DynamicJaxprTracer was created on line /home/me/test_mvn.py:7 (<lambda>)
<DynamicJaxprTracer 140510728705328> is referred to by <ScalarAffine 140511566864992>._log_scale
<ScalarAffine 140511566864992> is referred to by <Block 140511566865664>._bijector
<Block 140511566865664> is referred to by <method 140511570285120>
<method 140511570285120> is referred to by <list 140510728761984>[14]
<list 140510728761984> is referred to by <tuple 140510728754816>[0]
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/me/test_mvn.py", line 13, in <module>
jitted_mvn()
File "/home/me/anaconda3/envs/jax/lib/python3.10/contextlib.py", line 142, in __exit__
next(self.gen)
Exception: Leaked trace MainTrace(1,DynamicJaxprTrace). Leaked tracer(s):
Traced<ShapedArray(float32[1])>with<DynamicJaxprTrace(level=1/0)>
This DynamicJaxprTracer was created on line /home/me/test_mvn.py:8 (<lambda>)
<DynamicJaxprTracer 140510728702848> is referred to by <ScalarAffine 140511566864992>._scale
<ScalarAffine 140511566864992> is referred to by <Block 140511566865664>._bijector
<Block 140511566865664> is referred to by <method 140511570285120>
<method 140511570285120> is referred to by <list 140510728761984>[14]
<list 140510728761984> is referred to by <tuple 140510728754816>[0]
Traced<ShapedArray(float32[1])>with<DynamicJaxprTrace(level=1/0)>
This DynamicJaxprTracer was created on line /home/me/test_mvn.py:8 (<lambda>)
<DynamicJaxprTracer 140510728705008> is referred to by <ScalarAffine 140511566864992>._inv_scale
<ScalarAffine 140511566864992> is referred to by <Block 140511566865664>._bijector
<Block 140511566865664> is referred to by <method 140511570285120>
<method 140511570285120> is referred to by <list 140510728761984>[14]
<list 140510728761984> is referred to by <tuple 140510728754816>[0]
Traced<ShapedArray(float32[1])>with<DynamicJaxprTrace(level=1/0)>
This DynamicJaxprTracer was created on line /home/me/test_mvn.py:7 (<lambda>)
<DynamicJaxprTracer 140510728705328> is referred to by <ScalarAffine 140511566864992>._log_scale
<ScalarAffine 140511566864992> is referred to by <Block 140511566865664>._bijector
<Block 140511566865664> is referred to by <method 140511570285120>
<method 140511570285120> is referred to by <list 140510728761984>[14]
<list 140510728761984> is referred to by <tuple 140510728754816>[0]
As seen from the output, returning a Categorical works fine, but returning a MultivariateNormalDiag results in leaked tracers. This seems like a bug. I'm using
I would like to return a
MultiVariateNormalDiag
distribution from a jitted function. However, I'm getting a leaked tracer. I've created the following minimal examplewhich outputs
As seen from the output, returning a
Categorical
works fine, but returning aMultivariateNormalDiag
results in leaked tracers. This seems like a bug. I'm usingThe text was updated successfully, but these errors were encountered: