Skip to content

Commit

Permalink
Add GenerateSentencePair task (#689)
Browse files Browse the repository at this point in the history
* Add `GenerateSentencePair` task

* Update task to use system prompt

* Fix `setup_logging` file location

* Update `add_raw_output` to be `RuntimeParamater` and `True` by default

* Fix system prompt for negative sentences

* Add `GenerateSentencePair` unit tests

* Fix unit tests after updating `add_raw_output`

* Update docs to mention `add_raw_output` attribute

* Update `add_raw_output` description

Co-authored-by: alvarobartt <[email protected]>

* Fix columns

Co-authored-by: alvarobartt <[email protected]>

* Add missing docstrings

* Fix tests

* Add `answer` generation action

* Fix examples not being correctly rendered

* Add examples

---------

Co-authored-by: alvarobartt <[email protected]>
  • Loading branch information
gabrielmbmb and alvarobartt authored Jun 4, 2024
1 parent e61b598 commit e4a9609
Show file tree
Hide file tree
Showing 17 changed files with 473 additions and 17 deletions.
14 changes: 13 additions & 1 deletion docs/sections/learn/tutorial/task/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ The subclasses of [`Task`][distilabel.steps.tasks.Task] are intended to be used

For example, the most basic task is the [`TextGeneration`][distilabel.steps.tasks.TextGeneration] task, which generates text based on a given instruction, and it can be used standalone as well as within a [`Pipeline`][distilabel.pipeline.Pipeline].

```python
```python
from distilabel.steps.tasks import TextGeneration

Expand All @@ -18,12 +19,23 @@ task = TextGeneration(
task.load()

next(task.process([{"instruction": "What's the capital of Spain?"}]))
# [{'instruction': "What's the capital of Spain?", "generation": "The capital of Spain is Madrid.", "model_name": "gpt-4"}]
# [
# {
# "instruction": "What's the capital of Spain?",
# "generation": "The capital of Spain is Madrid.",
# "model_name": "gpt-4",
# "distilabel_metadata": {
# "raw_output_text-generation": "The capital of Spain is Madrid"
# }
# }
# ]
```

!!! NOTE
The `load` method needs to be called ALWAYS if using the tasks as standalone, otherwise, if the [`Pipeline`][distilabel.pipeline.Pipeline] context manager is used, there's no need to call that method, since it will be automatically called on `Pipeline.run`; but in any other case the method `load` needs to be called from the parent class e.g. a [`Task`][distilabel.steps.tasks.Task] with an [`LLM`][distilabel.llms.LLM] will need to call `Task.load` to load both the task and the LLM.

As we can see in the comment of the code snippet above, the task has enriched the input dictionaries adding the `generation`, the `model_name` that was used to generate, and finally the `distilabel_metadata` dictionary that contains the raw output (without post-processing) from the LLM. In this case, the `TextGeneration` task does no post-processing, so the `generation` and the raw output is the same, but some other tasks do post-processing, which in some situations it can fail. That's why is useful to have the raw output available in the `distilabel_metadata` dictionary. If this default behaviour is not desired, then all the `Task`s has a `add_raw_output` attribute that we can set to `False` when creating the instance of the task or at run time.

## Defining custom Tasks

In order to define custom tasks, we need to inherit from the [`Task`][distilabel.steps.tasks.Task] class and implement the `format_input` and `format_output` methods, as well as setting the properties `inputs` and `outputs`, as for [`Step`][distilabel.steps.Step] subclasses.
Expand Down
4 changes: 2 additions & 2 deletions src/distilabel/mixins/runtime_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,13 @@ def set_runtime_parameters(self, runtime_parameters: Dict[str, Any]) -> None:
name, runtime_parameters_names, cutoff=0.5
)
msg = (
f"⚠️ Runtime parameter '{name}' unknown in step '{self.name}'."
f"⚠️ Runtime parameter '{name}' unknown in step '{self.name}'." # type: ignore
)
if closest:
msg += f" Did you mean any of: {closest}"
else:
msg += f" Available runtime parameters for the step: {runtime_parameters_names}."
self.pipeline._logger.warning(msg)
self.pipeline._logger.warning(msg) # type: ignore
continue

attr = getattr(self, name)
Expand Down
13 changes: 10 additions & 3 deletions src/distilabel/pipeline/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,11 +298,18 @@ def run(
The `Distiset` created by the pipeline.
"""

setup_logging(**self._logging_parameters)

# Set the runtime parameters that will be used during the pipeline execution
# Set the runtime parameters that will be used during the pipeline execution.
# They are used to generate the signature of the pipeline that is used to hit the
# cache when the pipeline is run, so it's important to do it first.
self._set_runtime_parameters(parameters or {})

setup_logging(
**{
**self._logging_parameters,
"filename": str(self._cache_location["log_file"]),
}
)

# Validate the pipeline DAG to check that all the steps are chainable, there are
# no missing runtime parameters, batch sizes are correct, etc.
self.dag.validate()
Expand Down
10 changes: 6 additions & 4 deletions src/distilabel/steps/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,15 @@
from distilabel.steps.tasks.prometheus_eval import PrometheusEval
from distilabel.steps.tasks.quality_scorer import QualityScorer
from distilabel.steps.tasks.self_instruct import SelfInstruct
from distilabel.steps.tasks.sentence_transformers import GenerateSentencePair
from distilabel.steps.tasks.structured_generation import StructuredGeneration
from distilabel.steps.tasks.text_generation import ChatGeneration, TextGeneration
from distilabel.steps.tasks.typing import ChatItem, ChatType
from distilabel.steps.tasks.ultrafeedback import UltraFeedback

__all__ = [
"Task",
"GeneratorTask",
"ChatGeneration",
"ChatItem",
"ChatType",
"Task",
"ComplexityScorer",
"EvolInstruct",
"EvolComplexity",
Expand All @@ -54,7 +52,11 @@
"PrometheusEval",
"QualityScorer",
"SelfInstruct",
"GenerateSentencePair",
"StructuredGeneration",
"ChatGeneration",
"TextGeneration",
"ChatItem",
"ChatType",
"UltraFeedback",
]
8 changes: 7 additions & 1 deletion src/distilabel/steps/tasks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,13 @@ class _Task(_Step, ABC):
llm: LLM

group_generations: bool = False
add_raw_output: bool = False
add_raw_output: RuntimeParameter[bool] = Field(
default=True,
description=(
"Whether to include the raw output of the LLM in the key `raw_output_<TASK_NAME>`"
" of the `distilabel_metadata` dictionary output column"
),
)
num_generations: RuntimeParameter[int] = Field(
default=1, description="The number of generations to be produced per input."
)
Expand Down
254 changes: 254 additions & 0 deletions src/distilabel/steps/tasks/sentence_transformers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,254 @@
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import re
import sys
from typing import TYPE_CHECKING, Any, Dict, Final, List, Literal, Optional, Union

from jinja2 import Template

from distilabel.steps.tasks.base import Task

if sys.version_info < (3, 9):
import importlib_resources
else:
import importlib.resources as importlib_resources

if TYPE_CHECKING:
from distilabel.steps.tasks.typing import ChatType

GenerationAction = Literal["paraphrase", "semantically-similar", "query", "answer"]

POSITIVE_NEGATIVE_PAIR_REGEX = re.compile(
r"## Positive\s+(.*?)(?:\s+## Negative\s+(.*?))?\s*$",
re.DOTALL,
)

GENERATION_ACTION_SENTENCES: Final[Dict[GenerationAction, str]] = {
"paraphrase": "paraphrase",
"semantically-similar": "be semantically similar to",
"query": "be a query for",
"answer": "be an answer for",
}

POSITIVE_SYSTEM_PROMPT: str = (
"Your task is to generate a positive sentence given an anchor sentence. The positive"
" sentence has to {action_sentence} the anchor sentence. You must output only one new"
" section: `## Positive`."
)

POSITIVE_NEGATIVE_SYSTEM_PROMPT: str = (
"Your task is to generate a positive and a negative sentence given an anchor sentence."
" The positive sentence has to {action_sentence} the anchor sentence, while the negative"
" sentence can use similar words but must not be related to the anchor sentence. You"
" must output only two new sections: `## Positive` and `## Negative`."
)


class GenerateSentencePair(Task):
"""Generate a positive and negative (optionally) sentences given an anchor sentence.
`GenerateSentencePair` is a pre-defined task that given an anchor sentence generates
a positive sentence related to the anchor and optionally a negative sentence unrelated
to the anchor. This task is useful to generate training datasets for training embeddings
models.
Attributes:
triplet: a flag to indicate if the task should generate a triplet of sentences
(anchor, positive, negative). Defaults to `False`.
action: the action to perform to generate the positive sentence.
Input columns:
- anchor (`str`): The anchor sentence to generate the positive and negative sentences.
Output columns:
- positive (`str`): The positive sentence related to the `anchor`.
- negative (`str`): The negative sentence unrelated to the `anchor` if `triplet=True`.
- model_name (`str`): The name of the model that was used to generate the sentences.
Categories:
- embedding
Examples:
Paraphrasing:
```python
from distilabel.steps.tasks import GenerateSentencePair
from distilabel.llms import InferenceEndpointsLLM
generate_sentence_pair = GenerateSentencePair(
triplet=True, # `False` to generate only positive
action="paraphrase",
llm=InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3-70B-Instruct",
tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct",
),
input_batch_size=10,
)
generate_sentence_pair.load()
result = generate_sentence_pair.process([{"anchor": "What Game of Thrones villain would be the most likely to give you mercy?"}])
```
Generating semantically similar sentences:
```python
from distilabel.llms import InferenceEndpointsLLM
from distilabel.steps.tasks import GenerateSentencePair
generate_sentence_pair = GenerateSentencePair(
triplet=True, # `False` to generate only positive
action="semantically-similar",
llm=InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3-70B-Instruct",
tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct",
),
input_batch_size=10,
)
generate_sentence_pair.load()
result = generate_sentence_pair.process([{"anchor": "How does 3D printing work?"}])
```
Generating queries:
```python
from distilabel.steps.tasks import GenerateSentencePair
from distilabel.llms import InferenceEndpointsLLM
generate_sentence_pair = GenerateSentencePair(
triplet=True, # `False` to generate only positive
action="query",
llm=InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3-70B-Instruct",
tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct",
),
input_batch_size=10,
)
generate_sentence_pair.load()
result = generate_sentence_pair.process([{"anchor": "Argilla is an open-source data curation platform for LLMs. Using Argilla, ..."}])
```
Generating answers:
```python
from distilabel.steps.tasks import GenerateSentencePair
from distilabel.llms import InferenceEndpointsLLM
generate_sentence_pair = GenerateSentencePair(
triplet=True, # `False` to generate only positive
action="answer",
llm=InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3-70B-Instruct",
tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct",
),
input_batch_size=10,
)
generate_sentence_pair.load()
result = generate_sentence_pair.process([{"anchor": "What Game of Thrones villain would be the most likely to give you mercy?"}])
```
"""

triplet: bool = False
action: GenerationAction

def load(self) -> None:
"""Loads the Jinja2 template."""
super().load()

_path = str(
importlib_resources.files("distilabel")
/ "steps"
/ "tasks"
/ "templates"
/ "generate-sentence-pair.jinja2"
)

self._template = Template(open(_path).read())

@property
def inputs(self) -> List[str]:
"""The inputs for the task is the `anchor` sentence."""
return ["anchor"]

def format_input(self, input: Dict[str, Any]) -> "ChatType":
"""The inputs are formatted as a `ChatType`, with a system prompt describing the
task of generating a positive and negative sentences for the anchor sentence. The
anchor is provided as the first user interaction in the conversation.
Args:
input: The input containing the `anchor` sentence.
Returns:
A list of dictionaries containing the system and user interactions.
"""
action_sentence = GENERATION_ACTION_SENTENCES[self.action]
system_prompt = (
POSITIVE_NEGATIVE_SYSTEM_PROMPT if self.triplet else POSITIVE_SYSTEM_PROMPT
).format(action_sentence=action_sentence)

return [
{"role": "system", "content": system_prompt},
{"role": "user", "content": self._template.render(anchor=input["anchor"])},
]

@property
def outputs(self) -> List[str]:
"""The outputs for the task are the `positive` and `negative` sentences, as well
as the `model_name` used to generate the sentences."""
columns = ["positive", "negative"] if self.triplet else ["positive"]
columns += ["model_name"]
return columns

def format_output(
self, output: Union[str, None], input: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""Formats the output of the LLM, to extract the `positive` and `negative` sentences
generated. If the output is `None` or the regex doesn't match, then the outputs
will be set to `None` as well.
Args:
output: The output of the LLM.
input: The input used to generate the output.
Returns:
The formatted output containing the `positive` and `negative` sentences.
"""
if output is None:
return {"positive": None, "negative": None}

match = POSITIVE_NEGATIVE_PAIR_REGEX.match(output)
if match is None:
formatted_output = {"positive": None}
if self.triplet:
formatted_output["negative"] = None
return formatted_output

groups = match.groups()
if self.triplet:
return {
"positive": groups[0].strip(),
"negative": groups[1].strip()
if len(groups) > 1 and groups[1] is not None
else None,
}

return {"positive": groups[0].strip()}
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
## Anchor

{{ anchor }}

Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
{% for example_title, code in step.docstring.examples.items() %}
#### {{ example_title }}
```python
{{ code | e }}
{{ code | replace("\n", "\n") }}
```
{% endfor %}
{% endif %}
Expand Down
Loading

0 comments on commit e4a9609

Please sign in to comment.