Skip to content

Commit

Permalink
Add URIAL task (#921)
Browse files Browse the repository at this point in the history
* Initial work for `URIAL`

* Update template

* Fix checking last message

* Add `format_output` logic

* Refine `format_output` and add docstring

* Add `References`

* Add `URIAL` unit tests
  • Loading branch information
gabrielmbmb authored Aug 22, 2024
1 parent 516909e commit 2a3906d
Show file tree
Hide file tree
Showing 7 changed files with 231 additions and 19 deletions.
2 changes: 2 additions & 0 deletions src/distilabel/steps/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from distilabel.steps.tasks.text_generation import ChatGeneration, TextGeneration
from distilabel.steps.tasks.typing import ChatItem, ChatType
from distilabel.steps.tasks.ultrafeedback import UltraFeedback
from distilabel.steps.tasks.urial import URIAL

__all__ = [
"GeneratorTask",
Expand Down Expand Up @@ -79,4 +80,5 @@
"ChatItem",
"ChatType",
"UltraFeedback",
"URIAL",
]
16 changes: 16 additions & 0 deletions src/distilabel/steps/tasks/templates/urial.jinja2
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Instruction

Below is a list of conversations between a human and an AI assistant (you).
Users place their queries under "# User:", and your responses are under "# Assistant:".
You are a helpful, respectful, and honest assistant.
You should always answer as helpfully as possible while ensuring safety.
Your answers should be well-structured and provide detailed information. They should also have an engaging tone.
Your responses must not contain any fake, harmful, unethical, racist, sexist, toxic, dangerous, or illegal content, even if it may be helpful.
Your response must be socially responsible, and thus you can refuse to answer some controversial topics.

{% for message in messages %}
# {{ message.role | capitalize }}:

{{ message.content }}
{% endfor %}
# Assistant:
15 changes: 9 additions & 6 deletions src/distilabel/steps/tasks/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,15 @@
# limitations under the License.

import warnings
from typing import Any, Dict, List, Union
from typing import TYPE_CHECKING, Any, Dict, List, Union

from distilabel.steps.tasks.base import Task
from distilabel.steps.tasks.typing import ChatType
from distilabel.utils.chat import is_openai_format

if TYPE_CHECKING:
from distilabel.steps.tasks.typing import ChatType
from distilabel.steps.typing import StepColumns


class TextGeneration(Task):
"""Simple text generation with an `LLM` given an instruction.
Expand Down Expand Up @@ -78,11 +81,11 @@ class TextGeneration(Task):
use_system_prompt: bool = True

@property
def inputs(self) -> List[str]:
def inputs(self) -> "StepColumns":
"""The input for the task is the `instruction`."""
return ["instruction"]

def format_input(self, input: Dict[str, Any]) -> ChatType:
def format_input(self, input: Dict[str, Any]) -> "ChatType":
"""The input is formatted as a `ChatType` assuming that the instruction
is the first interaction from the user within a conversation."""

Expand Down Expand Up @@ -189,7 +192,7 @@ def inputs(self) -> List[str]:
"""The input for the task are the `messages`."""
return ["messages"]

def format_input(self, input: Dict[str, Any]) -> ChatType:
def format_input(self, input: Dict[str, Any]) -> "ChatType":
"""The input is formatted as a `ChatType` assuming that the messages provided
are already formatted that way i.e. following the OpenAI chat format."""

Expand All @@ -213,7 +216,7 @@ def outputs(self) -> List[str]:
return ["generation", "model_name"]

def format_output(
self, output: Union[str, None], input: Dict[str, Any]
self, output: Union[str, None], input: Union[Dict[str, Any], None] = None
) -> Dict[str, Any]:
"""The output is formatted as a dictionary with the `generation`. The `model_name`
will be automatically included within the `process` method of `Task`."""
Expand Down
15 changes: 6 additions & 9 deletions src/distilabel/steps/tasks/ultrafeedback.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import importlib.resources as importlib_resources
import re
import sys

if sys.version_info < (3, 9):
import importlib_resources
else:
import importlib.resources as importlib_resources

from typing import Any, Dict, List, Literal, Optional, Union

import orjson
Expand Down Expand Up @@ -264,7 +258,7 @@ def outputs(self) -> List[str]:
return columns + ["model_name"]

def format_output(
self, output: Union[str, None], input: Dict[str, Any]
self, output: Union[str, None], input: Union[Dict[str, Any], None] = None
) -> Dict[str, Any]:
"""The output is formatted as a dictionary with the `ratings` and `rationales` for
each of the provided `generations` for the given `instruction`. The `model_name`
Expand All @@ -281,12 +275,15 @@ def format_output(
`ratings`, and `rationales-for-ratings` for each of the provided `generations` for the
given `instruction` if the provided aspect is either `helpfulness` or `truthfulness`.
"""
assert input is not None, "Input is required to format the output."

if self.aspect in [
"honesty",
"instruction-following",
"overall-rating",
]:
return self._format_ratings_rationales_output(output, input)

return self._format_types_ratings_rationales_output(output, input)

def _format_ratings_rationales_output(
Expand Down Expand Up @@ -450,7 +447,7 @@ class SchemaUltraFeedbackWithType(BaseModel):

def _format_structured_output(
self, output: str, input: Dict[str, Any]
) -> Dict[str, str]:
) -> Dict[str, Any]:
"""Parses the structured response, which should correspond to a dictionary
with either `positive`, or `positive` and `negative` keys.
Expand Down
125 changes: 125 additions & 0 deletions src/distilabel/steps/tasks/urial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import importlib.resources as importlib_resources
from typing import TYPE_CHECKING, Any, Dict, Union

from jinja2 import Template

from distilabel.steps.tasks import Task

if TYPE_CHECKING:
from distilabel.steps.tasks.typing import ChatType
from distilabel.steps.typing import StepColumns


class URIAL(Task):
"""Generates a response using a non-instruct fine-tuned model.
`URIAL` is a pre-defined task that generates a response using a non-instruct fine-tuned
model. This task is used to generate a response based on the conversation provided as
input.
Input columns:
- instruction (`str`, optional): The instruction to generate a response from.
- conversation (`List[Dict[str, str]]`, optional): The conversation to generate
a response from (the last message must be from the user).
Output columns:
- generation (`str`): The generated response.
- model_name (`str`): The name of the model used to generate the response.
Categories:
- text-generation
References:
- [The Unlocking Spell on Base LLMs: Rethinking Alignment via In-Context Learning](https://arxiv.org/abs/2312.01552)
Examples:
Generate text from an instruction:
```python
from distilabel.llms import vLLM
from distilabel.steps.tasks import URIAL
step = URIAL(
llm=vLLM(
model="meta-llama/Meta-Llama-3.1-8B",
generation_kwargs={"temperature": 0.7},
),
)
step.load()
results = next(
step.process(inputs=[{"instruction": "What's the most most common type of cloud?"}])
)
# [
# {
# 'instruction': "What's the most most common type of cloud?",
# 'generation': 'Clouds are classified into three main types, high, middle, and low. The most common type of cloud is the middle cloud.',
# 'distilabel_metadata': {...},
# 'model_name': 'meta-llama/Meta-Llama-3.1-8B'
# }
# ]
```
"""

def load(self) -> None:
"""Loads the Jinja2 template for the given `aspect`."""
super().load()

_path = str(
importlib_resources.files("distilabel")
/ "steps"
/ "tasks"
/ "templates"
/ "urial.jinja2"
)

self._template = Template(open(_path).read())

@property
def inputs(self) -> "StepColumns":
return {"instruction": False, "conversation": False}

def format_input(self, input: Dict[str, Any]) -> "ChatType":
messages = (
[{"role": "user", "content": input["instruction"]}]
if "instruction" in input
else input["conversation"]
)

if messages[-1]["role"] != "user":
raise ValueError("The last message must be from the user.")

return [{"role": "user", "content": self._template.render(messages=messages)}]

@property
def outputs(self) -> "StepColumns":
return ["generation", "model_name"]

def format_output(
self, output: Union[str, None], input: Union[Dict[str, Any], None] = None
) -> Dict[str, Any]:
if output is None:
return {"generation": None}

response = output.split("\n\n# User")[0]
if response.startswith("\n\n"):
response = response[2:]
response = response.strip()

return {"generation": response}
5 changes: 1 addition & 4 deletions tests/unit/steps/tasks/test_text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,8 @@

class TestTextGeneration:
def test_format_input(self) -> None:
pipeline = Pipeline(name="unit-test-pipeline")
llm = DummyLLM()
task = TextGeneration(
name="task", llm=llm, pipeline=pipeline, use_system_prompt=False
)
task = TextGeneration(name="task", llm=llm, use_system_prompt=False)

assert task.format_input({"instruction": "test", "system_prompt": "test"}) == [
{"role": "user", "content": "test"}
Expand Down
72 changes: 72 additions & 0 deletions tests/unit/steps/tasks/test_urial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
from distilabel.steps.tasks.urial import URIAL

from tests.unit.conftest import DummyLLM


class TestURIAL:
def test_format_input(self) -> None:
task = URIAL(llm=DummyLLM())
task.load()
assert task.format_input({"instruction": "test"}) == [
{
"role": "user",
"content": '# Instruction\n\nBelow is a list of conversations between a human and an AI assistant (you). \nUsers place their queries under "# User:", and your responses are under "# Assistant:".\nYou are a helpful, respectful, and honest assistant.\nYou should always answer as helpfully as possible while ensuring safety.\nYour answers should be well-structured and provide detailed information. They should also have an engaging tone.\nYour responses must not contain any fake, harmful, unethical, racist, sexist, toxic, dangerous, or illegal content, even if it may be helpful.\nYour response must be socially responsible, and thus you can refuse to answer some controversial topics.\n\n\n# User:\n\ntest\n\n# Assistant:',
}
]

def test_format_input_with_conversation(self) -> None:
task = URIAL(llm=DummyLLM())
task.load()
assert task.format_input(
{
"conversation": [
{"role": "user", "content": "test"},
{"role": "assistant", "content": "test"},
{"role": "user", "content": "test"},
]
}
) == [
{
"role": "user",
"content": '# Instruction\n\nBelow is a list of conversations between a human and an AI assistant (you). \nUsers place their queries under "# User:", and your responses are under "# Assistant:".\nYou are a helpful, respectful, and honest assistant.\nYou should always answer as helpfully as possible while ensuring safety.\nYour answers should be well-structured and provide detailed information. They should also have an engaging tone.\nYour responses must not contain any fake, harmful, unethical, racist, sexist, toxic, dangerous, or illegal content, even if it may be helpful.\nYour response must be socially responsible, and thus you can refuse to answer some controversial topics.\n\n\n# User:\n\ntest\n\n# Assistant:\n\ntest\n\n# User:\n\ntest\n\n# Assistant:',
}
]

def test_format_input_raise_valueerror(self) -> None:
task = URIAL(llm=DummyLLM())
task.load()

with pytest.raises(ValueError, match="The last message must be from the user."):
assert task.format_input(
{
"conversation": [
{"role": "user", "content": "test"},
{"role": "assistant", "content": "test"},
]
}
)

def test_format_output(self) -> None:
task = URIAL(llm=DummyLLM())
task.load()

assert task.format_output(
output=" \n\noutput\n\n# User:", input={"instruction": "test"}
) == {
"generation": "output",
}

0 comments on commit 2a3906d

Please sign in to comment.