Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use of core.unsafe_get_axis_names_DO_NOT_USE which no longer exists #285

Open
svandenhaute opened this issue Oct 31, 2024 · 1 comment
Open

Comments

@svandenhaute
Copy link

Using:

jax                    0.4.31.dev20241012+eafb38ca8
jax-rocm60-pjrt        0.4.31
jax-rocm60-plugin      0.4.31
jaxlib                 0.4.31
jaxtyping              0.2.34
kfac-jax               0.0.6

and trying to use ferminet but got the following error:

  File "/*/lib/python3.11/site-packages/ferminet/pretrain.py", line 245, in loss_fn
    return constants.pmean(result)
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/*/lib/python3.11/site-packages/kfac_jax/_src/utils/parallel.py", line 58, in pmean_if_pmap
    return lax.pmean(obj, axis_name) if in_pmap(axis_name) else obj
                                        ^^^^^^^^^^^^^^^^^^
  File "/*/lib/python3.11/site-packages/kfac_jax/_src/utils/parallel.py", line 38, in in_pmap
    return axis_name in core.unsafe_get_axis_names_DO_NOT_USE()
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/*/lib/python3.11/site-packages/jax/_src/deprecations.py", line 55, in getattr
    raise AttributeError(f"module {module!r} has no attribute {name!r}")
AttributeError: module 'jax.core' has no attribute 'unsafe_get_axis_names_DO_NOT_USE'

It appears that this is a deprecated method? Is kfac-jax only compatible with JAX up to a certain version? If so, this should probably be specified in the package config.

@Eureka10shen
Copy link

Eureka10shen commented Nov 5, 2024

The same question I met. Is there anyone who has solved it?

Versions of my packages :

# Name                    Version                   Build  Channel
jax                       0.4.26                   pypi_0    pypi
jax-metal                 0.0.7                    pypi_0    pypi
jaxlib                    0.4.26                   pypi_0    pypi
jaxtyping                 0.2.28                   pypi_0    pypi
kfac-jax                  0.0.6                    pypi_0    pypi

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants