Skip to content

Commit

Permalink
Add context to guide the generate sentence pair task if informed (#706)
Browse files Browse the repository at this point in the history
* Add context to guide the generate sentence pair task if informed

* Include example of how to add context to generate sentence pairs

* Invert order of anchor/context in prompt template
  • Loading branch information
plaguss authored Jun 10, 2024
1 parent 893cfa3 commit 23b3b41
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 14 deletions.
47 changes: 42 additions & 5 deletions src/distilabel/steps/tasks/sentence_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,31 +43,36 @@
}

POSITIVE_SYSTEM_PROMPT: str = (
"Your task is to generate a positive sentence given an anchor sentence. The positive"
"Your task is to generate a positive sentence given an anchor sentence.{context} The positive"
" sentence has to {action_sentence} the anchor sentence. You must output only one new"
" section: `## Positive`."
)

POSITIVE_NEGATIVE_SYSTEM_PROMPT: str = (
"Your task is to generate a positive and a negative sentence given an anchor sentence."
"Your task is to generate a positive and a negative sentence given an anchor sentence.{context}"
" The positive sentence has to {action_sentence} the anchor sentence, while the negative"
" sentence can use similar words but must not be related to the anchor sentence. You"
" must output only two new sections: `## Positive` and `## Negative`."
)

CONTEXT_INTRO: Final[str] = " Take into account the context given."


class GenerateSentencePair(Task):
"""Generate a positive and negative (optionally) sentences given an anchor sentence.
`GenerateSentencePair` is a pre-defined task that given an anchor sentence generates
a positive sentence related to the anchor and optionally a negative sentence unrelated
to the anchor. This task is useful to generate training datasets for training embeddings
to the anchor. Optionally, you can give a context to guide the LLM towards more specific
behavior. This task is useful to generate training datasets for training embeddings
models.
Attributes:
triplet: a flag to indicate if the task should generate a triplet of sentences
(anchor, positive, negative). Defaults to `False`.
action: the action to perform to generate the positive sentence.
context: the context to use for the generation. Can be helpful to guide the LLM
towards more specific context. Not used by default.
Input columns:
- anchor (`str`): The anchor sentence to generate the positive and negative sentences.
Expand Down Expand Up @@ -165,10 +170,33 @@ class GenerateSentencePair(Task):
result = generate_sentence_pair.process([{"anchor": "What Game of Thrones villain would be the most likely to give you mercy?"}])
```
Generating queries with context (**applies to every action**):
```python
from distilabel.steps.tasks import GenerateSentencePair
from distilabel.llms import InferenceEndpointsLLM
generate_sentence_pair = GenerateSentencePair(
triplet=True, # `False` to generate only positive
action="query",
context="Argilla is an open-source data curation platform for LLMs.",
llm=InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3-70B-Instruct",
tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct",
),
input_batch_size=10,
)
generate_sentence_pair.load()
result = generate_sentence_pair.process([{"anchor": "I want to generate queries for my LLM."}])
```
"""

triplet: bool = False
action: GenerationAction
context: str = ""

def load(self) -> None:
"""Loads the Jinja2 template."""
Expand Down Expand Up @@ -203,11 +231,20 @@ def format_input(self, input: Dict[str, Any]) -> "ChatType":
action_sentence = GENERATION_ACTION_SENTENCES[self.action]
system_prompt = (
POSITIVE_NEGATIVE_SYSTEM_PROMPT if self.triplet else POSITIVE_SYSTEM_PROMPT
).format(action_sentence=action_sentence)
).format(
action_sentence=action_sentence,
context=CONTEXT_INTRO if self.context else "",
)

return [
{"role": "system", "content": system_prompt},
{"role": "user", "content": self._template.render(anchor=input["anchor"])},
{
"role": "user",
"content": self._template.render(
anchor=input["anchor"],
context=self.context if self.context else None,
),
},
]

@property
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
{% if context is not none -%}
## Context

{{ context }}

{% endif -%}

## Anchor

{{ anchor }}
Expand Down
104 changes: 95 additions & 9 deletions tests/unit/steps/tasks/test_sentence_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import pytest
from distilabel.steps.tasks.sentence_transformers import (
CONTEXT_INTRO,
POSITIVE_NEGATIVE_SYSTEM_PROMPT,
POSITIVE_SYSTEM_PROMPT,
GenerateSentencePair,
Expand All @@ -32,50 +33,56 @@ class TestGenerateSentencePair:
(
"paraphrase",
True,
POSITIVE_NEGATIVE_SYSTEM_PROMPT.format(action_sentence="paraphrase"),
POSITIVE_NEGATIVE_SYSTEM_PROMPT.format(
action_sentence="paraphrase", context=""
),
),
(
"paraphrase",
False,
POSITIVE_SYSTEM_PROMPT.format(action_sentence="paraphrase"),
POSITIVE_SYSTEM_PROMPT.format(action_sentence="paraphrase", context=""),
),
(
"semantically-similar",
True,
POSITIVE_NEGATIVE_SYSTEM_PROMPT.format(
action_sentence="be semantically similar to"
action_sentence="be semantically similar to", context=""
),
),
(
"semantically-similar",
False,
POSITIVE_SYSTEM_PROMPT.format(
action_sentence="be semantically similar to"
action_sentence="be semantically similar to", context=""
),
),
(
"query",
True,
POSITIVE_NEGATIVE_SYSTEM_PROMPT.format(
action_sentence="be a query for"
action_sentence="be a query for", context=""
),
),
(
"query",
False,
POSITIVE_SYSTEM_PROMPT.format(action_sentence="be a query for"),
POSITIVE_SYSTEM_PROMPT.format(
action_sentence="be a query for", context=""
),
),
(
"answer",
True,
POSITIVE_NEGATIVE_SYSTEM_PROMPT.format(
action_sentence="be an answer for"
action_sentence="be an answer for", context=""
),
),
(
"answer",
False,
POSITIVE_SYSTEM_PROMPT.format(action_sentence="be an answer for"),
POSITIVE_SYSTEM_PROMPT.format(
action_sentence="be an answer for", context=""
),
),
],
)
Expand All @@ -84,10 +91,89 @@ def test_format_input(
) -> None:
task = GenerateSentencePair(llm=DummyLLM(), action=action, triplet=triplet)
task.load()
content = "## Anchor\n\nThis is a unit test\n"
assert task.format_input({"anchor": "This is a unit test"}) == [
{"role": "system", "content": system_prompt},
{"role": "user", "content": content},
]

@pytest.mark.parametrize(
"action,triplet,system_prompt",
[
(
"paraphrase",
True,
POSITIVE_NEGATIVE_SYSTEM_PROMPT.format(
action_sentence="paraphrase", context=CONTEXT_INTRO
),
),
(
"paraphrase",
False,
POSITIVE_SYSTEM_PROMPT.format(
action_sentence="paraphrase", context=CONTEXT_INTRO
),
),
(
"semantically-similar",
True,
POSITIVE_NEGATIVE_SYSTEM_PROMPT.format(
action_sentence="be semantically similar to", context=CONTEXT_INTRO
),
),
(
"semantically-similar",
False,
POSITIVE_SYSTEM_PROMPT.format(
action_sentence="be semantically similar to", context=CONTEXT_INTRO
),
),
(
"query",
True,
POSITIVE_NEGATIVE_SYSTEM_PROMPT.format(
action_sentence="be a query for", context=CONTEXT_INTRO
),
),
(
"query",
False,
POSITIVE_SYSTEM_PROMPT.format(
action_sentence="be a query for", context=CONTEXT_INTRO
),
),
(
"answer",
True,
POSITIVE_NEGATIVE_SYSTEM_PROMPT.format(
action_sentence="be an answer for", context=CONTEXT_INTRO
),
),
(
"answer",
False,
POSITIVE_SYSTEM_PROMPT.format(
action_sentence="be an answer for", context=CONTEXT_INTRO
),
),
],
)
def test_format_input_with_context(
self, action: GenerationAction, triplet: bool, system_prompt: str
) -> None:
context = "This is your context."
task = GenerateSentencePair(
llm=DummyLLM(),
action=action,
triplet=triplet,
context=context,
)
task.load()
content = f"## Context\n\n{context}\n\n## Anchor\n\nThis is a unit test\n"
# content = f"## Anchor\n\nThis is a unit test\n## Context\n\n{context}"
assert task.format_input({"anchor": "This is a unit test"}) == [
{"role": "system", "content": system_prompt},
{"role": "user", "content": "## Anchor\n\nThis is a unit test\n"},
{"role": "user", "content": content},
]

@pytest.mark.parametrize(
Expand Down

0 comments on commit 23b3b41

Please sign in to comment.