You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi Juan,
Thanks for your helpful video. I subscribed!
However, I have a problem here!
the "from jax.experimental import optimizers" does not work for me! it says "cannot import name 'optimizers' from 'jax.experimental'".
Then I have to switch to CPU and also install "!pip install jax[cpu]==0.2.27" to work!
It is confusing for me and as I searched the net for other people.
Could you please let me know how you use GPU in your video? I have to use my CPU and it takes like a year for training!!!
Thank you
The text was updated successfully, but these errors were encountered:
Hi, thanks for reporting this issue.
I used the google colab GPU. However, it looks like they changed something in the new updates, which does not work now.
We only use "jax. experimental" to get our optimizer. I would recommend finding a way to use a previous version of Colab or importing the jax. experimental library. If not, as an alternative, you can try optax, but you may need to add some minor changes downstream too.
To install optax you would use the following commands:
!pip install optax
import optax
Then you can create an optimizer using the following lines:
optimizer = optax.adam(optax.exponential_decay(lr0, decay_step, decay_rate,)) opt_state = optimizer.init(params)
To update our parameters you need to compute the gradients of your loss. You can do something like this:
de los_fn(params):
'"your loss function here"
return loss
grad_fn = value_and_grad(loss_fn,)
loss, grads = grad_fn(params)
updates, opt_state = optimizer.update(grads, opt_state, params) params optax.apply_updates(params, updates)
Thanks for letting me know and I am sorry I could not help much this time.
Hi Juan,
Thanks for your helpful video. I subscribed!
However, I have a problem here!
the "from jax.experimental import optimizers" does not work for me! it says "cannot import name 'optimizers' from 'jax.experimental'".
Then I have to switch to CPU and also install "!pip install jax[cpu]==0.2.27" to work!
It is confusing for me and as I searched the net for other people.
Could you please let me know how you use GPU in your video? I have to use my CPU and it takes like a year for training!!!
Thank you
The text was updated successfully, but these errors were encountered: