Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add approach to flexibly inject system_messages #1087

Draft
wants to merge 2 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions src/distilabel/steps/tasks/argilla_labeller.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class ArgillaLabeller(Task):
- question (`Optional[Dict[str, Any]]`): The question settings for the question to be answered.
- example_records (`Optional[List[Dict[str, Any]]]`): The few shot example records with responses to be used to answer the question.
- guidelines (`Optional[str]`): The guidelines for the annotation task.

- system_prompt (`Optional[str]`): The system prompt for the annotation task.
Output columns:
- suggestion (`Dict[str, Any]`): The final suggestion for annotation.

Expand Down Expand Up @@ -205,7 +205,7 @@ class ArgillaLabeller(Task):
```
"""

system_prompt: str = (
system_prompt: Optional[str] = Field(
"You are an expert annotator and labelling assistant that understands complex domains and natural language processing. "
"You are given input fields and a question. "
"You should create a valid JSON object as an response to the question based on the input fields. "
Expand Down Expand Up @@ -270,6 +270,7 @@ def inputs(self) -> Dict[str, bool]:
"question": False,
"example_records": False,
"guidelines": False,
"system_prompt": False,
}

def _format_record(
Expand Down Expand Up @@ -421,7 +422,9 @@ def format_input(
)

messages = []
if self.system_prompt:
if "system_prompt" in input:
messages.append({"role": "system", "content": input["system_prompt"]})
elif self.system_prompt:
messages.append({"role": "system", "content": self.system_prompt})
messages.append({"role": "user", "content": prompt})
return messages
Expand Down
5 changes: 5 additions & 0 deletions src/distilabel/steps/tasks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,11 @@ class Task(_Task, Step):
num_generations: The number of generations to be produced per input.
"""

system_prompt: Optional[str] = Field(
default=None,
description="The system prompt for the task.",
)

@abstractmethod
def format_input(self, input: Dict[str, Any]) -> "FormattedInput":
"""Abstract method to format the inputs of the task. It needs to receive an input
Expand Down
19 changes: 14 additions & 5 deletions src/distilabel/steps/tasks/clair.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,13 @@ class CLAIR(Task):
CLAIR uses an AI system to minimally revise a solution A→A´ such that the resulting
preference A `preferred` A’ is much more contrastive and precise.

Attributes:
system_prompt: The system prompt for the CLAIR task.

Input columns:
- task (`str`): The task or instruction.
- student_solution (`str`): An answer to the task that is to be revised.
- system_prompt (`Optional[str]`): The system prompt for the CLAIR task.

Output columns:
- revision (`str`): The revised text.
Expand Down Expand Up @@ -127,7 +131,7 @@ def load(self) -> None:

@property
def inputs(self) -> "StepColumns":
return ["task", "student_solution"]
return {"task": True, "student_solution": True, "system_prompt": False}

@property
def outputs(self) -> "StepColumns":
Expand All @@ -136,15 +140,20 @@ def outputs(self) -> "StepColumns":
def format_input(self, input: Dict[str, Any]) -> "ChatType":
"""The input is formatted as a `ChatType` assuming that the instruction
is the first interaction from the user within a conversation."""
return [
{"role": "system", "content": self.system_prompt},
messages = []
if "system_prompt" in input:
messages.append({"role": "system", "content": input["system_prompt"]})
elif self.system_prompt:
messages.append({"role": "system", "content": self.system_prompt})
messages.append(
{
"role": "user",
"content": self._template.render(
task=input["task"], student_solution=input["student_solution"]
),
},
]
}
)
return messages

def format_output(
self, output: Union[str, None], input: Dict[str, Any]
Expand Down
12 changes: 10 additions & 2 deletions src/distilabel/steps/tasks/complexity_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,12 @@ class ComplexityScorer(Task):
in Instruction Tuning'.

Attributes:
system_prompt: The system prompt for the complexity scorer task.
_template: a Jinja2 template used to format the input for the LLM.

Input columns:
- instructions (`List[str]`): The list of instructions to be scored.
- system_prompt (`Optional[str]`): The system prompt for the complexity scorer task.

Output columns:
- scores (`List[float]`): The score for each instruction.
Expand Down Expand Up @@ -151,12 +153,18 @@ def inputs(self) -> List[str]:
def format_input(self, input: Dict[str, Any]) -> "ChatType":
"""The input is formatted as a `ChatType` assuming that the instruction
is the first interaction from the user within a conversation."""
return [
messages = []
if "system_prompt" in input:
messages.append({"role": "system", "content": input["system_prompt"]})
elif self.system_prompt:
messages.append({"role": "system", "content": self.system_prompt})
messages.append(
{
"role": "user",
"content": self._template.render(instructions=input["instructions"]), # type: ignore
}
]
)
return messages

@property
def outputs(self) -> List[str]:
Expand Down
22 changes: 18 additions & 4 deletions src/distilabel/steps/tasks/generate_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,12 @@ class GenerateEmbeddings(Step):

Attributes:
llm: The `LLM` to use to generate the embeddings.
system_prompt: The system prompt for the embedding generation task.

Input columns:
- text (`str`, `List[Dict[str, str]]`): The input text or conversation to generate
embeddings for.
- system_prompt (`Optional[str]`): The system prompt for the embedding generation task.

Output columns:
- embedding (`List[float]`): The embedding of the input text or conversation.
Expand Down Expand Up @@ -101,7 +103,7 @@ def load(self) -> None:
def inputs(self) -> "StepColumns":
"""The inputs for the task is a `text` column containing either a string or a
list of dictionaries in OpenAI chat-like format."""
return ["text"]
return {"text": True, "system_prompt": False}

@property
def outputs(self) -> "StepColumns":
Expand All @@ -120,14 +122,26 @@ def format_input(self, input: Dict[str, Any]) -> "ChatType":
Returns:
The OpenAI chat-like format of the input.
"""
text = input["text"] = input["text"]
text = input["text"]

# input is in `ChatType` format
if isinstance(text, str):
return [{"role": "user", "content": text}]
messages = []
if "system_prompt" in input:
messages.append({"role": "system", "content": input["system_prompt"]})
elif self.system_prompt:
messages.append({"role": "system", "content": self.system_prompt})
messages.append({"role": "user", "content": text})
return messages

if is_openai_format(text):
return text
messages = []
if "system_prompt" in input:
messages.append({"role": "system", "content": input["system_prompt"]})
elif self.system_prompt:
messages.append({"role": "system", "content": self.system_prompt})
messages.append({"role": "user", "content": text})
return messages

raise DistilabelUserError(
f"Couldn't format input for step {self.name}. The `text` input column has to"
Expand Down
14 changes: 11 additions & 3 deletions src/distilabel/steps/tasks/genstruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,13 @@ class Genstruct(Task):
for this specific task.

Attributes:
system_prompt: The system prompt for the instruction generation task.
_template: a Jinja2 template used to format the input for the LLM.

Input columns:
- title (`str`): The title of the document.
- content (`str`): The content of the document.
- system_prompt (`Optional[str]`): The system prompt for the instruction generation task.

Output columns:
- user (`str`): The user's instruction based on the document.
Expand Down Expand Up @@ -136,19 +138,25 @@ def load(self) -> None:
@property
def inputs(self) -> List[str]:
"""The inputs for the task are the `title` and the `content`."""
return ["title", "content"]
return {"title": True, "content": True, "system_prompt": False}

def format_input(self, input: Dict[str, Any]) -> "ChatType":
"""The input is formatted as a `ChatType` assuming that the instruction
is the first interaction from the user within a conversation."""
return [
messages = []
if "system_prompt" in input:
messages.append({"role": "system", "content": input["system_prompt"]})
elif self.system_prompt:
messages.append({"role": "system", "content": self.system_prompt})
messages.append(
{
"role": "user",
"content": self._template.render( # type: ignore
title=input["title"], content=input["content"]
),
}
]
)
return messages

@property
def outputs(self) -> List[str]:
Expand Down
46 changes: 37 additions & 9 deletions src/distilabel/steps/tasks/improving_text_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ class _EmbeddingDataGeneration(_JSONFormatter, Task, ABC):
keeping the `format_input` as an abstract method to be implemented in the subclasses.

Attributes:
system_prompt: The system prompt for the task.
seed: The random seed to be set in case there's any sampling within the `format_input` method.
_template: The Jinja2 template to be rendered within the `format_input` method with the
provided arguments.
Expand Down Expand Up @@ -160,6 +161,7 @@ class _EmbeddingDataGenerator(_JSONFormatter, GeneratorTask, ABC):
keeping the `format_input` as an abstract method to be implemented in the subclasses.

Attributes:
system_prompt: The system prompt for the task.
seed: The random seed to be set in case there's any sampling within the `format_input` method.
_template: The Jinja2 template to be rendered within the `format_input` method with the
provided arguments.
Expand Down Expand Up @@ -247,6 +249,7 @@ class EmbeddingTaskGenerator(GeneratorTask):
- `text-classification`: Generate task descriptions for text classification tasks.

Attributes:
system_prompt: The system prompt for the task.
category: The category of the task to be generated, which can either be `text-retrieval`,
`text-matching-short`, `text-matching-long`, or `text-classification`.
flatten_tasks: Whether to flatten the tasks i.e. since a list of tasks is generated by the
Expand Down Expand Up @@ -504,7 +507,12 @@ def format_input(self, input: Dict[str, Any]) -> ChatType:
Returns:
A list with a single chat containing the user's message with the rendered `_template`.
"""
return [
messages = []
if "system_prompt" in input:
messages.append({"role": "system", "content": input["system_prompt"]})
elif self.system_prompt:
messages.append({"role": "system", "content": self.system_prompt})
messages.append(
{
"role": "user",
"content": self._template.render( # type: ignore
Expand All @@ -526,7 +534,8 @@ def format_input(self, input: Dict[str, Any]) -> ChatType:
or random.choice([50, 100, 200, 300, 400, 500]),
).strip(),
}
]
)
return messages

@property
def keys(self) -> List[str]:
Expand Down Expand Up @@ -612,15 +621,21 @@ def format_input(self, input: Dict[str, Any]) -> ChatType:
Returns:
A list with a single chat containing the user's message with the rendered `_template`.
"""
return [
messages = []
if "system_prompt" in input:
messages.append({"role": "system", "content": input["system_prompt"]})
elif self.system_prompt:
messages.append({"role": "system", "content": self.system_prompt})
messages.append(
{
"role": "user",
"content": self._template.render( # type: ignore
task=input["task"],
language=self.language,
).strip(),
}
]
)
return messages

@property
def keys(self) -> List[str]:
Expand Down Expand Up @@ -702,15 +717,21 @@ def format_input(self, input: Dict[str, Any]) -> ChatType:
Returns:
A list with a single chat containing the user's message with the rendered `_template`.
"""
return [
messages = []
if "system_prompt" in input:
messages.append({"role": "system", "content": input["system_prompt"]})
elif self.system_prompt:
messages.append({"role": "system", "content": self.system_prompt})
messages.append(
{
"role": "user",
"content": self._template.render( # type: ignore
task=input["task"],
language=self.language,
).strip(),
}
]
)
return messages

@property
def keys(self) -> List[str]:
Expand Down Expand Up @@ -739,9 +760,10 @@ class GenerateTextClassificationData(_EmbeddingDataGeneration):
clarity: The clarity of the query to be generated, which can be `clear`, `understandable with some effort`,
or `ambiguous`. Defaults to `None`, meaning that it will be randomly sampled.
seed: The random seed to be set in case there's any sampling within the `format_input` method.

system_prompt: The system prompt for the text classification task.
Input columns:
- task (`str`): The task description to be used in the generation.
- system_prompt (`Optional[str]`): The system prompt for the text classification task.

Output columns:
- input_text (`str`): the input text generated by the `LLM`.
Expand Down Expand Up @@ -803,7 +825,12 @@ def format_input(self, input: Dict[str, Any]) -> ChatType:
Returns:
A list with a single chat containing the user's message with the rendered `_template`.
"""
return [
messages = []
if "system_prompt" in input:
messages.append({"role": "system", "content": input["system_prompt"]})
elif self.system_prompt:
messages.append({"role": "system", "content": self.system_prompt})
messages.append(
{
"role": "user",
"content": self._template.render( # type: ignore
Expand All @@ -817,7 +844,8 @@ def format_input(self, input: Dict[str, Any]) -> ChatType:
),
).strip(),
}
]
)
return messages

@property
def keys(self) -> List[str]:
Expand Down
Loading
Loading