Replies: 1 comment
-
Can you try with Check new_prefix, _ = lax.scan(
update_prefix,
prefix,
jnp.arange(n),
) to new_prefix = lax.fori_loop(
lower=0,
upper=n,
body_fun=lambda offset, pref: update_prefix(pref, offset),
init_val=prefix,
) This may help some improvement. You also mentioned using |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Part of my machine-learning code that uses
JAX
contains some dynamic programming where large array updates occur along the diagonal of aN ✕ N
matrix, by use of alax.scan
.Originally, the original code which I want to 'mimic' was in$\mathcal{O}(N^3)$ time. I feel that with $\mathcal{O}(N)$ time.
C++
. However, because I wanted the result of this array to be differentiable, and I wanted to take advantage of hardware accelerators, I implemented the original code as a loop inJAX
. The original code hadJAX
with sufficient amounts of parallelism, this could be done inHowever, the
lax.scan
part is somewhat slower than expected –10x
slower than using C++.str
's, the gap becomes even larger despite usingvmap
inJAX
. (around50
x)Could the people here provide any insight into how to improve the performance of the following code? (The following code is somewhat simplified from the original)
Here were some of my initial thoughts.
jsp.logsumexp
very slow?However, I am not exactly sure how to profile and check these thoughts either, as I would need to have a control which does not implement the above ideas to test this, and making such a control is non-trivial. Any help would be appreciated!
Beta Was this translation helpful? Give feedback.
All reactions