diff --git a/xarray_jax/tests/test_types.py b/xarray_jax/tests/test_types.py index b88847c..ae17bb4 100644 --- a/xarray_jax/tests/test_types.py +++ b/xarray_jax/tests/test_types.py @@ -11,7 +11,6 @@ ) import equinox as eqx from xarray_jax.register_pytrees import var_change_on_unflatten -import jax.numpy as jnp jax.config.update("jax_enable_x64", True)