Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Typo in 'to_hf_weights.py ' #231

Open
AmoArt opened this issue Jul 8, 2022 · 1 comment
Open

Typo in 'to_hf_weights.py ' #231

AmoArt opened this issue Jul 8, 2022 · 1 comment

Comments

@AmoArt
Copy link

AmoArt commented Jul 8, 2022

In the line 461 ' with maps.mesh(devices, ("dp", "mp")):' should be written as ' with maps.Mesh(devices, ("dp", "mp")):' otherwise it gives error that jax.experimental.maps do not have attribute called mesh.

@vfbd
Copy link
Contributor

vfbd commented Jul 8, 2022

jax.experimental.maps does have "mesh" as long as you have jax<=0.3.7:

❯ python3.9
Python 3.9.13 (main, Jun  8 2022, 09:45:57) 
[GCC 11.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax
>>> import jaxlib
>>> jax.__version__
'0.2.12'
>>> jaxlib.__version__
'0.1.68'
>>> jax.experimental.maps.mesh
<function mesh at 0x7f6346581940>
>>> from jax.experimental import maps
>>> maps.mesh
<function mesh at 0x7f6346581940>

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants