Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Print per-token reward over an RM #9

Merged
merged 3 commits into from
Feb 7, 2024
Merged

Print per-token reward over an RM #9

merged 3 commits into from
Feb 7, 2024

Conversation

natolambert
Copy link
Collaborator

Brief documentation in analysis/README.md
We can add visualizing it soon :)

E.g.
Reward: -0.544 | Substring: I
Reward: -0.556 | Substring: I love
Reward: -0.566 | Substring: I love to
Reward: 0.099 | Substring: I love to walk
Reward: 0.096 | Substring: I love to walk the
Reward: 0.092 | Substring: I love to walk the dog
Reward: 0.09 | Substring: I love to walk the dog,
Reward: 0.087 | Substring: I love to walk the dog, what
Reward: 0.085 | Substring: I love to walk the dog, what do
Reward: 0.089 | Substring: I love to walk the dog, what do you
Reward: 0.09 | Substring: I love to walk the dog, what do you like
Reward: 0.093 | Substring: I love to walk the dog, what do you like?

@ljvmiranda921
Copy link
Member

Will review later today!

Comment on lines +58 to +95
args = get_args()
quantized = True # only Starling isn't quantized for now
custom_dialogue = False
# some models need custom code to be run
if "oasst" in args.model or "oasst" in args.chat_template:
from herm.models import openassistant # noqa

model_builder = AutoModelForSequenceClassification.from_pretrained
pipeline_builder = pipeline
elif "Starling" in args.model or "Starling" in args.chat_template:
from herm.models.starling import StarlingPipeline, build_starling_rm

model_builder = build_starling_rm
pipeline_builder = StarlingPipeline
quantized = False
elif "openbmb" in args.model or "openbmb" in args.chat_template:
from herm.models.openbmb import LlamaRewardModel, OpenBMBPipeline

model_builder = LlamaRewardModel.from_pretrained
pipeline_builder = OpenBMBPipeline
elif "PairRM" in args.model or "PairRM" in args.chat_template:
from herm.models.pairrm import DebertaV2PairRM, PairRMPipeline

custom_dialogue = True
model_builder = DebertaV2PairRM.from_pretrained
pipeline_builder = PairRMPipeline
elif "SHP" in args.model or "SHP" in args.chat_template:
from herm.models.shp import SHPPipeline

custom_dialogue = True
model_builder = T5ForConditionalGeneration.from_pretrained
pipeline_builder = SHPPipeline
else:
model_builder = AutoModelForSequenceClassification.from_pretrained
pipeline_builder = pipeline

if custom_dialogue:
raise ValueError("Custom dialogue formatting not yet supported in this script")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In case we're going to reuse this code block in the future, we should factor this logic out (so that we can reuse it on run_rm.py), but imo for v1 it's fine for now 🤔

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I agree @ljvmiranda921 , and maybe add a test case.

Copy link
Member

@ljvmiranda921 ljvmiranda921 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@natolambert natolambert merged commit ed1bffa into main Feb 7, 2024
3 checks passed
@ljvmiranda921 ljvmiranda921 deleted the per_token branch February 9, 2024 01:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants