diff --git a/MaxText/tests/forward_pass_logit_checker.py b/MaxText/tests/forward_pass_logit_checker.py index f4c7fecdd..80d71fdcb 100644 --- a/MaxText/tests/forward_pass_logit_checker.py +++ b/MaxText/tests/forward_pass_logit_checker.py @@ -110,10 +110,25 @@ def main(config, test_args): max_logging.log(f"{golden_logits[0]=}") max_logging.log(f"{full_train_logits[0, 0, :]=}") token_size = int(test_args.token_size) if test_args.token_size else golden_logits.shape[0] - max_logging.log(f"Max Difference {np.max(np.subtract(full_train_logits[0, :token_size, :], golden_logits[:token_size, :]))}") - assert jax.numpy.allclose( + max_logging.log(f"Max Numerical Difference {np.max(np.subtract(full_train_logits[0, :token_size, :], golden_logits[:token_size, :]))}") + + model_probabilities = jax.nn.softmax(full_train_logits[0, :token_size, :], axis=-1) + golden_probabilities = jax.nn.softmax(golden_logits[:token_size, :], axis=-1) + + max_logging.log(f"{golden_probabilities[0]=}") + max_logging.log(f"{model_probabilities[0]=}") + + kl_div = jax.numpy.sum(jax.scipy.special.kl_div(golden_probabilities, model_probabilities), axis=-1) + max_logging.log(f"KL divergence = {kl_div}, max KL divergence = {jax.numpy.max(kl_div)}") + + if test_args.max_kl_div is not None: + max_logging.log("Checking KL Divergence between train distribution and golden distribution") + assert jax.numpy.all(kl_div < test_args.max_kl_div), f"KL divergence values exceed the specified threshold of {test_args.max_kl_div}. Max divergence: {jax.numpy.max(kl_div)}" + else: + max_logging.log("Checking Numerical Differences between train logits and golden logits") + assert jax.numpy.allclose( full_train_logits[0, :token_size, :], golden_logits[:token_size, :], rtol=float(test_args.rtol), atol=float(test_args.atol), equal_nan=False - ) + ), f"Logits do not match closely enough. Required rtol={test_args.rtol}, atol={test_args.atol}." @@ -125,11 +140,12 @@ def main(config, test_args): parser.add_argument("--atol", type=float, required=False, default=0.1) parser.add_argument("--rtol", type=float, required=False, default=0.1) parser.add_argument("--token_size", type=int, required=False) + parser.add_argument("--max_kl_div", type=float, required=False, default=None) test_args, _ = parser.parse_known_args() # Remove args defined in this test file to avoid error from pyconfig model_args = sys.argv - to_remove_args = ["--atol", "--rtol", "--token_size"] + to_remove_args = ["--atol", "--rtol", "--token_size", "--max_kl_div"] for arg in to_remove_args: model_args = [s for s in model_args if not s.startswith(arg)]