Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement softcapping in fused jsd #403

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open

Conversation

wheynelau
Copy link

Summary

Implements softcap in the fused linear jsd, so it can be used for gemma2 models

Details

Assumes same softcap for teacher and student model

Testing Done

  • added tests for softcapping in test_fused_linear_jsd.py

  • Hardware Type: L40S

  • run make test to ensure correctness

  • run make checkstyle to ensure code style

  • run make test-convergence to ensure convergence

yundai424
yundai424 previously approved these changes Nov 21, 2024
@yundai424
Copy link
Collaborator

@wheynelau i just fixed some conflict due to out of sync 😃 FYI test_correctness_functional is failing for me on A100 when softcap=50.0. Raw output here: https://gist.github.com/yundai424/7bbfa78f05667749ce189cd458cabf90

@wheynelau
Copy link
Author

@yundai424 Okay thanks! Let me take a look at this

@wheynelau
Copy link
Author

@yundai424 Have updated it!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants