Skip to content

How to import jax.numpy at jax.debug.breakpoint() #15117

Answered by sharadmv
HHalva asked this question in Q&A
Discussion options

You must be logged in to vote

Try adding the line import jax.numpy as jnp right before you run the breakpoint. I'd also suggest using numpy not jax.numpy if you're on an accelerator, since the breakpoint uses NumPy arrays, not JAX arrays.

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@HHalva
Comment options

Answer selected by HHalva
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants