Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug: treescope fails to render typed PRNG key arrays #52

Open
amifalk opened this issue Jan 3, 2025 · 0 comments
Open

Bug: treescope fails to render typed PRNG key arrays #52

amifalk opened this issue Jan 3, 2025 · 0 comments

Comments

@amifalk
Copy link

amifalk commented Jan 3, 2025

With jax 0.4.37 and treescope 0.1.7:

import jax
import treescope as ts

ts.basic_interactive_setup()

jax.random.key(0)
<TypeError during deferred rendering
Traceback (most recent call last):
  File ".../.venv/lib/python3.10/site-packages/treescope/lowering.py", line 358, in _render_to_html_as_root_streaming
    replacement_part = deferred.thunk(layout_decision)
  File ".../.venv/lib/python3.10/site-packages/treescope/external/jax_support.py", line 606, in _thunk
    summarized = adapter.get_array_summary(node, fast=False)
  File ".../.venv/lib/python3.10/site-packages/treescope/external/jax_support.py", line 525, in get_array_summary
    output_parts.append(summarize_array_data(array))
  File ".../.venv/lib/python3.10/site-packages/treescope/external/jax_support.py", line 479, in summarize_array_data
    output_parts.extend(_summarize_array_data_unconditionally(array))
  File ".../.venv/lib/python3.10/site-packages/treescope/external/jax_support.py", line 433, in _summarize_array_data_unconditionally
    stat = compute_summary(array, is_floating, is_integer, is_bool)
  File ".../.venv/lib/python3.10/site-packages/treescope/external/jax_support.py", line 391, in _compute_summary
    x = xnp.array(x)
  File ".../.venv/lib/python3.10/site-packages/jax/_src/prng.py", line 283, in __array__
    raise TypeError("JAX array with PRNGKey dtype cannot be converted to a NumPy array."
TypeError: JAX array with PRNGKey dtype cannot be converted to a NumPy array. Use jax.random.key_data(arr) if you wish to extract the underlying integer array.
>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant