diff --git a/docs/source/api/resources/schedulers.rst b/docs/source/api/resources/schedulers.rst index e844afdb..52e5644d 100644 --- a/docs/source/api/resources/schedulers.rst +++ b/docs/source/api/resources/schedulers.rst @@ -28,7 +28,7 @@ Learning rate schedulers are techniques that adjust the learning rate over time - **PyTorch**: The implemented schedulers inherit from the PyTorch :literal:`_LRScheduler` class. Visit `How to adjust learning rate `_ in the PyTorch documentation for more details. -- **JAX**: The implemented schedulers must parameterize and return a function that maps step counts to values. Visit `Schedules `_ in the Optax documentation for more details. +- **JAX**: The implemented schedulers must parameterize and return a function that maps step counts to values. Visit `Optimizer Schedules `_ in the Optax documentation for more details. .. raw:: html diff --git a/skrl/models/jax/base.py b/skrl/models/jax/base.py index ddf6027a..f9c7e607 100644 --- a/skrl/models/jax/base.py +++ b/skrl/models/jax/base.py @@ -482,7 +482,7 @@ def update_parameters(self, model: flax.linen.Module, polyak: float = 1) -> None self.state_dict = self.state_dict.replace(params=model.state_dict.params) # soft update else: - # HACK: Does it make sense to use https://optax.readthedocs.io/en/latest/api.html?#optax.incremental_update + # HACK: Does it make sense to use https://optax.readthedocs.io/en/latest/api/apply_updates.html#optax.incremental_update params = jax.tree_util.tree_map( lambda params, model_params: polyak * model_params + (1 - polyak) * params, self.state_dict.params, diff --git a/skrl/resources/optimizers/jax/adam.py b/skrl/resources/optimizers/jax/adam.py index f8b405d1..0d877be7 100644 --- a/skrl/resources/optimizers/jax/adam.py +++ b/skrl/resources/optimizers/jax/adam.py @@ -24,7 +24,7 @@ def _step_with_scale(transformation, grad, state, state_dict, scale): # optax transform params, optimizer_state = transformation.update(grad, state, state_dict.params) # custom scale - # https://optax.readthedocs.io/en/latest/api.html?#optax.scale + # https://optax.readthedocs.io/en/latest/api/transformations.html#optax.scale params = jax.tree_util.tree_map(lambda params: scale * params, params) # apply transformation params = optax.apply_updates(state_dict.params, params) @@ -35,7 +35,7 @@ class Adam: def __new__(cls, model: Model, lr: float = 1e-3, grad_norm_clip: float = 0, scale: bool = True) -> "Optimizer": """Adam optimizer - Adapted from `Optax's Adam `_ + Adapted from `Optax's Adam `_ to support custom scale (learning rate) :param model: Model diff --git a/skrl/resources/schedulers/jax/kl_adaptive.py b/skrl/resources/schedulers/jax/kl_adaptive.py index 149bc922..8a8a634b 100644 --- a/skrl/resources/schedulers/jax/kl_adaptive.py +++ b/skrl/resources/schedulers/jax/kl_adaptive.py @@ -86,5 +86,5 @@ def step(self, kl: Optional[Union[np.ndarray, float]] = None) -> None: # Alias to maintain naming compatibility with Optax schedulers -# https://optax.readthedocs.io/en/latest/api.html#schedules +# https://optax.readthedocs.io/en/latest/api/optimizer_schedules.html kl_adaptive = KLAdaptiveLR