From 0836732847b0f3a7a92752f278f2510a7cefb49c Mon Sep 17 00:00:00 2001 From: Ed Schmerling Date: Tue, 15 Oct 2024 11:43:30 -0700 Subject: [PATCH] Switch to `jax.tree` namespace for tree utilities --- .github/workflows/ci.yml | 2 +- hj_reachability/artificial_dissipation.py | 2 +- hj_reachability/finite_differences/upwind_first_test.py | 6 +++--- hj_reachability/sets_test.py | 2 +- hj_reachability/utils.py | 7 +++---- requirements.txt | 4 ++-- 6 files changed, 11 insertions(+), 12 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4608970..9d8f4fe 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -9,7 +9,7 @@ jobs: strategy: matrix: - python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] + python-version: ["3.9", "3.10", "3.11", "3.12"] os: [ubuntu-latest] steps: diff --git a/hj_reachability/artificial_dissipation.py b/hj_reachability/artificial_dissipation.py index c7aa188..aba95b5 100644 --- a/hj_reachability/artificial_dissipation.py +++ b/hj_reachability/artificial_dissipation.py @@ -23,7 +23,7 @@ def local_lax_friedrichs(partial_max_magnitudes, states, time, values, left_grad jnp.maximum(jnp.max(left_grad_values, grid_axes), jnp.max(right_grad_values, grid_axes))) local_local_grad_value_boxes = sets.Box(jnp.minimum(left_grad_values, right_grad_values), jnp.maximum(left_grad_values, right_grad_values)) - local_grad_value_boxes = jax.tree_map( + local_grad_value_boxes = jax.tree.map( lambda global_grad_value, local_local_grad_values: (jnp.broadcast_to(global_grad_value, values.shape + (values.ndim,) * 2).at[..., grid_axes, grid_axes].set(local_local_grad_values)), diff --git a/hj_reachability/finite_differences/upwind_first_test.py b/hj_reachability/finite_differences/upwind_first_test.py index 3860841..0c5adb2 100644 --- a/hj_reachability/finite_differences/upwind_first_test.py +++ b/hj_reachability/finite_differences/upwind_first_test.py @@ -37,7 +37,7 @@ def compute_weno(v): values = np.random.rand(1000) spacing = 0.1 - jax.tree_map(lambda x, y: np.testing.assert_allclose(x, y, atol=1e-5), + jax.tree.map(lambda x, y: np.testing.assert_allclose(x, y, atol=1e-5), upwind_first.WENO5(values, spacing, boundary_conditions.periodic), _WENO5(values, spacing, boundary_conditions.periodic)) @@ -79,7 +79,7 @@ def _divided_difference(x, i, spacing=1): values = np.random.rand(1000) spacing = 0.1 for order in range(1, 5): - jax.tree_map(lambda x, y: np.testing.assert_allclose(x, y, atol=1e-5), + jax.tree.map(lambda x, y: np.testing.assert_allclose(x, y, atol=1e-5), upwind_first.essentially_non_oscillatory(order, values, spacing, boundary_conditions.periodic), _brute_force_essentially_non_oscillatory(order, values, spacing, boundary_conditions.periodic)) @@ -87,7 +87,7 @@ def test_weighted_essentially_non_oscillatory_vectorized(self): values = np.random.rand(1000) spacing = 0.1 for eno_order in range(1, 5): - jax.tree_map( + jax.tree.map( lambda x, y: np.testing.assert_allclose(x, y, atol=1e-5), upwind_first.weighted_essentially_non_oscillatory(eno_order, values, spacing, boundary_conditions.periodic), diff --git a/hj_reachability/sets_test.py b/hj_reachability/sets_test.py index ef6be9a..b979d5d 100644 --- a/hj_reachability/sets_test.py +++ b/hj_reachability/sets_test.py @@ -22,7 +22,7 @@ def test_ball(self): ball = sets.Ball(np.ones(3), np.sqrt(3)) np.testing.assert_allclose(ball.extreme_point(np.array([1, -1, 1])), np.array([2, 0, 2]), atol=1e-6) self.assertTrue(np.all(np.isfinite(ball.extreme_point(np.zeros(3))))) - jax.tree_map(np.testing.assert_allclose, ball.bounding_box, + jax.tree.map(np.testing.assert_allclose, ball.bounding_box, sets.Box((1 - np.sqrt(3)) * np.ones(3), (1 + np.sqrt(3)) * np.ones(3))) np.testing.assert_allclose(ball.max_magnitudes, (1 + np.sqrt(3)) * np.ones(3)) self.assertEqual(ball.ndim, 3) diff --git a/hj_reachability/utils.py b/hj_reachability/utils.py index 11631dc..02a02db 100644 --- a/hj_reachability/utils.py +++ b/hj_reachability/utils.py @@ -52,10 +52,9 @@ def get_axis_sequence(axis_array: np.ndarray) -> List: return axis_list multivmap_kwargs = {"in_axes": in_axes, "out_axes": in_axes if out_axes is None else out_axes} - axis_sequence_structure = jax.tree_util.tree_structure( - next(a for a in jax.tree_util.tree_leaves(in_axes) if a is not None).tolist()) - vmap_kwargs = jax.tree_util.tree_transpose(jax.tree_util.tree_structure(multivmap_kwargs), axis_sequence_structure, - jax.tree_map(get_axis_sequence, multivmap_kwargs)) + axis_sequence_structure = jax.tree.structure(next(a for a in jax.tree.leaves(in_axes) if a is not None).tolist()) + vmap_kwargs = jax.tree.transpose(jax.tree.structure(multivmap_kwargs), axis_sequence_structure, + jax.tree.map(get_axis_sequence, multivmap_kwargs)) return functools.reduce(lambda f, kwargs: jax.vmap(f, **kwargs), vmap_kwargs, fun) diff --git a/requirements.txt b/requirements.txt index 3834ede..e239762 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ flax>=0.6.6 -jax>=0.4.2 -numpy>=1.18.0 +jax>=0.4.25 +numpy>=1.22