diff --git a/optax/contrib/_muon.py b/optax/contrib/_muon.py index a911dd52..39646fc3 100644 --- a/optax/contrib/_muon.py +++ b/optax/contrib/_muon.py @@ -6,7 +6,7 @@ """ -from typing import Any, List, NamedTuple, Optional, Tuple, Union +from typing import Any, NamedTuple, Optional, Tuple, Union import chex import jax @@ -76,13 +76,14 @@ def scale_by_muon( ) if muon_coeffs.ndim > 2 or muon_coeffs.shape[-1] != 3: raise ValueError( - f"newton_schulz_coeffs must have shape (3,) or (n, 3), got {muon_coeffs.shape}" + "newton_schulz_coeffs must have shape (3,) or (n, 3)," + f"got {muon_coeffs.shape}" ) def muon_iterator(x: jnp.ndarray, coeffs: jnp.ndarray): - A = x @ x.T - B = coeffs[1] * A + coeffs[2] * A @ A - return coeffs[0] * x + B @ x, None + a = x @ x.T + b = coeffs[1] * a + coeffs[2] * a @ a + return coeffs[0] * x + b @ x def init_fn(params): mu = otu.tree_zeros_like(params, dtype=mu_dtype) # First moment @@ -105,10 +106,15 @@ def update_fn(updates, state, params=None): # Ensure that the spectral norm of the updates is at most 1. updates = jax.tree.map(lambda x: x / (jnp.linalg.norm(x) + eps), mu_hat) # Apply Newton-schulz orthogonalization. - updates, _ = jax.lax.scan(muon_iterator, updates, muon_coeffs) + updates, _ = jax.lax.scan( + lambda x, coeffs: (muon_iterator(x, coeffs), None), updates, muon_coeffs + ) if adaptive: - # Scale the orthogonalized updates by the dual norm of the original updates. - updates = jax.tree.map(lambda x, y: jnp.linalg.trace(x.T @ y) * y, mu_hat, updates) + # Scale the orthogonalized updates by the dual norm of the original + # updates. See https://arxiv.org/abs/2409.20325 for the derivation. + updates = jax.tree.map( + lambda x, y: jnp.linalg.trace(x.T @ y) * y, mu_hat, updates + ) mu = otu.tree_cast(mu, mu_dtype) return updates, MuonState(count=count_inc, mu=mu) return base.GradientTransformation(init_fn, update_fn)