-
-
Notifications
You must be signed in to change notification settings - Fork 945
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TRL upgrade #2307
base: main
Are you sure you want to change the base?
TRL upgrade #2307
Conversation
from pydantic import BaseModel, Field | ||
|
||
|
||
class TrlConfig(BaseModel): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great : )
clean path and add mounts handle mounting
texts, | ||
return_tensors="pt", | ||
padding=True, | ||
padding_side="right", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why there padding_side mismatched between reward_inputs
and prompt_inputs
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this was a copy/paste from upstream trl trainer as I needed this PR to land to simplify our end. https://github.com/huggingface/trl/pull/2817/files I've removed this method from this class now that the referenced PR is merged.
@@ -576,7 +576,7 @@ def prepare_opinionated_env(cfg): | |||
def setup_trainer( | |||
cfg, train_dataset, eval_dataset, model, tokenizer, processor, total_num_steps | |||
): | |||
if cfg.rl in ("dpo", "ipo", "orpo", "kto", "simpo"): | |||
if cfg.rl in ("dpo", "grpo", "ipo", "orpo", "kto", "simpo"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can just check for if cfg.rl
here right?
wip towards adding support for GRPO