Skip to content

How should I do segment_argmax in JAX? #14784

Answered by jakevdp
ameya98 asked this question in Q&A
Discussion options

You must be logged in to vote

Here's a basic implementation you might use, though if you have many segments it's not particularly efficient:

import jax
import numpy as np
import jax.numpy as jnp

def segment_argmax(data, segment_ids, num_segments=None):
  if num_segments is None:
    num_segments = np.max(segment_ids) + 1
  num_segments = int(num_segments)
  data = jnp.asarray(data)
  segment_ids = jnp.asarray(segment_ids)
  return jax.vmap(lambda i: jnp.where(i == segment_ids, data, -jnp.inf))(
      jnp.arange(num_segments)).argmax(1)

data = jnp.array(np.random.randint(0, 100, 10))
segments = jnp.array([0, 1, 2, 0, 2, 2, 1, 2, 0, 1])

print("data:", data)
print("segments:", segments)
print(jax.ops.segment_max(data, s…

Replies: 1 comment 2 replies

Comment options

You must be logged in to vote
2 replies
@jakevdp
Comment options

Answer selected by ameya98
@ameya98
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants