-
I need to identify the indices corresponding to the largest values in segments of an array. Basically, the argmax version of jax.ops.segment_max. How would one implement this? |
Beta Was this translation helpful? Give feedback.
Answered by
jakevdp
Mar 8, 2023
Replies: 1 comment 2 replies
-
Very good question – I don't think there's a good answer at the moment. |
Beta Was this translation helpful? Give feedback.
2 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Here's a basic implementation you might use, though if you have many segments it's not particularly efficient: