Skip to content

Commit

Permalink
Pretty print (#934)
Browse files Browse the repository at this point in the history
* Add integration test to showcase the prompts

* Add a base print method so the Tasks can pretty print their prompts easily

* Update base method to allow automatic pretty printing

* Add optional argument instead of the default, update method name and return type

* Fix type hint

* Add example in docstrings

* Add section in docs for the print method
  • Loading branch information
plaguss authored Oct 7, 2024
1 parent 4b056ff commit 87683f0
Show file tree
Hide file tree
Showing 12 changed files with 294 additions and 8 deletions.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
69 changes: 69 additions & 0 deletions docs/sections/how_to_guides/basic/task/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,75 @@ As shown above, the [`TextGeneration`][distilabel.steps.tasks.TextGeneration] ta
)
```

### Task.print

!!! Info
New since version `1.4.0`, [`Task.print`][distilabel.steps.tasks.base._Task.print] `Task.print` method.

The `Tasks` include a handy method to show what the prompt formatted for an `LLM` would look like, let's see an example with [`UltraFeedback`][distilabel.steps.tasks.ultrafeedback.UltraFeedback], but it applies to any other `Task`.

```python
from distilabel.steps.tasks import UltraFeedback
from distilabel.llms.huggingface import InferenceEndpointsLLM

uf = UltraFeedback(
llm=InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
),
)
uf.load()
uf.print()
```

The result will be a rendered prompt, with the System prompt (if contained for the task) and the User prompt, rendered with rich (it will show exactly the same in a jupyter notebook).

![task-print](../../../../assets/images/sections/how_to_guides/tasks/task_print.png)

In case you want to test with a custom input, you can pass an example to the tasks` `format_input` method (or generate it on your own depending on the task), and pass it to the print method so that it shows your example:


```python
uf.print(
uf.format_input({"instruction": "test", "generations": ["1", "2"]})
)
```

??? "Using a DummyLLM to avoid loading one"

In case you don't want to load an LLM to render the template, you can create a dummy one like the ones we could use for testing.

```python
from distilabel.llms.base import LLM
from distilabel.llms.mixins.magpie import MagpieChatTemplateMixin

class DummyLLM(AsyncLLM, MagpieChatTemplateMixin):
structured_output: Any = None
magpie_pre_query_template: str = "llama3"

def load(self) -> None:
pass

@property
def model_name(self) -> str:
return "test"

def generate(
self, input: "FormattedInput", num_generations: int = 1
) -> "GenerateOutput":
return ["output" for _ in range(num_generations)]
```

You can use this `LLM` just as any of the other ones to `load` your task and call `print`:

```python
uf = UltraFeedback(llm=DummyLLM())
uf.load()
uf.print()
```

!!! Note
When creating a custom task, the `print` method will be available by default, but it is limited to the most common scenarios for the inputs. If you test your new task and find it's not working as expected (for example, if your task contains one input consisting of a list of texts instead of a single one), you should override the `_sample_input` method. You can inspect the [`UltraFeedback`][distilabel.steps.tasks.ultrafeedback.UltraFeedback] source code for this.

## Specifying the number of generations and grouping generations

All the `Task`s have a `num_generations` attribute that allows defining the number of generations that we want to have per input. We can update the example above to generate 3 completions per input:
Expand Down
91 changes: 89 additions & 2 deletions src/distilabel/steps/tasks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import importlib
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, List, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

from pydantic import Field, PrivateAttr
from typing_extensions import override
Expand All @@ -34,7 +34,7 @@

if TYPE_CHECKING:
from distilabel.llms.typing import GenerateOutput
from distilabel.steps.tasks.typing import FormattedInput
from distilabel.steps.tasks.typing import ChatType, FormattedInput
from distilabel.steps.typing import StepOutput


Expand Down Expand Up @@ -276,6 +276,93 @@ def get_structured_output(self) -> Union[Dict[str, Any], None]:
"""
return None

def _sample_input(self) -> "ChatType":
"""Returns a sample input to be used in the `print` method.
Tasks that don't adhere to a format input that returns a map of the type
str -> str should override this method to return a sample input.
"""
return self.format_input(
{input: f"<PLACEHOLDER_{input.upper()}>" for input in self.inputs}
)

def print(self, sample_input: Optional["ChatType"] = None) -> None:
"""Prints a sample input to the console using the `rich` library.
Helper method to visualize the prompt of the task.
Args:
sample_input: A sample input to be printed. If not provided, a default will be
generated using the `_sample_input` method, which can be overriden by
subclasses. This should correspond to the same example you could pass to
the `format_input` method.
The variables be named <PLACEHOLDER_VARIABLE_NAME> by default.
Examples:
Print the URIAL prompt:
```python
from distilabel.steps.tasks import URIAL
from distilabel.llms.huggingface import InferenceEndpointsLLM
# Consider this as a placeholder for your actual LLM.
urial = URIAL(
llm=InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
),
)
urial.load()
urial.print()
╭─────────────────────────────────────── Prompt: URIAL ────────────────────────────────────────╮
│ ╭────────────────────────────────────── User Message ───────────────────────────────────────╮ │
│ │ # Instruction │ │
│ │ │ │
│ │ Below is a list of conversations between a human and an AI assistant (you). │ │
│ │ Users place their queries under "# User:", and your responses are under "# Assistant:". │ │
│ │ You are a helpful, respectful, and honest assistant. │ │
│ │ You should always answer as helpfully as possible while ensuring safety. │ │
│ │ Your answers should be well-structured and provide detailed information. They should also │ │
│ │ have an engaging tone. │ │
│ │ Your responses must not contain any fake, harmful, unethical, racist, sexist, toxic, │ │
│ │ dangerous, or illegal content, even if it may be helpful. │ │
│ │ Your response must be socially responsible, and thus you can refuse to answer some │ │
│ │ controversial topics. │ │
│ │ │ │
│ │ │ │
│ │ # User: │ │
│ │ │ │
│ │ <PLACEHOLDER_INSTRUCTION> │ │
│ │ │ │
│ │ # Assistant: │ │
│ ╰───────────────────────────────────────────────────────────────────────────────────────────╯ │
╰───────────────────────────────────────────────────────────────────────────────────────────────╯
```
"""
from rich.console import Console, Group
from rich.panel import Panel
from rich.text import Text

console = Console()
sample_input = sample_input or self._sample_input()

panels = []
for item in sample_input:
content = Text.assemble((item.get("content", ""),))
panel = Panel(
content,
title=f"[bold][magenta]{item.get('role', '').capitalize()} Message[/magenta][/bold]",
border_style="light_cyan3",
)
panels.append(panel)

# Create a group of panels
# Wrap the group in an outer panel
outer_panel = Panel(
Group(*panels),
title=f"[bold][magenta]Prompt: {type(self).__name__} [/magenta][/bold]",
border_style="light_cyan3",
expand=False,
)
console.print(outer_panel)


class Task(_Task, Step):
"""Task is a class that implements the `_Task` abstract class and adds the `Step`
Expand Down
14 changes: 14 additions & 0 deletions src/distilabel/steps/tasks/complexity_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,3 +239,17 @@ def _format_structured_output(
return orjson.loads(output)
except orjson.JSONDecodeError:
return {"scores": [None] * len(input["instructions"])}

@override
def _sample_input(self) -> "ChatType":
"""Returns a sample input to be used in the `print` method.
Tasks that don't adhere to a format input that returns a map of the type
str -> str should override this method to return a sample input.
"""
return self.format_input(
{
"instructions": [
f"<PLACEHOLDER_{f'GENERATION_{i}'.upper()}>" for i in range(2)
],
}
)
6 changes: 6 additions & 0 deletions src/distilabel/steps/tasks/evol_instruct/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,3 +388,9 @@ def process(self, inputs: StepInput) -> "StepOutput": # type: ignore
):
input.update(self.format_output(instruction, answers[idx]))
yield inputs

@override
def _sample_input(self) -> ChatType:
return self.format_input(
self._apply_random_mutation("<PLACEHOLDER_INSTRUCTION>")
)
4 changes: 4 additions & 0 deletions src/distilabel/steps/tasks/evol_instruct/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,3 +347,7 @@ def process(self, offset: int = 0) -> "GeneratorStepOutput": # type: ignore
],
True,
)

@override
def _sample_input(self) -> "ChatType":
return self._apply_random_mutation(iter_no=0)[0]
4 changes: 4 additions & 0 deletions src/distilabel/steps/tasks/evol_quality/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,3 +271,7 @@ def process(self, inputs: StepInput) -> "StepOutput": # type: ignore
yield inputs

self._logger.info(f"🎉 Finished evolving {len(responses)} instructions!")

@override
def _sample_input(self) -> ChatType:
return self.format_input("<PLACEHOLDER_INSTRUCTION>")
15 changes: 9 additions & 6 deletions src/distilabel/steps/tasks/improving_text_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import importlib.resources as importlib_resources
import random
import re
import sys
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Literal, Optional, Union

if sys.version_info < (3, 9):
import importlib_resources
else:
import importlib.resources as importlib_resources

from jinja2 import Template
from pydantic import Field, PrivateAttr
from typing_extensions import override
Expand Down Expand Up @@ -232,6 +227,10 @@ def process(self, offset: int = 0) -> GeneratorStepOutput: # type: ignore
)
yield task_outputs, True

