Skip to content

Commit

Permalink
style updates
Browse files Browse the repository at this point in the history
  • Loading branch information
Nicolinho committed Oct 9, 2024
1 parent 6f7eb90 commit 290a868
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 51 deletions.
2 changes: 1 addition & 1 deletion rewardbench/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from .openbmb import LlamaRewardModel, OpenBMBPipeline
from .pairrm import DebertaV2PairRM, PairRMPipeline
from .pipeline import RewardBenchPipeline
from .qrm import LlamaForRewardModelWithGating31, LlamaForRewardModelWithGating3
from .qrm import LlamaForRewardModelWithGating3, LlamaForRewardModelWithGating31
from .shp import SHPPipeline
from .slicpairpm import SlicPairPMPipeline
from .starling import (
Expand Down
119 changes: 69 additions & 50 deletions rewardbench/models/qrm.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,30 @@
from dataclasses import dataclass
from typing import Optional, List
from typing import List, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from transformers import LlamaModel, LlamaPreTrainedModel, AutoModelForSequenceClassification
from transformers import (
AutoModelForSequenceClassification,
LlamaModel,
LlamaPreTrainedModel,
)
from transformers.models.llama.modeling_llama import LLAMA_INPUTS_DOCSTRING
from transformers.utils import ModelOutput
from transformers.utils import add_start_docstrings_to_model_forward
from transformers.utils import ModelOutput, add_start_docstrings_to_model_forward


class GatingNetwork(nn.Module):
def __init__(self, in_features: int, out_features: int, bias: bool = True, temperature: float = 10,
logit_scale: float = 1., hidden_dim: int = 1024, n_hidden: int = 3):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
temperature: float = 10,
logit_scale: float = 1.0,
hidden_dim: int = 1024,
n_hidden: int = 3,
):
super().__init__()
self.temperature = temperature
self.logit_scale = nn.Parameter(torch.ones(1) * logit_scale)
Expand Down Expand Up @@ -41,12 +52,14 @@ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
token_pattern = [128009, 128006, 78191, 128007, 271]


def find_token_for_gating(lst, ):
def find_token_for_gating(
lst,
):
"""Find the last occurrence of a token_pattern in a list."""
token_pattern_len = len(token_pattern)
search_end = len(lst)
for j in range(search_end - token_pattern_len, -1, -1):
if lst[j:j + token_pattern_len] == token_pattern:
if lst[j : j + token_pattern_len] == token_pattern:
return j
raise ValueError("Token pattern not found in the list.")

Expand Down Expand Up @@ -87,20 +100,23 @@ def __init__(self, config):
config_dict = config.to_dict()
self.num_objectives = config_dict.get("num_objectives", 19)
self.num_quantiles = config.num_quantiles
self.quantiles = torch.linspace(0., 1., config.num_quantiles + 2)[1:-1]
self.quantiles = torch.linspace(0.0, 1.0, config.num_quantiles + 2)[1:-1]
self.regression_layer = nn.Linear(config.hidden_size, config.num_quantiles * self.num_objectives, bias=False)
self.post_init()
# Not using torch.eye because it is not supported in BF16
t = torch.zeros(self.num_objectives, self.num_objectives)
t[range(self.num_objectives), range(self.num_objectives)] = 1.
t[range(self.num_objectives), range(self.num_objectives)] = 1.0
self.reward_transform_matrix = nn.Parameter(t)
self.reward_transform_matrix.requires_grad = False

# Initialize weights and apply final processing
self.gating = GatingNetwork(config.hidden_size, config.num_objectives,
temperature=config_dict.get("gating_temperature", 10),
hidden_dim=config_dict.get("gating_hidden_dim", 1024),
n_hidden=config_dict.get("gating_n_hidden", 3))
self.gating = GatingNetwork(
config.hidden_size,
config.num_objectives,
temperature=config_dict.get("gating_temperature", 10),
hidden_dim=config_dict.get("gating_hidden_dim", 1024),
n_hidden=config_dict.get("gating_n_hidden", 3),
)

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
Expand All @@ -112,17 +128,17 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):

@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> CustomOutput:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

Expand Down Expand Up @@ -170,11 +186,12 @@ def forward(
gating_output = self.gating(prompt_embedding.float())

reward_quantiles_all_adjusted = torch.matmul(
torch.transpose(rewards.float(), 1, 2), self.reward_transform_matrix)
torch.transpose(rewards.float(), 1, 2), self.reward_transform_matrix
)
# [B, num_quantiles, num_objectives]
reward_quantiles = torch.mul(
gating_output.unsqueeze(-1).repeat(1, 1, self.num_objectives),
torch.transpose(reward_quantiles_all_adjusted, 1, 2)
torch.transpose(reward_quantiles_all_adjusted, 1, 2),
).sum(1)

rewards_expectation = rewards.float().mean(dim=2)
Expand All @@ -201,42 +218,43 @@ def __init__(self, config):
config_dict = config.to_dict()
self.num_objectives = config_dict.get("num_objectives", 19)
self.num_quantiles = config.num_quantiles
self.quantiles = torch.linspace(0., 1., config.num_quantiles + 2)[1:-1]
self.quantiles = torch.linspace(0.0, 1.0, config.num_quantiles + 2)[1:-1]
self.regression_layer = nn.Linear(config.hidden_size, config.num_quantiles * self.num_objectives, bias=False)
self.post_init()
# Not using torch.eye because it is not supported in BF16
t = torch.zeros(self.num_objectives, self.num_objectives).to(torch.bfloat16)
t[range(self.num_objectives), range(self.num_objectives)] = 1.
t[range(self.num_objectives), range(self.num_objectives)] = 1.0
self.reward_transform_matrix = nn.Parameter(t)
self.reward_transform_matrix.requires_grad = False

# Initialize weights and apply final processing
self.gating = GatingNetwork(config.hidden_size, config.num_objectives,
temperature=config_dict.get("gating_temperature", 10),
hidden_dim=config_dict.get("gating_hidden_dim", 1024),
n_hidden=config_dict.get("gating_n_hidden", 3))
self.gating = GatingNetwork(
config.hidden_size,
config.num_objectives,
temperature=config_dict.get("gating_temperature", 10),
hidden_dim=config_dict.get("gating_hidden_dim", 1024),
n_hidden=config_dict.get("gating_n_hidden", 3),
)

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
return super().from_pretrained(
pretrained_model_name_or_path,
trust_remote_code=True,
attn_implementation="flash_attention_2"
pretrained_model_name_or_path, trust_remote_code=True, attn_implementation="flash_attention_2"
)

@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> CustomOutput:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

Expand Down Expand Up @@ -283,11 +301,12 @@ def forward(

with torch.autocast(device_type=rewards.device.type, dtype=torch.float32):
reward_quantiles_all_adjusted = torch.matmul(
torch.transpose(rewards.float(), 1, 2), self.reward_transform_matrix)
torch.transpose(rewards.float(), 1, 2), self.reward_transform_matrix
)
# [B, num_quantiles, num_objectives]
reward_quantiles = torch.mul(
gating_output.unsqueeze(-1).repeat(1, 1, self.num_objectives),
torch.transpose(reward_quantiles_all_adjusted, 1, 2)
torch.transpose(reward_quantiles_all_adjusted, 1, 2),
).sum(1)

rewards_expectation = rewards.float().mean(dim=2)
Expand All @@ -301,4 +320,4 @@ def forward(
gating_output=gating_output,
score=score,
logits=score,
)
)

0 comments on commit 290a868

Please sign in to comment.