diff --git a/acme/setup.py b/acme/setup.py index 2f170f1f..291ee31c 100755 --- a/acme/setup.py +++ b/acme/setup.py @@ -58,9 +58,9 @@ jax_requirements = [ 'chex', - 'jax==0.3.6', # Update when TF2.9 is release. - 'jaxlib==0.3.15', # Update when TF2.9 is release. - 'dm-haiku', + 'jax==0.4.1', + 'jaxlib==0.4.1', + 'dm-haiku==0.0.10', 'flax', 'optax', 'rlax',