@override
def _sample_input(self) -> ChatType:
return self.prompt


# IMPLEMENTED TASKS
class EmbeddingTaskGenerator(GeneratorTask):
Expand Down Expand Up @@ -402,6 +401,10 @@ def format_output(
pass
return {"tasks": output}

@override
def _sample_input(self) -> ChatType:
return self.prompt


class GenerateTextRetrievalData(_EmbeddingDataGeneration):
"""Generate text retrieval data with an `LLM` to later on train an embedding model.
Expand Down
6 changes: 6 additions & 0 deletions src/distilabel/steps/tasks/magpie/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing import TYPE_CHECKING, Any, Dict, Union

from pydantic import Field
from typing_extensions import override

from distilabel.errors import DistilabelUserError
from distilabel.llms.mixins.magpie import MagpieChatTemplateMixin
Expand All @@ -23,6 +24,7 @@
from distilabel.steps.tasks.magpie.base import MagpieBase

if TYPE_CHECKING:
from distilabel.steps.tasks.typing import ChatType
from distilabel.steps.typing import GeneratorStepOutput, StepColumns


Expand Down Expand Up @@ -312,3 +314,7 @@ def process(self, offset: int = 0) -> "GeneratorStepOutput":
)
generated += rows_to_generate # type: ignore
yield (conversations, generated == self.num_rows)

@override
def _sample_input(self) -> "ChatType":
return self._generate_with_pre_query_template(inputs=[{}])
11 changes: 11 additions & 0 deletions src/distilabel/steps/tasks/quality_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,3 +262,14 @@ def _format_structured_output(
return orjson.loads(output)
except orjson.JSONDecodeError:
return {"scores": [None] * len(input["responses"])}

@override
def _sample_input(self) -> ChatType:
return self.format_input(
{
"instruction": f"<PLACEHOLDER_{'instruction'.upper()}>",
"responses": [
f"<PLACEHOLDER_{f'RESPONSE_{i}'.upper()}>" for i in range(2)
],
}
)
11 changes: 11 additions & 0 deletions src/distilabel/steps/tasks/ultrafeedback.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,3 +480,14 @@ def _format_structured_output(
"types": [None] * len(input["generations"]),
"rationales-for-ratings": [None] * len(input["generations"]),
}

@override
def _sample_input(self) -> ChatType:
return self.format_input(
{
"instruction": f"<PLACEHOLDER_{'instruction'.upper()}>",
"generations": [
f"<PLACEHOLDER_{f'GENERATION_{i}'.upper()}>" for i in range(2)
],
}
)
71 changes: 71 additions & 0 deletions tests/integration/test_prints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import partial

import pytest

from distilabel.llms.mixins.magpie import MagpieChatTemplateMixin
from distilabel.steps import tasks as tasks_
from tests.unit.conftest import DummyLLM

# The tasks not listed here don't have a print method (or don't have a print method that works)
tasks = [
tasks_.ComplexityScorer,
partial(tasks_.EvolInstruct, num_evolutions=1),
partial(tasks_.EvolComplexity, num_evolutions=1),
partial(tasks_.EvolComplexityGenerator, num_instructions=1),
partial(tasks_.EvolInstructGenerator, num_instructions=1),
partial(tasks_.EvolQuality, num_evolutions=1),
tasks_.Genstruct,
partial(
tasks_.BitextRetrievalGenerator,
source_language="English",
target_language="Spanish",
unit="sentence",
difficulty="elementary school",
high_score="4",
low_score="2.5",
),
partial(tasks_.EmbeddingTaskGenerator, category="text-retrieval"),
tasks_.GenerateLongTextMatchingData,
tasks_.GenerateShortTextMatchingData,
tasks_.GenerateTextClassificationData,
tasks_.GenerateTextRetrievalData,
tasks_.MonolingualTripletGenerator,
tasks_.InstructionBacktranslation,
tasks_.Magpie,
tasks_.MagpieGenerator,
partial(tasks_.PrometheusEval, mode="absolute", rubric="factual-validity"),
tasks_.QualityScorer,
tasks_.SelfInstruct,
partial(tasks_.GenerateSentencePair, action="paraphrase"),
tasks_.UltraFeedback,
tasks_.URIAL,
]


class TestLLM(DummyLLM, MagpieChatTemplateMixin):
magpie_pre_query_template: str = "llama3"


llm = TestLLM()


@pytest.mark.parametrize("task", tasks)
def test_prints(task):
t = task(llm=llm)
t.load()
t.print()
t.unload()

0 comments on commit 87683f0

Please sign in to comment.