From 2a3906d5a6ae637500bbc0e2259aa8e0089580c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Thu, 22 Aug 2024 09:18:50 +0200 Subject: [PATCH] Add `URIAL` task (#921) * Initial work for `URIAL` * Update template * Fix checking last message * Add `format_output` logic * Refine `format_output` and add docstring * Add `References` * Add `URIAL` unit tests --- src/distilabel/steps/tasks/__init__.py | 2 + .../steps/tasks/templates/urial.jinja2 | 16 +++ src/distilabel/steps/tasks/text_generation.py | 15 ++- src/distilabel/steps/tasks/ultrafeedback.py | 15 +-- src/distilabel/steps/tasks/urial.py | 125 ++++++++++++++++++ .../unit/steps/tasks/test_text_generation.py | 5 +- tests/unit/steps/tasks/test_urial.py | 72 ++++++++++ 7 files changed, 231 insertions(+), 19 deletions(-) create mode 100644 src/distilabel/steps/tasks/templates/urial.jinja2 create mode 100644 src/distilabel/steps/tasks/urial.py create mode 100644 tests/unit/steps/tasks/test_urial.py diff --git a/src/distilabel/steps/tasks/__init__.py b/src/distilabel/steps/tasks/__init__.py index 0b3a69596b..7bd96c3ce0 100644 --- a/src/distilabel/steps/tasks/__init__.py +++ b/src/distilabel/steps/tasks/__init__.py @@ -46,6 +46,7 @@ from distilabel.steps.tasks.text_generation import ChatGeneration, TextGeneration from distilabel.steps.tasks.typing import ChatItem, ChatType from distilabel.steps.tasks.ultrafeedback import UltraFeedback +from distilabel.steps.tasks.urial import URIAL __all__ = [ "GeneratorTask", @@ -79,4 +80,5 @@ "ChatItem", "ChatType", "UltraFeedback", + "URIAL", ] diff --git a/src/distilabel/steps/tasks/templates/urial.jinja2 b/src/distilabel/steps/tasks/templates/urial.jinja2 new file mode 100644 index 0000000000..09a45bcc58 --- /dev/null +++ b/src/distilabel/steps/tasks/templates/urial.jinja2 @@ -0,0 +1,16 @@ +# Instruction + +Below is a list of conversations between a human and an AI assistant (you). +Users place their queries under "# User:", and your responses are under "# Assistant:". +You are a helpful, respectful, and honest assistant. +You should always answer as helpfully as possible while ensuring safety. +Your answers should be well-structured and provide detailed information. They should also have an engaging tone. +Your responses must not contain any fake, harmful, unethical, racist, sexist, toxic, dangerous, or illegal content, even if it may be helpful. +Your response must be socially responsible, and thus you can refuse to answer some controversial topics. + +{% for message in messages %} +# {{ message.role | capitalize }}: + +{{ message.content }} +{% endfor %} +# Assistant: diff --git a/src/distilabel/steps/tasks/text_generation.py b/src/distilabel/steps/tasks/text_generation.py index f5c4659651..aeb74c9ec7 100644 --- a/src/distilabel/steps/tasks/text_generation.py +++ b/src/distilabel/steps/tasks/text_generation.py @@ -13,12 +13,15 @@ # limitations under the License. import warnings -from typing import Any, Dict, List, Union +from typing import TYPE_CHECKING, Any, Dict, List, Union from distilabel.steps.tasks.base import Task -from distilabel.steps.tasks.typing import ChatType from distilabel.utils.chat import is_openai_format +if TYPE_CHECKING: + from distilabel.steps.tasks.typing import ChatType + from distilabel.steps.typing import StepColumns + class TextGeneration(Task): """Simple text generation with an `LLM` given an instruction. @@ -78,11 +81,11 @@ class TextGeneration(Task): use_system_prompt: bool = True @property - def inputs(self) -> List[str]: + def inputs(self) -> "StepColumns": """The input for the task is the `instruction`.""" return ["instruction"] - def format_input(self, input: Dict[str, Any]) -> ChatType: + def format_input(self, input: Dict[str, Any]) -> "ChatType": """The input is formatted as a `ChatType` assuming that the instruction is the first interaction from the user within a conversation.""" @@ -189,7 +192,7 @@ def inputs(self) -> List[str]: """The input for the task are the `messages`.""" return ["messages"] - def format_input(self, input: Dict[str, Any]) -> ChatType: + def format_input(self, input: Dict[str, Any]) -> "ChatType": """The input is formatted as a `ChatType` assuming that the messages provided are already formatted that way i.e. following the OpenAI chat format.""" @@ -213,7 +216,7 @@ def outputs(self) -> List[str]: return ["generation", "model_name"] def format_output( - self, output: Union[str, None], input: Dict[str, Any] + self, output: Union[str, None], input: Union[Dict[str, Any], None] = None ) -> Dict[str, Any]: """The output is formatted as a dictionary with the `generation`. The `model_name` will be automatically included within the `process` method of `Task`.""" diff --git a/src/distilabel/steps/tasks/ultrafeedback.py b/src/distilabel/steps/tasks/ultrafeedback.py index eec232aabd..dae68bb48f 100644 --- a/src/distilabel/steps/tasks/ultrafeedback.py +++ b/src/distilabel/steps/tasks/ultrafeedback.py @@ -12,14 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import importlib.resources as importlib_resources import re -import sys - -if sys.version_info < (3, 9): - import importlib_resources -else: - import importlib.resources as importlib_resources - from typing import Any, Dict, List, Literal, Optional, Union import orjson @@ -264,7 +258,7 @@ def outputs(self) -> List[str]: return columns + ["model_name"] def format_output( - self, output: Union[str, None], input: Dict[str, Any] + self, output: Union[str, None], input: Union[Dict[str, Any], None] = None ) -> Dict[str, Any]: """The output is formatted as a dictionary with the `ratings` and `rationales` for each of the provided `generations` for the given `instruction`. The `model_name` @@ -281,12 +275,15 @@ def format_output( `ratings`, and `rationales-for-ratings` for each of the provided `generations` for the given `instruction` if the provided aspect is either `helpfulness` or `truthfulness`. """ + assert input is not None, "Input is required to format the output." + if self.aspect in [ "honesty", "instruction-following", "overall-rating", ]: return self._format_ratings_rationales_output(output, input) + return self._format_types_ratings_rationales_output(output, input) def _format_ratings_rationales_output( @@ -450,7 +447,7 @@ class SchemaUltraFeedbackWithType(BaseModel): def _format_structured_output( self, output: str, input: Dict[str, Any] - ) -> Dict[str, str]: + ) -> Dict[str, Any]: """Parses the structured response, which should correspond to a dictionary with either `positive`, or `positive` and `negative` keys. diff --git a/src/distilabel/steps/tasks/urial.py b/src/distilabel/steps/tasks/urial.py new file mode 100644 index 0000000000..ed0e72d969 --- /dev/null +++ b/src/distilabel/steps/tasks/urial.py @@ -0,0 +1,125 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib.resources as importlib_resources +from typing import TYPE_CHECKING, Any, Dict, Union + +from jinja2 import Template + +from distilabel.steps.tasks import Task + +if TYPE_CHECKING: + from distilabel.steps.tasks.typing import ChatType + from distilabel.steps.typing import StepColumns + + +class URIAL(Task): + """Generates a response using a non-instruct fine-tuned model. + + `URIAL` is a pre-defined task that generates a response using a non-instruct fine-tuned + model. This task is used to generate a response based on the conversation provided as + input. + + Input columns: + - instruction (`str`, optional): The instruction to generate a response from. + - conversation (`List[Dict[str, str]]`, optional): The conversation to generate + a response from (the last message must be from the user). + + Output columns: + - generation (`str`): The generated response. + - model_name (`str`): The name of the model used to generate the response. + + Categories: + - text-generation + + References: + - [The Unlocking Spell on Base LLMs: Rethinking Alignment via In-Context Learning](https://arxiv.org/abs/2312.01552) + + Examples: + + Generate text from an instruction: + + ```python + from distilabel.llms import vLLM + from distilabel.steps.tasks import URIAL + + step = URIAL( + llm=vLLM( + model="meta-llama/Meta-Llama-3.1-8B", + generation_kwargs={"temperature": 0.7}, + ), + ) + + step.load() + + results = next( + step.process(inputs=[{"instruction": "What's the most most common type of cloud?"}]) + ) + # [ + # { + # 'instruction': "What's the most most common type of cloud?", + # 'generation': 'Clouds are classified into three main types, high, middle, and low. The most common type of cloud is the middle cloud.', + # 'distilabel_metadata': {...}, + # 'model_name': 'meta-llama/Meta-Llama-3.1-8B' + # } + # ] + ``` + """ + + def load(self) -> None: + """Loads the Jinja2 template for the given `aspect`.""" + super().load() + + _path = str( + importlib_resources.files("distilabel") + / "steps" + / "tasks" + / "templates" + / "urial.jinja2" + ) + + self._template = Template(open(_path).read()) + + @property + def inputs(self) -> "StepColumns": + return {"instruction": False, "conversation": False} + + def format_input(self, input: Dict[str, Any]) -> "ChatType": + messages = ( + [{"role": "user", "content": input["instruction"]}] + if "instruction" in input + else input["conversation"] + ) + + if messages[-1]["role"] != "user": + raise ValueError("The last message must be from the user.") + + return [{"role": "user", "content": self._template.render(messages=messages)}] + + @property + def outputs(self) -> "StepColumns": + return ["generation", "model_name"] + + def format_output( + self, output: Union[str, None], input: Union[Dict[str, Any], None] = None + ) -> Dict[str, Any]: + if output is None: + return {"generation": None} + + response = output.split("\n\n# User")[0] + if response.startswith("\n\n"): + response = response[2:] + response = response.strip() + + return {"generation": response} diff --git a/tests/unit/steps/tasks/test_text_generation.py b/tests/unit/steps/tasks/test_text_generation.py index dd43530cd9..2ed399b237 100644 --- a/tests/unit/steps/tasks/test_text_generation.py +++ b/tests/unit/steps/tasks/test_text_generation.py @@ -21,11 +21,8 @@ class TestTextGeneration: def test_format_input(self) -> None: - pipeline = Pipeline(name="unit-test-pipeline") llm = DummyLLM() - task = TextGeneration( - name="task", llm=llm, pipeline=pipeline, use_system_prompt=False - ) + task = TextGeneration(name="task", llm=llm, use_system_prompt=False) assert task.format_input({"instruction": "test", "system_prompt": "test"}) == [ {"role": "user", "content": "test"} diff --git a/tests/unit/steps/tasks/test_urial.py b/tests/unit/steps/tasks/test_urial.py new file mode 100644 index 0000000000..2075d98e6e --- /dev/null +++ b/tests/unit/steps/tasks/test_urial.py @@ -0,0 +1,72 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from distilabel.steps.tasks.urial import URIAL + +from tests.unit.conftest import DummyLLM + + +class TestURIAL: + def test_format_input(self) -> None: + task = URIAL(llm=DummyLLM()) + task.load() + assert task.format_input({"instruction": "test"}) == [ + { + "role": "user", + "content": '# Instruction\n\nBelow is a list of conversations between a human and an AI assistant (you). \nUsers place their queries under "# User:", and your responses are under "# Assistant:".\nYou are a helpful, respectful, and honest assistant.\nYou should always answer as helpfully as possible while ensuring safety.\nYour answers should be well-structured and provide detailed information. They should also have an engaging tone.\nYour responses must not contain any fake, harmful, unethical, racist, sexist, toxic, dangerous, or illegal content, even if it may be helpful.\nYour response must be socially responsible, and thus you can refuse to answer some controversial topics.\n\n\n# User:\n\ntest\n\n# Assistant:', + } + ] + + def test_format_input_with_conversation(self) -> None: + task = URIAL(llm=DummyLLM()) + task.load() + assert task.format_input( + { + "conversation": [ + {"role": "user", "content": "test"}, + {"role": "assistant", "content": "test"}, + {"role": "user", "content": "test"}, + ] + } + ) == [ + { + "role": "user", + "content": '# Instruction\n\nBelow is a list of conversations between a human and an AI assistant (you). \nUsers place their queries under "# User:", and your responses are under "# Assistant:".\nYou are a helpful, respectful, and honest assistant.\nYou should always answer as helpfully as possible while ensuring safety.\nYour answers should be well-structured and provide detailed information. They should also have an engaging tone.\nYour responses must not contain any fake, harmful, unethical, racist, sexist, toxic, dangerous, or illegal content, even if it may be helpful.\nYour response must be socially responsible, and thus you can refuse to answer some controversial topics.\n\n\n# User:\n\ntest\n\n# Assistant:\n\ntest\n\n# User:\n\ntest\n\n# Assistant:', + } + ] + + def test_format_input_raise_valueerror(self) -> None: + task = URIAL(llm=DummyLLM()) + task.load() + + with pytest.raises(ValueError, match="The last message must be from the user."): + assert task.format_input( + { + "conversation": [ + {"role": "user", "content": "test"}, + {"role": "assistant", "content": "test"}, + ] + } + ) + + def test_format_output(self) -> None: + task = URIAL(llm=DummyLLM()) + task.load() + + assert task.format_output( + output=" \n\noutput\n\n# User:", input={"instruction": "test"} + ) == { + "generation": "output", + }