Skip to content

Commit

Permalink
Update Optax links in source code and docs
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Jan 1, 2025
1 parent f39aadc commit aee399f
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion docs/source/api/resources/schedulers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ Learning rate schedulers are techniques that adjust the learning rate over time

- **PyTorch**: The implemented schedulers inherit from the PyTorch :literal:`_LRScheduler` class. Visit `How to adjust learning rate <https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate>`_ in the PyTorch documentation for more details.

- **JAX**: The implemented schedulers must parameterize and return a function that maps step counts to values. Visit `Schedules <https://optax.readthedocs.io/en/latest/api.html#schedules>`_ in the Optax documentation for more details.
- **JAX**: The implemented schedulers must parameterize and return a function that maps step counts to values. Visit `Optimizer Schedules <https://optax.readthedocs.io/en/latest/api/optimizer_schedules.html>`_ in the Optax documentation for more details.

.. raw:: html

Expand Down
2 changes: 1 addition & 1 deletion skrl/models/jax/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,7 @@ def update_parameters(self, model: flax.linen.Module, polyak: float = 1) -> None
self.state_dict = self.state_dict.replace(params=model.state_dict.params)
# soft update
else:
# HACK: Does it make sense to use https://optax.readthedocs.io/en/latest/api.html?#optax.incremental_update
# HACK: Does it make sense to use https://optax.readthedocs.io/en/latest/api/apply_updates.html#optax.incremental_update
params = jax.tree_util.tree_map(
lambda params, model_params: polyak * model_params + (1 - polyak) * params,
self.state_dict.params,
Expand Down
4 changes: 2 additions & 2 deletions skrl/resources/optimizers/jax/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def _step_with_scale(transformation, grad, state, state_dict, scale):
# optax transform
params, optimizer_state = transformation.update(grad, state, state_dict.params)
# custom scale
# https://optax.readthedocs.io/en/latest/api.html?#optax.scale
# https://optax.readthedocs.io/en/latest/api/transformations.html#optax.scale
params = jax.tree_util.tree_map(lambda params: scale * params, params)
# apply transformation
params = optax.apply_updates(state_dict.params, params)
Expand All @@ -35,7 +35,7 @@ class Adam:
def __new__(cls, model: Model, lr: float = 1e-3, grad_norm_clip: float = 0, scale: bool = True) -> "Optimizer":
"""Adam optimizer
Adapted from `Optax's Adam <https://optax.readthedocs.io/en/latest/api.html?#adam>`_
Adapted from `Optax's Adam <https://optax.readthedocs.io/en/latest/api/optimizers.html#optax.adam>`_
to support custom scale (learning rate)
:param model: Model
Expand Down
2 changes: 1 addition & 1 deletion skrl/resources/schedulers/jax/kl_adaptive.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,5 +86,5 @@ def step(self, kl: Optional[Union[np.ndarray, float]] = None) -> None:


# Alias to maintain naming compatibility with Optax schedulers
# https://optax.readthedocs.io/en/latest/api.html#schedules
# https://optax.readthedocs.io/en/latest/api/optimizer_schedules.html
kl_adaptive = KLAdaptiveLR

0 comments on commit aee399f

Please sign in to comment.