From b490b8c48824e0c808edbbc15ed556ca0e05bdca Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 16 Dec 2024 11:15:08 -0800 Subject: [PATCH] Migrate from jax.core to jax.extend.core for several deprecated symbols A number of symbols from jax.core are deprecated as of recent JAX releases; some of them are newly available in jax.extend.core. PiperOrigin-RevId: 706771637 --- diffren/jax/internal/kernels/rasterize_triangles_xla.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/diffren/jax/internal/kernels/rasterize_triangles_xla.py b/diffren/jax/internal/kernels/rasterize_triangles_xla.py index c83ff8e..590eae7 100644 --- a/diffren/jax/internal/kernels/rasterize_triangles_xla.py +++ b/diffren/jax/internal/kernels/rasterize_triangles_xla.py @@ -20,6 +20,7 @@ from . import rasterize_triangles_gpu from diffren.jax.internal.kernels import descriptors_pb2 import jax +import jax.extend as jex from jax.interpreters import batching from jax.interpreters import mlir import jax.numpy as jnp @@ -129,7 +130,7 @@ def rasterize_triangles_abstract_eval(vertices, triangles, *, image_width, (num_layers, image_height, image_width, 3), dtype=np.float32)) -rasterize_triangles_p = jax.core.Primitive('rasterize_triangles') +rasterize_triangles_p = jex.core.Primitive('rasterize_triangles') rasterize_triangles_p.multiple_results = True rasterize_triangles_p.def_impl(functools.partial( jax.interpreters.xla.apply_primitive, rasterize_triangles_p))