Skip to content

Commit

Permalink
Explicitly convert jax.numpy.meshgrid outputs to list
Browse files Browse the repository at this point in the history
The return type of several jax.numpy APIs will change from list to tuple in an upcoming JAX version, following a similar change in NumPy 2.0.

PiperOrigin-RevId: 602334680
  • Loading branch information
Jake VanderPlas authored and PIXDev committed Jan 29, 2024
1 parent cc4aa55 commit 8071089
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions dm_pix/_src/augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,9 @@ def elastic_deformation(
sigma=sigma,
kernel_size=kernel_size) * alpha

meshgrid = jnp.meshgrid(*[jnp.arange(size) for size in single_channel_shape],
indexing="ij")
meshgrid = list(
jnp.meshgrid(
*[jnp.arange(size) for size in single_channel_shape], indexing="ij"))
meshgrid[0] += shift_map_i
meshgrid[1] += shift_map_j

Expand Down

0 comments on commit 8071089

Please sign in to comment.