-
Notifications
You must be signed in to change notification settings - Fork 510
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Jax dependency Issue #579
Comments
Not sure how we can fix this. As far as I understand, we should not pin any jaxlib version, only jax, however that doesn't prevent it from installing a newer jaxlib. Installing through poetry instead of pip should solve most issues, as it will install the versions specified in the lock file. |
Nevermind, this was already merged, however we pin dm-haiku to 0.0.10 since I had some problem with 0.0.11, however I don't remember why. |
I pinned dm-haiku 0.0.12, this should hopefully help avoid any issues once google decides to upgrade jax within colab. Thanks! |
I think we can update jax to the latest version and change the code to use the linear_utils from the jax.extend module find the same here, I can make a pr for the same, and maybe fix installs for jax on colab as well. |
dm haiku 0.0.12 does this already. Should be fine now after updating the dependency |
I think this is a related/the same error. I returned to a ColabFold v1.5.5: AlphaFold2 w/ MMseqs2 BATCH session that was working yesterday, and now I am getting the below error. Is there a fix for this? I might just be ignorant about what pinning dm-haiku 0.0.12 means.
|
Are you using this notebook: The error message sounds like you are using an old version of it. |
I tried running things from that notebook, and no change. I did realize that I only seem to get that bug when using TPU. GPU seems fine. |
I pushed a fix for TPU, should work again |
Thanks! Works great. |
Since this has been fixed with the upgrade in dm haiku 0.0.12, I think I can close this issue? |
I still need to make a new pip release, i updated the conda package a few days ago. |
Jax discontinued linear_util since v0.4.25 (latest) which means that when haiku is imported running jax@latest it crashes and since colabfold's pyproject.toml says its fine with any version of jax which is 0.4.20 above it causes issues running several methods of colabfold.
Please let me know if this makes sense or would you require more info.
The text was updated successfully, but these errors were encountered: