Skip to content

Commit

Permalink
Bypass reward model usage when reward_model_multiplier is 0 (#461)
Browse files Browse the repository at this point in the history
* x

* x

Signed-off-by: SumanthRH <[email protected]>

* x

Signed-off-by: SumanthRH <[email protected]>

* push changes

* push changes

---------

Signed-off-by: SumanthRH <[email protected]>
Co-authored-by: Costa Huang <[email protected]>
SumanthRH and vwxyzjn authored Dec 3, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent e363290 commit 9b39f55
Showing 1 changed file with 26 additions and 23 deletions.
49 changes: 26 additions & 23 deletions open_instruct/ppo_vllm_thread_ray_gtrl.py
Original file line number Diff line number Diff line change
@@ -690,24 +690,27 @@ def from_pretrained(
self.ref_policy.eval()

# reward model
self.reward_model: PreTrainedModel = AutoModelForSequenceClassification.from_pretrained(
args.reward_model_path,
revision=args.reward_model_revision,
num_labels=1,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
use_cache=False,
)
disable_dropout_in_model(self.reward_model)
ds_config = get_eval_ds_config(
offload=False,
stage=args.deepspeed_stage,
bf16=True,
)
ds_config["train_micro_batch_size_per_gpu"] = args.per_device_train_batch_size
ds_config["train_batch_size"] = args.mini_batch_size
self.reward_model, *_ = deepspeed.initialize(model=self.reward_model, config=ds_config)
self.reward_model.eval()
if args.reward_model_multiplier:
self.reward_model: PreTrainedModel = AutoModelForSequenceClassification.from_pretrained(
args.reward_model_path,
revision=args.reward_model_revision,
num_labels=1,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
use_cache=False,
)
disable_dropout_in_model(self.reward_model)
ds_config = get_eval_ds_config(
offload=False,
stage=args.deepspeed_stage,
bf16=True,
)
ds_config["train_micro_batch_size_per_gpu"] = args.per_device_train_batch_size
ds_config["train_batch_size"] = args.mini_batch_size
self.reward_model, *_ = deepspeed.initialize(model=self.reward_model, config=ds_config)
self.reward_model.eval()

assert args.reward_model_multiplier or args.apply_verifiable_reward, "Either `reward_model_multiplier` must be non-zero or `apply_verifiable_reward` must be True."

def get_vocab_size(self):
return self.policy.config.vocab_size
@@ -1089,12 +1092,12 @@ def vllm_generate(
# Response Processing 2. run reward model on the truncated responses
postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
sequence_length = first_true_indices(postprocessed_response == tokenizer.pad_token_id) - 1
_, score, _ = get_reward(
self.reward_model, postprocessed_query_response, tokenizer.pad_token_id, context_length
)
if args.reward_model_multiplier != 1.0:
score = torch.zeros(query.shape[0], device=query.device)
if args.reward_model_multiplier:
_, score, _ = get_reward(
self.reward_model, postprocessed_query_response, tokenizer.pad_token_id, context_length
)
score *= args.reward_model_multiplier
# also apply verifiable reward
if args.apply_verifiable_reward:
# we need to batch the gt to match query.
ground_truth = ground_truths[i : i + args.local_rollout_forward_batch_size]

0 comments on commit 9b39f55

Please sign in to comment.