Skip to content

Commit

Permalink
Switch to jax.tree namespace for tree utilities
Browse files Browse the repository at this point in the history
  • Loading branch information
schmrlng committed Oct 15, 2024
1 parent 2b3bc16 commit 0836732
Show file tree
Hide file tree
Showing 6 changed files with 11 additions and 12 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ jobs:

strategy:
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
python-version: ["3.9", "3.10", "3.11", "3.12"]
os: [ubuntu-latest]

steps:
Expand Down
2 changes: 1 addition & 1 deletion hj_reachability/artificial_dissipation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def local_lax_friedrichs(partial_max_magnitudes, states, time, values, left_grad
jnp.maximum(jnp.max(left_grad_values, grid_axes), jnp.max(right_grad_values, grid_axes)))
local_local_grad_value_boxes = sets.Box(jnp.minimum(left_grad_values, right_grad_values),
jnp.maximum(left_grad_values, right_grad_values))
local_grad_value_boxes = jax.tree_map(
local_grad_value_boxes = jax.tree.map(
lambda global_grad_value, local_local_grad_values:
(jnp.broadcast_to(global_grad_value, values.shape +
(values.ndim,) * 2).at[..., grid_axes, grid_axes].set(local_local_grad_values)),
Expand Down
6 changes: 3 additions & 3 deletions hj_reachability/finite_differences/upwind_first_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def compute_weno(v):

values = np.random.rand(1000)
spacing = 0.1
jax.tree_map(lambda x, y: np.testing.assert_allclose(x, y, atol=1e-5),
jax.tree.map(lambda x, y: np.testing.assert_allclose(x, y, atol=1e-5),
upwind_first.WENO5(values, spacing, boundary_conditions.periodic),
_WENO5(values, spacing, boundary_conditions.periodic))

Expand Down Expand Up @@ -79,15 +79,15 @@ def _divided_difference(x, i, spacing=1):
values = np.random.rand(1000)
spacing = 0.1
for order in range(1, 5):
jax.tree_map(lambda x, y: np.testing.assert_allclose(x, y, atol=1e-5),
jax.tree.map(lambda x, y: np.testing.assert_allclose(x, y, atol=1e-5),
upwind_first.essentially_non_oscillatory(order, values, spacing, boundary_conditions.periodic),
_brute_force_essentially_non_oscillatory(order, values, spacing, boundary_conditions.periodic))

def test_weighted_essentially_non_oscillatory_vectorized(self):
values = np.random.rand(1000)
spacing = 0.1
for eno_order in range(1, 5):
jax.tree_map(
jax.tree.map(
lambda x, y: np.testing.assert_allclose(x, y, atol=1e-5),
upwind_first.weighted_essentially_non_oscillatory(eno_order, values, spacing,
boundary_conditions.periodic),
Expand Down
2 changes: 1 addition & 1 deletion hj_reachability/sets_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_ball(self):
ball = sets.Ball(np.ones(3), np.sqrt(3))
np.testing.assert_allclose(ball.extreme_point(np.array([1, -1, 1])), np.array([2, 0, 2]), atol=1e-6)
self.assertTrue(np.all(np.isfinite(ball.extreme_point(np.zeros(3)))))
jax.tree_map(np.testing.assert_allclose, ball.bounding_box,
jax.tree.map(np.testing.assert_allclose, ball.bounding_box,
sets.Box((1 - np.sqrt(3)) * np.ones(3), (1 + np.sqrt(3)) * np.ones(3)))
np.testing.assert_allclose(ball.max_magnitudes, (1 + np.sqrt(3)) * np.ones(3))
self.assertEqual(ball.ndim, 3)
Expand Down
7 changes: 3 additions & 4 deletions hj_reachability/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,9 @@ def get_axis_sequence(axis_array: np.ndarray) -> List:
return axis_list

multivmap_kwargs = {"in_axes": in_axes, "out_axes": in_axes if out_axes is None else out_axes}
axis_sequence_structure = jax.tree_util.tree_structure(
next(a for a in jax.tree_util.tree_leaves(in_axes) if a is not None).tolist())
vmap_kwargs = jax.tree_util.tree_transpose(jax.tree_util.tree_structure(multivmap_kwargs), axis_sequence_structure,
jax.tree_map(get_axis_sequence, multivmap_kwargs))
axis_sequence_structure = jax.tree.structure(next(a for a in jax.tree.leaves(in_axes) if a is not None).tolist())
vmap_kwargs = jax.tree.transpose(jax.tree.structure(multivmap_kwargs), axis_sequence_structure,
jax.tree.map(get_axis_sequence, multivmap_kwargs))
return functools.reduce(lambda f, kwargs: jax.vmap(f, **kwargs), vmap_kwargs, fun)


Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
flax>=0.6.6
jax>=0.4.2
numpy>=1.18.0
jax>=0.4.25
numpy>=1.22

0 comments on commit 0836732

Please sign in to comment.