Skip to content

Commit

Permalink
Migrate from jax.core to jax.extend.core for several deprecated symbols
Browse files Browse the repository at this point in the history
A number of symbols from jax.core are deprecated as of recent JAX releases; some of them are newly available in jax.extend.core.

PiperOrigin-RevId: 706771637
  • Loading branch information
Jake VanderPlas authored and The diffren Authors committed Dec 16, 2024
1 parent 7f1dae7 commit b490b8c
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion diffren/jax/internal/kernels/rasterize_triangles_xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from . import rasterize_triangles_gpu
from diffren.jax.internal.kernels import descriptors_pb2
import jax
import jax.extend as jex
from jax.interpreters import batching
from jax.interpreters import mlir
import jax.numpy as jnp
Expand Down Expand Up @@ -129,7 +130,7 @@ def rasterize_triangles_abstract_eval(vertices, triangles, *, image_width,
(num_layers, image_height, image_width, 3), dtype=np.float32))


rasterize_triangles_p = jax.core.Primitive('rasterize_triangles')
rasterize_triangles_p = jex.core.Primitive('rasterize_triangles')
rasterize_triangles_p.multiple_results = True
rasterize_triangles_p.def_impl(functools.partial(
jax.interpreters.xla.apply_primitive, rasterize_triangles_p))
Expand Down

0 comments on commit b490b8c

Please sign in to comment.