Skip to content

Commit

Permalink
add kl divergence and more logging for forward_pass_logit_checker
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhaoyueCheng committed Aug 20, 2024
1 parent 14379df commit 3c49ac7
Showing 1 changed file with 20 additions and 4 deletions.
24 changes: 20 additions & 4 deletions MaxText/tests/forward_pass_logit_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,25 @@ def main(config, test_args):
max_logging.log(f"{golden_logits[0]=}")
max_logging.log(f"{full_train_logits[0, 0, :]=}")
token_size = int(test_args.token_size) if test_args.token_size else golden_logits.shape[0]
max_logging.log(f"Max Difference {np.max(np.subtract(full_train_logits[0, :token_size, :], golden_logits[:token_size, :]))}")
assert jax.numpy.allclose(
max_logging.log(f"Max Numerical Difference {np.max(np.subtract(full_train_logits[0, :token_size, :], golden_logits[:token_size, :]))}")

model_probabilities = jax.nn.softmax(full_train_logits[0, :token_size, :], axis=-1)
golden_probabilities = jax.nn.softmax(golden_logits[:token_size, :], axis=-1)

max_logging.log(f"{golden_probabilities[0]=}")
max_logging.log(f"{model_probabilities[0]=}")

kl_div = jax.numpy.sum(jax.scipy.special.kl_div(golden_probabilities, model_probabilities), axis=-1)
max_logging.log(f"KL divergence = {kl_div}, max KL divergence = {jax.numpy.max(kl_div)}")

if test_args.max_kl_div is not None:
max_logging.log("Checking KL Divergence between train distribution and golden distribution")
assert jax.numpy.all(kl_div < test_args.max_kl_div), f"KL divergence values exceed the specified threshold of {test_args.max_kl_div}. Max divergence: {jax.numpy.max(kl_div)}"
else:
max_logging.log("Checking Numerical Differences between train logits and golden logits")
assert jax.numpy.allclose(
full_train_logits[0, :token_size, :], golden_logits[:token_size, :], rtol=float(test_args.rtol), atol=float(test_args.atol), equal_nan=False
)
), f"Logits do not match closely enough. Required rtol={test_args.rtol}, atol={test_args.atol}."



Expand All @@ -125,11 +140,12 @@ def main(config, test_args):
parser.add_argument("--atol", type=float, required=False, default=0.1)
parser.add_argument("--rtol", type=float, required=False, default=0.1)
parser.add_argument("--token_size", type=int, required=False)
parser.add_argument("--max_kl_div", type=float, required=False, default=None)
test_args, _ = parser.parse_known_args()

# Remove args defined in this test file to avoid error from pyconfig
model_args = sys.argv
to_remove_args = ["--atol", "--rtol", "--token_size"]
to_remove_args = ["--atol", "--rtol", "--token_size", "--max_kl_div"]
for arg in to_remove_args:
model_args = [s for s in model_args if not s.startswith(arg)]

Expand Down

0 comments on commit 3c49ac7

Please sign in to comment.