Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AttributeError: module 'jaxlib.pocketfft' has no attribute 'pocketfft' #233

Open
umm-maybe opened this issue Jul 16, 2022 · 4 comments
Open

Comments

@umm-maybe
Copy link

Hello, I have followed the (very much appreciated) howto_finetune.md guide and, upon attempting to run the magic python device_train.py command, received the error noted above. The only Google search result that seems to mention something similar is this: https://bytemeta.vip/repo/deepmind/alphafold/issues/515

The answer to that question seems to imply it has to do with a version incompatibility between jax and jaxlib, but the solution they link to doesn't work here. Any tips or advice for working around this would be greatly appreciated!

@Ontopic
Copy link

Ontopic commented Aug 3, 2022

From the top of my head; pip install jax==0.2.12 jaxlib==0.1.67
Can not try right now, but that version combination should work on TPU-VM.

Edit: I think it also has to do with what Python (3.7 on TPU v2 and Colab, 3.8 on v3) version you're running and what TPU-version / accelerator-type. I think I've seen jaxlib==0.1.68 in v2 setups, so also worth a shot.

@dunstantom
Copy link

I'm also using a TPU v2 setup and ran into this problem. I used the JAX TPU install instructions from their README and it worked for me.

@sxiii
Copy link

sxiii commented Aug 18, 2022

Now I'm also getting "AttributeError: module 'jax' has no attribute 'version'"... Or, also:
AttributeError: module 'jaxlib.pocketfft' has no attribute 'pocketfft'. Tried couple of different colab notebooks... Doesn't work...

@musabgultekin
Copy link

musabgultekin commented Oct 11, 2022

I fixed it by doing this right after the install dependencies section:

!pip install jaxlib==0.1.67

And restart the runtime if it asks

Though it feels so fragile. Don't know why

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants