Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue with using your code in Colab because of JAX! #4

Open
AmirhosseinnnKhademi opened this issue Nov 13, 2022 · 1 comment
Open

Comments

@AmirhosseinnnKhademi
Copy link

Hi Juan,
Thanks for your helpful video. I subscribed!
However, I have a problem here!
the "from jax.experimental import optimizers" does not work for me! it says "cannot import name 'optimizers' from 'jax.experimental'".
Then I have to switch to CPU and also install "!pip install jax[cpu]==0.2.27" to work!
It is confusing for me and as I searched the net for other people.
Could you please let me know how you use GPU in your video? I have to use my CPU and it takes like a year for training!!!

Thank you

@jdtoscano94
Copy link
Owner

Hi, thanks for reporting this issue.
I used the google colab GPU. However, it looks like they changed something in the new updates, which does not work now.
We only use "jax. experimental" to get our optimizer. I would recommend finding a way to use a previous version of Colab or importing the jax. experimental library. If not, as an alternative, you can try optax, but you may need to add some minor changes downstream too.
To install optax you would use the following commands:
!pip install optax
import optax
Then you can create an optimizer using the following lines:
optimizer = optax.adam(optax.exponential_decay(lr0, decay_step, decay_rate,)) opt_state = optimizer.init(params)
To update our parameters you need to compute the gradients of your loss. You can do something like this:
de los_fn(params):
'"your loss function here"
return loss
grad_fn = value_and_grad(loss_fn,)
loss, grads = grad_fn(params)
updates, opt_state = optimizer.update(grads, opt_state, params) params optax.apply_updates(params, updates)

Thanks for letting me know and I am sorry I could not help much this time.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants