diff --git a/rewardbench/models/__init__.py b/rewardbench/models/__init__.py index 24f7c7e..672a73b 100644 --- a/rewardbench/models/__init__.py +++ b/rewardbench/models/__init__.py @@ -32,7 +32,7 @@ from .openbmb import LlamaRewardModel, OpenBMBPipeline from .pairrm import DebertaV2PairRM, PairRMPipeline from .pipeline import RewardBenchPipeline -from .qrm import LlamaForRewardModelWithGating31, LlamaForRewardModelWithGating3 +from .qrm import LlamaForRewardModelWithGating3, LlamaForRewardModelWithGating31 from .shp import SHPPipeline from .slicpairpm import SlicPairPMPipeline from .starling import ( diff --git a/rewardbench/models/qrm.py b/rewardbench/models/qrm.py index 374cf7b..aa266f9 100644 --- a/rewardbench/models/qrm.py +++ b/rewardbench/models/qrm.py @@ -1,19 +1,30 @@ from dataclasses import dataclass -from typing import Optional, List +from typing import List, Optional import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint -from transformers import LlamaModel, LlamaPreTrainedModel, AutoModelForSequenceClassification +from transformers import ( + AutoModelForSequenceClassification, + LlamaModel, + LlamaPreTrainedModel, +) from transformers.models.llama.modeling_llama import LLAMA_INPUTS_DOCSTRING -from transformers.utils import ModelOutput -from transformers.utils import add_start_docstrings_to_model_forward +from transformers.utils import ModelOutput, add_start_docstrings_to_model_forward class GatingNetwork(nn.Module): - def __init__(self, in_features: int, out_features: int, bias: bool = True, temperature: float = 10, - logit_scale: float = 1., hidden_dim: int = 1024, n_hidden: int = 3): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + temperature: float = 10, + logit_scale: float = 1.0, + hidden_dim: int = 1024, + n_hidden: int = 3, + ): super().__init__() self.temperature = temperature self.logit_scale = nn.Parameter(torch.ones(1) * logit_scale) @@ -41,12 +52,14 @@ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: token_pattern = [128009, 128006, 78191, 128007, 271] -def find_token_for_gating(lst, ): +def find_token_for_gating( + lst, +): """Find the last occurrence of a token_pattern in a list.""" token_pattern_len = len(token_pattern) search_end = len(lst) for j in range(search_end - token_pattern_len, -1, -1): - if lst[j:j + token_pattern_len] == token_pattern: + if lst[j : j + token_pattern_len] == token_pattern: return j raise ValueError("Token pattern not found in the list.") @@ -87,20 +100,23 @@ def __init__(self, config): config_dict = config.to_dict() self.num_objectives = config_dict.get("num_objectives", 19) self.num_quantiles = config.num_quantiles - self.quantiles = torch.linspace(0., 1., config.num_quantiles + 2)[1:-1] + self.quantiles = torch.linspace(0.0, 1.0, config.num_quantiles + 2)[1:-1] self.regression_layer = nn.Linear(config.hidden_size, config.num_quantiles * self.num_objectives, bias=False) self.post_init() # Not using torch.eye because it is not supported in BF16 t = torch.zeros(self.num_objectives, self.num_objectives) - t[range(self.num_objectives), range(self.num_objectives)] = 1. + t[range(self.num_objectives), range(self.num_objectives)] = 1.0 self.reward_transform_matrix = nn.Parameter(t) self.reward_transform_matrix.requires_grad = False # Initialize weights and apply final processing - self.gating = GatingNetwork(config.hidden_size, config.num_objectives, - temperature=config_dict.get("gating_temperature", 10), - hidden_dim=config_dict.get("gating_hidden_dim", 1024), - n_hidden=config_dict.get("gating_n_hidden", 3)) + self.gating = GatingNetwork( + config.hidden_size, + config.num_objectives, + temperature=config_dict.get("gating_temperature", 10), + hidden_dim=config_dict.get("gating_hidden_dim", 1024), + n_hidden=config_dict.get("gating_n_hidden", 3), + ) @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): @@ -112,17 +128,17 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, ) -> CustomOutput: return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -170,11 +186,12 @@ def forward( gating_output = self.gating(prompt_embedding.float()) reward_quantiles_all_adjusted = torch.matmul( - torch.transpose(rewards.float(), 1, 2), self.reward_transform_matrix) + torch.transpose(rewards.float(), 1, 2), self.reward_transform_matrix + ) # [B, num_quantiles, num_objectives] reward_quantiles = torch.mul( gating_output.unsqueeze(-1).repeat(1, 1, self.num_objectives), - torch.transpose(reward_quantiles_all_adjusted, 1, 2) + torch.transpose(reward_quantiles_all_adjusted, 1, 2), ).sum(1) rewards_expectation = rewards.float().mean(dim=2) @@ -201,42 +218,43 @@ def __init__(self, config): config_dict = config.to_dict() self.num_objectives = config_dict.get("num_objectives", 19) self.num_quantiles = config.num_quantiles - self.quantiles = torch.linspace(0., 1., config.num_quantiles + 2)[1:-1] + self.quantiles = torch.linspace(0.0, 1.0, config.num_quantiles + 2)[1:-1] self.regression_layer = nn.Linear(config.hidden_size, config.num_quantiles * self.num_objectives, bias=False) self.post_init() # Not using torch.eye because it is not supported in BF16 t = torch.zeros(self.num_objectives, self.num_objectives).to(torch.bfloat16) - t[range(self.num_objectives), range(self.num_objectives)] = 1. + t[range(self.num_objectives), range(self.num_objectives)] = 1.0 self.reward_transform_matrix = nn.Parameter(t) self.reward_transform_matrix.requires_grad = False # Initialize weights and apply final processing - self.gating = GatingNetwork(config.hidden_size, config.num_objectives, - temperature=config_dict.get("gating_temperature", 10), - hidden_dim=config_dict.get("gating_hidden_dim", 1024), - n_hidden=config_dict.get("gating_n_hidden", 3)) + self.gating = GatingNetwork( + config.hidden_size, + config.num_objectives, + temperature=config_dict.get("gating_temperature", 10), + hidden_dim=config_dict.get("gating_hidden_dim", 1024), + n_hidden=config_dict.get("gating_n_hidden", 3), + ) @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): return super().from_pretrained( - pretrained_model_name_or_path, - trust_remote_code=True, - attn_implementation="flash_attention_2" + pretrained_model_name_or_path, trust_remote_code=True, attn_implementation="flash_attention_2" ) @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, ) -> CustomOutput: return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -283,11 +301,12 @@ def forward( with torch.autocast(device_type=rewards.device.type, dtype=torch.float32): reward_quantiles_all_adjusted = torch.matmul( - torch.transpose(rewards.float(), 1, 2), self.reward_transform_matrix) + torch.transpose(rewards.float(), 1, 2), self.reward_transform_matrix + ) # [B, num_quantiles, num_objectives] reward_quantiles = torch.mul( gating_output.unsqueeze(-1).repeat(1, 1, self.num_objectives), - torch.transpose(reward_quantiles_all_adjusted, 1, 2) + torch.transpose(reward_quantiles_all_adjusted, 1, 2), ).sum(1) rewards_expectation = rewards.float().mean(dim=2) @@ -301,4 +320,4 @@ def forward( gating_output=gating_output, score=score, logits=score, - ) \ No newline at end of file + )