-
Notifications
You must be signed in to change notification settings - Fork 269
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Grpo loss #553
base: main
Are you sure you want to change the base?
Grpo loss #553
Conversation
test/chunked_loss/test_grpo_loss.py
Outdated
attention_mask.view(-1)[mask_indices] = 0 | ||
|
||
# Create rewards with shape [B, num_generations] | ||
rewards = torch.randn(B * num_generations, device=device, dtype=dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it is not too much work, would it be possible to test a scenario where the rewards are same same, i.e all 1s?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, nice test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work! Great use of comments
so there was a small change to the loss: huggingface/trl#2881 which i will integrate here too |
Nice work!! |
Summary
Adds the GRPO chunked loss
fixes issue #548
Testing Done
make test
to ensure correctnessmake checkstyle
to ensure code stylemake test-convergence
to ensure convergence