implementing a neural network manually in Jax. #16762
Unanswered
HosseinKhodavirdi
asked this question in
Q&A
Replies: 1 comment 5 replies
-
Hi - it sounds like you're curious about what JAX tracers are, and how you can think of the execution model for JAX programs. I'd start with these docs: To more specifically answer your questions, I wouldn't worry about the level number in "Level 9 tracer". When you see a tracer, it just means that your code is being transformed by a JAX transformation like JAX functions should work correctly with traced arrays, which is why you see that further computations work as expected despite the fact that the inputs are tracers. Hope that helps! |
Beta Was this translation helpful? Give feedback.
5 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi, I am new to Jax and I am trying to solve a deep learning problem by implementing its network manually. For instance, this part, Computes the forward pass for each example individually: (np is jax numpy here)
I have multiple questions which I appreciate any help:
and then to:
Can someone explain what do all these things mean?
My last question is about using the grad operator in JAX. I have a function which its output may be complex in some epochs. I tried grad to take the derivative of my function but it showed an error that it does not work for complex values. Setting holomorphic=True also does not work since the function do not have always a complex type. Does anyone have any suggesiton?
Thanks much in advance,
Beta Was this translation helpful? Give feedback.
All reactions