From 7f9efd76351aa3b01c2f2387ddd37f6b0d840692 Mon Sep 17 00:00:00 2001 From: Franz Louis Cesista Date: Wed, 1 Jan 2025 12:45:46 +0800 Subject: [PATCH] calc trace(G.T @ X) instead of norm(G.T @ X) --- optax/contrib/_muon.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optax/contrib/_muon.py b/optax/contrib/_muon.py index c753f5d9..a911dd52 100644 --- a/optax/contrib/_muon.py +++ b/optax/contrib/_muon.py @@ -108,7 +108,7 @@ def update_fn(updates, state, params=None): updates, _ = jax.lax.scan(muon_iterator, updates, muon_coeffs) if adaptive: # Scale the orthogonalized updates by the dual norm of the original updates. - updates = jax.tree.map(lambda x, y: jnp.linalg.norm(x.T @ y) * y, mu_hat, updates) + updates = jax.tree.map(lambda x, y: jnp.linalg.trace(x.T @ y) * y, mu_hat, updates) mu = otu.tree_cast(mu, mu_dtype) return updates, MuonState(count=count_inc, mu=mu) return base.GradientTransformation(init_fn, update_fn)