Skip to content

Commit

Permalink
calc trace(G.T @ X) instead of norm(G.T @ X)
Browse files Browse the repository at this point in the history
  • Loading branch information
leloykun committed Jan 1, 2025
1 parent 4ffe450 commit 7f9efd7
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion optax/contrib/_muon.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def update_fn(updates, state, params=None):
updates, _ = jax.lax.scan(muon_iterator, updates, muon_coeffs)
if adaptive:
# Scale the orthogonalized updates by the dual norm of the original updates.
updates = jax.tree.map(lambda x, y: jnp.linalg.norm(x.T @ y) * y, mu_hat, updates)
updates = jax.tree.map(lambda x, y: jnp.linalg.trace(x.T @ y) * y, mu_hat, updates)
mu = otu.tree_cast(mu, mu_dtype)
return updates, MuonState(count=count_inc, mu=mu)
return base.GradientTransformation(init_fn, update_fn)
Expand Down

0 comments on commit 7f9efd7

Please sign in to comment.