Skip to content

Commit

Permalink
make linter happy
Browse files Browse the repository at this point in the history
  • Loading branch information
leloykun committed Jan 1, 2025
1 parent 7f9efd7 commit 40aa7b4
Showing 1 changed file with 14 additions and 8 deletions.
22 changes: 14 additions & 8 deletions optax/contrib/_muon.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"""


from typing import Any, List, NamedTuple, Optional, Tuple, Union
from typing import Any, NamedTuple, Optional, Tuple, Union

import chex
import jax
Expand Down Expand Up @@ -76,13 +76,14 @@ def scale_by_muon(
)
if muon_coeffs.ndim > 2 or muon_coeffs.shape[-1] != 3:
raise ValueError(
f"newton_schulz_coeffs must have shape (3,) or (n, 3), got {muon_coeffs.shape}"
"newton_schulz_coeffs must have shape (3,) or (n, 3),"
f"got {muon_coeffs.shape}"
)

def muon_iterator(x: jnp.ndarray, coeffs: jnp.ndarray):
A = x @ x.T
B = coeffs[1] * A + coeffs[2] * A @ A
return coeffs[0] * x + B @ x, None
a = x @ x.T
b = coeffs[1] * a + coeffs[2] * a @ a
return coeffs[0] * x + b @ x

def init_fn(params):
mu = otu.tree_zeros_like(params, dtype=mu_dtype) # First moment
Expand All @@ -105,10 +106,15 @@ def update_fn(updates, state, params=None):
# Ensure that the spectral norm of the updates is at most 1.
updates = jax.tree.map(lambda x: x / (jnp.linalg.norm(x) + eps), mu_hat)
# Apply Newton-schulz orthogonalization.
updates, _ = jax.lax.scan(muon_iterator, updates, muon_coeffs)
updates, _ = jax.lax.scan(
lambda x, coeffs: (muon_iterator(x, coeffs), None), updates, muon_coeffs
)
if adaptive:
# Scale the orthogonalized updates by the dual norm of the original updates.
updates = jax.tree.map(lambda x, y: jnp.linalg.trace(x.T @ y) * y, mu_hat, updates)
# Scale the orthogonalized updates by the dual norm of the original
# updates. See https://arxiv.org/abs/2409.20325 for the derivation.
updates = jax.tree.map(
lambda x, y: jnp.linalg.trace(x.T @ y) * y, mu_hat, updates
)
mu = otu.tree_cast(mu, mu_dtype)
return updates, MuonState(count=count_inc, mu=mu)
return base.GradientTransformation(init_fn, update_fn)
Expand Down

0 comments on commit 40aa7b4

Please sign in to comment.