From 20bd1e331088102764b9f4507ba60c343efbe5ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Tue, 30 Jul 2024 09:59:16 +0200 Subject: [PATCH] Add `RewardModelScore` step (#840) * Add `RewardModelScore` step * Use logits * Update docstring * Fix unit test * Adjust abs tolerance --- .../llms/huggingface/inference_endpoints.py | 9 +- .../llms/huggingface/transformers.py | 13 +- src/distilabel/steps/__init__.py | 2 + src/distilabel/steps/reward_model.py | 220 ++++++++++++++++++ src/distilabel/utils/huggingface.py | 11 +- tests/unit/steps/test_reward_model.py | 93 ++++++++ 6 files changed, 331 insertions(+), 17 deletions(-) create mode 100644 src/distilabel/steps/reward_model.py create mode 100644 tests/unit/steps/test_reward_model.py diff --git a/src/distilabel/llms/huggingface/inference_endpoints.py b/src/distilabel/llms/huggingface/inference_endpoints.py index f2be978620..95d0020f8f 100644 --- a/src/distilabel/llms/huggingface/inference_endpoints.py +++ b/src/distilabel/llms/huggingface/inference_endpoints.py @@ -37,10 +37,7 @@ StandardInput, StructuredOutputType, ) -from distilabel.utils.huggingface import ( - _INFERENCE_ENDPOINTS_API_KEY_ENV_VAR_NAME, - get_hf_token, -) +from distilabel.utils.huggingface import HF_TOKEN_ENV_VAR, get_hf_token if TYPE_CHECKING: from huggingface_hub import AsyncInferenceClient @@ -162,7 +159,7 @@ class User(BaseModel): description="The base URL to use for the Inference Endpoints API requests.", ) api_key: Optional[RuntimeParameter[SecretStr]] = Field( - default=os.getenv(_INFERENCE_ENDPOINTS_API_KEY_ENV_VAR_NAME), + default_factory=lambda: os.getenv(HF_TOKEN_ENV_VAR), description="The API key to authenticate the requests to the Inference Endpoints API.", ) @@ -178,7 +175,7 @@ class User(BaseModel): _model_name: Optional[str] = PrivateAttr(default=None) _tokenizer: Optional["PreTrainedTokenizer"] = PrivateAttr(default=None) - _api_key_env_var: str = PrivateAttr(_INFERENCE_ENDPOINTS_API_KEY_ENV_VAR_NAME) + _api_key_env_var: str = PrivateAttr(HF_TOKEN_ENV_VAR) _aclient: Optional["AsyncInferenceClient"] = PrivateAttr(...) @model_validator(mode="after") # type: ignore diff --git a/src/distilabel/llms/huggingface/transformers.py b/src/distilabel/llms/huggingface/transformers.py index 06d6b71422..6b8ad25e2f 100644 --- a/src/distilabel/llms/huggingface/transformers.py +++ b/src/distilabel/llms/huggingface/transformers.py @@ -15,7 +15,7 @@ import os from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union -from pydantic import Field, PrivateAttr, validate_call +from pydantic import Field, PrivateAttr, SecretStr, validate_call from distilabel.llms.base import LLM from distilabel.llms.chat_templates import CHATML_TEMPLATE @@ -24,6 +24,7 @@ from distilabel.llms.typing import GenerateOutput from distilabel.mixins.runtime_parameters import RuntimeParameter from distilabel.steps.tasks.typing import OutlinesStructuredOutputType, StandardInput +from distilabel.utils.huggingface import HF_TOKEN_ENV_VAR if TYPE_CHECKING: from transformers import Pipeline @@ -46,8 +47,6 @@ class TransformersLLM(LLM, MagpieChatTemplateMixin, CudaDevicePlacementMixin): Defaults to `"auto"`. trust_remote_code: whether to allow fetching and executing remote code fetched from the repository in the Hub. Defaults to `False`. - trust_remote_code: whether to trust or not remote (code in the Hugging Face Hub - repository) code to load the model. Defaults to `False`. model_kwargs: additional dictionary of keyword arguments that will be passed to the `from_pretrained` method of the model. tokenizer: the tokenizer Hugging Face Hub repo id or a path to a directory containing @@ -103,7 +102,9 @@ class TransformersLLM(LLM, MagpieChatTemplateMixin, CudaDevicePlacementMixin): chat_template: Optional[str] = None device: Optional[Union[str, int]] = None device_map: Optional[Union[str, Dict[str, Any]]] = None - token: Optional[str] = None + token: Optional[SecretStr] = Field( + default_factory=lambda: os.getenv(HF_TOKEN_ENV_VAR) + ) structured_output: Optional[RuntimeParameter[OutlinesStructuredOutputType]] = Field( default=None, description="The structured output format to use across all the generations.", @@ -125,6 +126,8 @@ def load(self) -> None: "Transformers is not installed. Please install it using `pip install transformers`." ) from ie + token = self.token.get_secret_value() if self.token is not None else self.token + self._pipeline = pipeline( "text-generation", model=self.model, @@ -136,7 +139,7 @@ def load(self) -> None: use_fast=self.use_fast, device=self.device, device_map=self.device_map, - token=self.token or os.getenv("HF_TOKEN"), + token=token, return_full_text=False, ) diff --git a/src/distilabel/steps/__init__.py b/src/distilabel/steps/__init__.py index 8ba07a4710..bd8fde2251 100644 --- a/src/distilabel/steps/__init__.py +++ b/src/distilabel/steps/__init__.py @@ -46,6 +46,7 @@ ) from distilabel.steps.generators.utils import make_generator_step from distilabel.steps.globals.huggingface import PushToHub +from distilabel.steps.reward_model import RewardModelScore from distilabel.steps.typing import GeneratorStepOutput, StepOutput __all__ = [ @@ -75,6 +76,7 @@ "PushToHub", "Step", "StepInput", + "RewardModelScore", "GeneratorStepOutput", "StepOutput", "step", diff --git a/src/distilabel/steps/reward_model.py b/src/distilabel/steps/reward_model.py new file mode 100644 index 0000000000..72e58173a5 --- /dev/null +++ b/src/distilabel/steps/reward_model.py @@ -0,0 +1,220 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import TYPE_CHECKING, Any, Dict, List, Union + +from pydantic import Field, PrivateAttr, SecretStr + +from distilabel.steps.base import Step, StepInput +from distilabel.utils.huggingface import HF_TOKEN_ENV_VAR + +if TYPE_CHECKING: + import torch + from transformers import PreTrainedModel, PreTrainedTokenizer + + from distilabel.steps.tasks.typing import ChatType + from distilabel.steps.typing import StepOutput + + +class RewardModelScore(Step): + """Assign a score to a response using a Reward Model. + + `RewardModelScore` is a `Step` that using a Reward Model (RM) loaded using `transformers`, + assigns an score to a response generated for an instruction, or a score to a multi-turn + conversation. + + Attributes: + model: the model Hugging Face Hub repo id or a path to a directory containing the + model weights and configuration files. + revision: if `model` refers to a Hugging Face Hub repository, then the revision + (e.g. a branch name or a commit id) to use. Defaults to `"main"`. + torch_dtype: the torch dtype to use for the model e.g. "float16", "float32", etc. + Defaults to `"auto"`. + trust_remote_code: whether to allow fetching and executing remote code fetched + from the repository in the Hub. Defaults to `False`. + device_map: a dictionary mapping each layer of the model to a device, or a mode like `"sequential"` or `"auto"`. Defaults to `None`. + token: the Hugging Face Hub token that will be used to authenticate to the Hugging + Face Hub. If not provided, the `HF_TOKEN` environment or `huggingface_hub` package + local configuration will be used. Defaults to `None`. + truncation: whether to truncate sequences at the maximum length. Defaults to `False`. + max_length: maximun length to use for padding or truncation. Defaults to `None`. + + Input columns: + - instruction (`str`, optional): the instruction used to generate a `response`. + If provided, then `response` must be provided too. + - response (`str`, optional): the response generated for `instruction`. If provided, + then `instruction` must be provide too. + - conversation (`ChatType`, optional): a multi-turn conversation. If not provided, + then `instruction` and `response` columns must be provided. + + Output columns: + - score (`float`): the score given by the reward model for the instruction-response + pair or the conversation. + + Categories: + - scorer + + Examples: + + Assigning an score for an instruction-response pair: + + ```python + from distilabel.steps import RewardModelScore + + step = RewardModelScore( + model="RLHFlow/ArmoRM-Llama3-8B-v0.1", device_map="auto", trust_remote_code=True + ) + + step.load() + + result = next( + step.process( + inputs=[ + { + "instruction": "How much is 2+2?", + "response": "The output of 2+2 is 4", + }, + {"instruction": "How much is 2+2?", "response": "4"}, + ] + ) + ) + # [ + # {'instruction': 'How much is 2+2?', 'response': 'The output of 2+2 is 4', 'score': 0.11690367758274078}, + # {'instruction': 'How much is 2+2?', 'response': '4', 'score': 0.10300665348768234} + # ] + ``` + + Assigning an score for a multi-turn conversation: + + ```python + from distilabel.steps import RewardModelScore + + step = RewardModelScore( + model="RLHFlow/ArmoRM-Llama3-8B-v0.1", device_map="auto", trust_remote_code=True + ) + + step.load() + + result = next( + step.process( + inputs=[ + { + "conversation": [ + {"role": "user", "content": "How much is 2+2?"}, + {"role": "assistant", "content": "The output of 2+2 is 4"}, + ], + }, + { + "conversation": [ + {"role": "user", "content": "How much is 2+2?"}, + {"role": "assistant", "content": "4"}, + ], + }, + ] + ) + ) + # [ + # {'conversation': [{'role': 'user', 'content': 'How much is 2+2?'}, {'role': 'assistant', 'content': 'The output of 2+2 is 4'}], 'score': 0.11690367758274078}, + # {'conversation': [{'role': 'user', 'content': 'How much is 2+2?'}, {'role': 'assistant', 'content': '4'}], 'score': 0.10300665348768234} + # ] + ``` + """ + + model: str + revision: str = "main" + torch_dtype: str = "auto" + trust_remote_code: bool = False + device_map: Union[str, Dict[str, Any], None] = None + token: Union[SecretStr, None] = Field( + default_factory=lambda: os.getenv(HF_TOKEN_ENV_VAR), description="" + ) + truncation: bool = False + max_length: Union[int, None] = None + + _model: Union["PreTrainedModel", None] = PrivateAttr(None) + _tokenizer: Union["PreTrainedTokenizer", None] = PrivateAttr(None) + + def load(self) -> None: + super().load() + + try: + from transformers import AutoModelForSequenceClassification, AutoTokenizer + except ImportError as e: + raise ImportError( + "`transformers` is not installed. Please install it using `pip install transformers`." + ) from e + + token = self.token.get_secret_value() if self.token is not None else self.token + + self._model = AutoModelForSequenceClassification.from_pretrained( + self.model, + revision=self.revision, + torch_dtype=self.torch_dtype, + trust_remote_code=self.trust_remote_code, + device_map=self.device_map, + token=token, + ) + self._tokenizer = AutoTokenizer.from_pretrained( + self.model, + revision=self.revision, + torch_dtype=self.torch_dtype, + trust_remote_code=self.trust_remote_code, + token=token, + ) + + @property + def inputs(self) -> List[str]: + """Either `response` and `instruction`, or a `conversation` columns.""" + return [] + + @property + def outputs(self) -> List[str]: + """The `score` given by the reward model.""" + return ["score"] + + def _prepare_conversation(self, input: Dict[str, Any]) -> "ChatType": + if "instruction" in input and "response" in input: + return [ + {"role": "user", "content": input["instruction"]}, + {"role": "assistant", "content": input["response"]}, + ] + + return input["conversation"] + + def _prepare_inputs(self, inputs: List[Dict[str, Any]]) -> "torch.Tensor": + return self._tokenizer.apply_chat_template( # type: ignore + [self._prepare_conversation(input) for input in inputs], # type: ignore + return_tensors="pt", + padding=True, + truncation=self.truncation, + max_length=self.max_length, + ).to(self._model.device) # type: ignore + + def _inference(self, inputs: List[Dict[str, Any]]) -> List[float]: + import torch + + input_ids = self._prepare_inputs(inputs) + with torch.no_grad(): + output = self._model(input_ids) # type: ignore + logits = output.logits + if logits.shape == (2, 1): + logits = logits.squeeze(-1) + return logits.tolist() + + def process(self, inputs: StepInput) -> "StepOutput": # type: ignore + scores = self._inference(inputs) + for input, score in zip(inputs, scores): + input["score"] = score + yield inputs diff --git a/src/distilabel/utils/huggingface.py b/src/distilabel/utils/huggingface.py index 7a637a831c..70be04418e 100644 --- a/src/distilabel/utils/huggingface.py +++ b/src/distilabel/utils/huggingface.py @@ -18,7 +18,7 @@ from huggingface_hub import constants -_INFERENCE_ENDPOINTS_API_KEY_ENV_VAR_NAME: Final[str] = "HF_TOKEN" +HF_TOKEN_ENV_VAR: Final[str] = "HF_TOKEN" def get_hf_token(cls_name: str, token_arg: str) -> str: @@ -39,14 +39,13 @@ def get_hf_token(cls_name: str, token_arg: str) -> str: Returns: The token for the hugging face API. """ - token = os.getenv(_INFERENCE_ENDPOINTS_API_KEY_ENV_VAR_NAME) + token = os.getenv(HF_TOKEN_ENV_VAR) if token is None: if not Path(constants.HF_TOKEN_PATH).exists(): raise ValueError( - f"To use `{cls_name}` an API key must be provided via" - f" `{token_arg}`, set the environment variable" - f" `{_INFERENCE_ENDPOINTS_API_KEY_ENV_VAR_NAME}` or use the `huggingface-hub` CLI to login" - " with `huggingface-cli login`." + f"To use `{cls_name}` an API key must be provided via `{token_arg}`," + f" set the environment variable `{HF_TOKEN_ENV_VAR}` or use the" + " `huggingface-hub` CLI to login with `huggingface-cli login`." ) with open(constants.HF_TOKEN_PATH) as f: token = f.read().strip() diff --git a/tests/unit/steps/test_reward_model.py b/tests/unit/steps/test_reward_model.py new file mode 100644 index 0000000000..cb6db69624 --- /dev/null +++ b/tests/unit/steps/test_reward_model.py @@ -0,0 +1,93 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from distilabel.steps.reward_model import RewardModelScore + + +class TestRewardModelScore: + def test_process(self) -> None: + step = RewardModelScore( + model="OpenAssistant/reward-model-deberta-v3-large-v2", + ) + + step.load() + + result = next( + step.process( + inputs=[ + { + "instruction": "How much is 2+2?", + "response": "The output of 2+2 is 4", + }, + {"instruction": "How much is 2+2?", "response": "4"}, + ] + ) + ) + + assert result == [ + { + "instruction": "How much is 2+2?", + "response": "The output of 2+2 is 4", + "score": pytest.approx(-0.5738837122917175, abs=1e-6), + }, + { + "instruction": "How much is 2+2?", + "response": "4", + "score": pytest.approx(-0.6376492977142334, abs=1e-6), + }, + ] + + def test_process_with_conversation(self) -> None: + step = RewardModelScore( + model="OpenAssistant/reward-model-deberta-v3-large-v2", + ) + + step.load() + + result = next( + step.process( + inputs=[ + { + "conversation": [ + {"role": "user", "content": "How much is 2+2?"}, + {"role": "assistant", "content": "The output of 2+2 is 4"}, + ], + }, + { + "conversation": [ + {"role": "user", "content": "How much is 2+2?"}, + {"role": "assistant", "content": "4"}, + ], + }, + ] + ) + ) + + assert result == [ + { + "conversation": [ + {"role": "user", "content": "How much is 2+2?"}, + {"role": "assistant", "content": "The output of 2+2 is 4"}, + ], + "score": pytest.approx(-0.5738837122917175, abs=1e-6), + }, + { + "conversation": [ + {"role": "user", "content": "How much is 2+2?"}, + {"role": "assistant", "content": "4"}, + ], + "score": pytest.approx(-0.6376492977142334, abs=1e-6), + }, + ]