Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add kl divergence for forward_pass_logit_checker #832

Merged
merged 1 commit into from
Aug 20, 2024
Merged

Conversation

ZhaoyueCheng
Copy link
Collaborator

@ZhaoyueCheng ZhaoyueCheng commented Aug 19, 2024

Description

  • add kl divergence option for forward_pass_logit_checker so that it compares the kl divergence of the output probability distribution with the golden output probability distribution
  • by default it's still using the original option to check numerical differences so it don't break other tests running

Test

  • http://shortn/_XlXElFUgnv added the logs of 4 tests (kl passed, kl failed, numerical passed, numerical failed) with updated logging message here

@rdyro
Copy link
Collaborator

rdyro commented Aug 19, 2024

Great, thanks for implementing this! This looks good to me

Copy link
Collaborator

@gagika gagika left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the change, a few comments to improve logging, naming.

MaxText/tests/forward_pass_logit_checker.py Outdated Show resolved Hide resolved
@@ -125,11 +138,12 @@ def main(config, test_args):
parser.add_argument("--atol", type=float, required=False, default=0.1)
parser.add_argument("--rtol", type=float, required=False, default=0.1)
parser.add_argument("--token_size", type=int, required=False)
parser.add_argument("--kl", type=float, required=False, default=None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you comment that if this is set then only KL divergence will be checked?

MaxText/tests/forward_pass_logit_checker.py Outdated Show resolved Hide resolved
MaxText/tests/forward_pass_logit_checker.py Outdated Show resolved Hide resolved
@gobbleturk gobbleturk removed their assignment Aug 20, 2024
@copybara-service copybara-service bot merged commit 15966fa into main Aug 20, 2024
13 checks passed
@copybara-service copybara-service bot deleted the kl-div branch August 20, 2024 18:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants