Skip to content

Commit 2fa0623

Browse files
lukebaumanncopybara-github
authored andcommitted
Fix the JAX version required for jaxlib._pathways.
PiperOrigin-RevId: 811518044
1 parent 03c26b5 commit 2fa0623

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

pathwaysutils/jax/__init__.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,29 +62,29 @@ def register_backend_cache(cache: Any, name: str, util=util): # pylint: disable
6262
del util
6363

6464
try:
65-
# jax>0.7.0
65+
# jax>=0.7.1
6666
from jax.extend import backend # pylint: disable=g-import-not-at-top
6767

6868
ifrt_proxy = backend.ifrt_proxy
6969
del backend
7070
except AttributeError:
71-
# jax<=0.7.0
71+
# jax<0.7.1
7272
from jax.lib import xla_extension # pylint: disable=g-import-not-at-top
7373

7474
ifrt_proxy = xla_extension.ifrt_proxy
7575
del xla_extension
7676

7777

7878
try:
79-
# jax>=0.7.2
79+
# jax>=0.8.0
8080
from jax.jaxlib import _pathways # pylint: disable=g-import-not-at-top
8181

8282
jaxlib_pathways = _pathways
8383
del _pathways
84-
except (ModuleNotFoundError, AttributeError):
85-
# jax<0.7.2
84+
except ModuleNotFoundError:
85+
# jax<0.8.0
8686

87-
jaxlib_pathways = _FakeJaxModule("jax.jaxlib._pathways", "0.7.2")
87+
jaxlib_pathways = _FakeJaxModule("jax.jaxlib._pathways", "0.8.0")
8888

8989

9090
del _FakeJaxModule

0 commit comments

Comments
 (0)