-
Notifications
You must be signed in to change notification settings - Fork 892
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Resolving dependency issues #246
Comments
With Colab Pro, the default TPU lib (and JAX) is now at 0.3.25. I jumped thru these hoops as well and have run with !pip install mesh-transformer-jax/ jax==0.3.15 tensorflow==2.8.2 chex==0.1.4 jaxlib==0.3.15 Your mileage may vary.. Johnny |
@rinapch Worked perfect for me. Thank you very much.. |
That said, it worked perfectly for fine tuning but not to infer on colab. (It caused optax error)
and
Just as a note. |
Thank you so much for this post, it helped me resolve all of my dependency issues. I have never worked with poetry before, but I was able to get a model training in a conda environment just using install commands. If anybody is interested, I wrote out the steps I took from scratch that are currently working based on my test run. -- First, Install conda on the TPU vm
-- Update path to include conda
-- Create env with mamba and python == 3.8
-- Close and reopen terminal, ressh
-- Leave base
-- Enter env
-- Install requirements available through conda first
-- Install remaining requirements not available through conda with pip
-- NOTE: You will see a typing-extensions error pop up about tensorflow 2.6.0 not being compatible with 4.2.0. This is fine, ignore it. -- Jax 0.2.12 does NOT WORK with TPUs anymore, but we can use 0.2.18 or 0.2.20
-- If you have issues with protobuf (may originate from the import wandb call), run this
-- Finally, you can run this and fine-tune your model
|
@JohnnyOpcode How did you infer with JAX 0.3.15? I think it runs only with 0.2.12. |
I was using Colab Pro (paid) and I experimented with different versions of the libraries and with pip. The key takeaway is compatibility with the TPUv2 ASIC. I'll try and find some time to go thru those motions again and come up with a newer working requirements.txt for everybody. Python sucks btw. Just like JS and TS. Too many brittle dependencies, but it does create lots of BS positions and salaries. |
There has been a number of issues regarding different version conflicts and how to fix them. I've spent some time trying to make this code run, so maybe this instruction will spare someone else their efforts :)
First of all, as per this issue in jax repo jax-ml/jax#13321, TPU VMs no longer work with jax older than 0.2.16. This repo requires jax==0.2.12. I found out that the code still works with jax versions 0.2.18 and 0.2.20
Additionally, since there are a number of dependecies in the requirements file that do not state the needed versions, I rolled back all of them to the lastest versions per January 2022 and used poetry to resolve conflics. Here is the
pyproject.toml
file in the end:After installing all of this with poetry, install
jax[tpu]
with pip, so that it gets the right libtpu nightly build (pip install "jax[tpu]==0.2.18" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
)When starting training you also can experience training being stuck at validation. As was suggested by @versae in this issue #218, it helps to change TPU runtime version to an alpha build. Something like
gcloud alpha compute tpus tpu-vm create gptj --accelerator-type v3-8 --version v2-alpha
The text was updated successfully, but these errors were encountered: