You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
and trying to use ferminet but got the following error:
File "/*/lib/python3.11/site-packages/ferminet/pretrain.py", line 245, in loss_fn
return constants.pmean(result)
^^^^^^^^^^^^^^^^^^^^^^^
File "/*/lib/python3.11/site-packages/kfac_jax/_src/utils/parallel.py", line 58, in pmean_if_pmap
return lax.pmean(obj, axis_name) if in_pmap(axis_name) else obj
^^^^^^^^^^^^^^^^^^
File "/*/lib/python3.11/site-packages/kfac_jax/_src/utils/parallel.py", line 38, in in_pmap
return axis_name in core.unsafe_get_axis_names_DO_NOT_USE()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/*/lib/python3.11/site-packages/jax/_src/deprecations.py", line 55, in getattr
raise AttributeError(f"module {module!r} has no attribute {name!r}")
AttributeError: module 'jax.core' has no attribute 'unsafe_get_axis_names_DO_NOT_USE'
It appears that this is a deprecated method? Is kfac-jax only compatible with JAX up to a certain version? If so, this should probably be specified in the package config.
The text was updated successfully, but these errors were encountered:
Using:
and trying to use
ferminet
but got the following error:It appears that this is a deprecated method? Is
kfac-jax
only compatible with JAX up to a certain version? If so, this should probably be specified in the package config.The text was updated successfully, but these errors were encountered: