Skip to content

Commit

Permalink
Update LLM.generate output to include statistics (#1034)
Browse files Browse the repository at this point in the history
Co-authored-by: Gabriel Martín Blázquez <[email protected]>
  • Loading branch information
plaguss and gabrielmbmb authored Nov 19, 2024
1 parent e830e25 commit 2469407
Show file tree
Hide file tree
Showing 54 changed files with 2,061 additions and 645 deletions.
60 changes: 50 additions & 10 deletions docs/sections/how_to_guides/basic/llm/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,45 @@ LLM subclasses are designed to be used within a [Task][distilabel.steps.tasks.Ta
```python
from distilabel.models import InferenceEndpointsLLM

llm = InferenceEndpointsLLM(model="meta-llama/Meta-Llama-3.1-70B-Instruct")
llm = InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
tokenizer_id="meta-llama/Meta-Llama-3.1-70B-Instruct"
)
llm.load()

llm.generate_outputs(
inputs=[
[{"role": "user", "content": "What's the capital of Spain?"}],
],
)
# "The capital of Spain is Madrid."
# [
# {
# "generations": [
# "The capital of Spain is Madrid."
# ],
# "statistics": {
# "input_tokens": [
# 43
# ],
# "output_tokens": [
# 8
# ]
# }
# }
# ]
```

!!! NOTE
!!! Note
Always call the `LLM.load` or `Task.load` method when using LLMs standalone or as part of a `Task`. If using a `Pipeline`, this is done automatically in `Pipeline.run()`.

!!! Tip "New in version 1.5.0"
Since version `1.5.0` the LLM output is a list of dictionaries (one per item in the `inputs`),
each containing `generations`, that reports the text returned by the `LLM`, and a `statistics` field that will store statistics related to the `LLM` generation. Initially, this will include
`input_tokens` and `output_tokens` when available, which will be obtained via the API when available, or if a tokenizer is available for the model used, using the tokenizer for the model.
This data will be moved by the corresponding `Task` during the pipeline processing and moved to `distilabel_metadata` so we can operate on this data if we want, like for example computing the number of tokens per dataset.

To access to the previous result one just has to access to the generations in the resulting dictionary: `result[0]["generations"]`.

### Offline Batch Generation

By default, all `LLM`s will generate text in a synchronous manner i.e. send inputs using `generate_outputs` method that will get blocked until outputs are generated. There are some `LLM`s (such as [OpenAILLM][distilabel.models.llms.openai.OpenAILLM]) that implements what we denote as _offline batch generation_, which allows to send the inputs to the LLM-as-a-service which will generate the outputs asynchronously and give us a job id that we can use later to check the status and retrieve the generated outputs when they are ready. LLM-as-a-service platforms offers this feature as a way to save costs in exchange of waiting for the outputs to be generated.
Expand Down Expand Up @@ -56,7 +81,8 @@ llm.generate_outputs( # (4)
[{"role": "user", "content": "What's the capital of Spain?"}],
],
)
# "The capital of Spain is Madrid."
# [{'generations': ['The capital of Spain is Madrid.'],
# 'statistics': {'input_tokens': [13], 'output_tokens': [7]}}]
```

1. At first the `jobs_ids` attribute is `None`.
Expand All @@ -81,7 +107,8 @@ llm.generate_outputs(
[{"role": "user", "content": "What's the capital of Spain?"}],
],
)
# "The capital of Spain is Madrid."
# [{'generations': ['The capital of Spain is Madrid.'],
# 'statistics': {'input_tokens': [13], 'output_tokens': [7]}}]
```

### Within a Task
Expand All @@ -92,20 +119,30 @@ Pass the LLM as an argument to the [`Task`][distilabel.steps.tasks.Task], and th
from distilabel.models import OpenAILLM
from distilabel.steps.tasks import TextGeneration

llm = OpenAILLM(model="gpt-4")
llm = OpenAILLM(model="gpt-4o-mini")
task = TextGeneration(name="text_generation", llm=llm)

task.load()

next(task.process(inputs=[{"instruction": "What's the capital of Spain?"}]))
# [{'instruction': "What's the capital of Spain?", "generation": "The capital of Spain is Madrid."}]
# [{'instruction': "What's the capital of Spain?",
# 'generation': 'The capital of Spain is Madrid.',
# 'distilabel_metadata': {'raw_output_text_generation': 'The capital of Spain is Madrid.',
# 'raw_input_text_generation': [{'role': 'user',
# 'content': "What's the capital of Spain?"}],
# 'statistics_text_generation': {'input_tokens': 13, 'output_tokens': 7}},
# 'model_name': 'gpt-4o-mini'}]
```

!!! Note
As mentioned in *Working with LLMs* section, the generation of an LLM is automatically moved to `distilabel_metadata` to avoid interference with the common workflow, so the addition of the `statistics` it's an extra component available for the user, but nothing has to be changed in the
defined pipelines.

### Runtime Parameters

LLMs can have runtime parameters, such as `generation_kwargs`, provided via the `Pipeline.run()` method using the `params` argument.

!!! NOTE
!!! Note
Runtime parameters can differ between LLM subclasses, caused by the different functionalities offered by the LLM providers.

```python
Expand All @@ -122,7 +159,7 @@ with Pipeline(name="text-generation-pipeline") as pipeline:

text_generation = TextGeneration(
name="text_generation",
llm=OpenAILLM(model="gpt-4"),
llm=OpenAILLM(model="gpt-4o-mini"),
)

load_dataset >> text_generation
Expand Down Expand Up @@ -200,9 +237,12 @@ To create custom LLMs, subclass either [`LLM`][distilabel.models.llms.LLM] for s

`generate` and `agenerate` keyword arguments (but `input` and `num_generations`) are considered as `RuntimeParameter`s, so a value can be passed to them via the `parameters` argument of the `Pipeline.run` method.

!!! NOTE
!!! Note
To have the arguments of the `generate` and `agenerate` coerced to the expected types, the `validate_call` decorator is used, which will automatically coerce the arguments to the expected types, and raise an error if the types are not correct. This is specially useful when providing a value for an argument of `generate` or `agenerate` from the CLI, since the CLI will always provide the arguments as strings.

!!! Warning
Additional LLMs created in `distilabel` will have to take into account how the `statistics` are generated to properly include them in the LLM output.

## Available LLMs

[Our LLM gallery](../../../../components-gallery/llms/index.md) shows a list of the available LLMs that can be used within the `distilabel` library.
43 changes: 29 additions & 14 deletions docs/sections/how_to_guides/basic/task/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,26 +21,35 @@ task.load()

next(task.process([{"instruction": "What's the capital of Spain?"}]))
# [
# {
# 'instruction': "What's the capital of Spain?",
# 'generation': 'The capital of Spain is Madrid.',
# 'distilabel_metadata': {
# 'raw_output_text-generation': 'The capital of Spain is Madrid.',
# 'raw_input_text-generation': [
# {'role': 'user', 'content': "What's the capital of Spain?"}
# ]
# },
# 'model_name': 'meta-llama/Meta-Llama-3-70B-Instruct'
# }
# {
# "instruction": "What's the capital of Spain?",
# "generation": "The capital of Spain is Madrid.",
# "distilabel_metadata": {
# "raw_output_text-generation": "The capital of Spain is Madrid.",
# "raw_input_text-generation": [
# {
# "role": "user",
# "content": "What's the capital of Spain?"
# }
# ],
# "statistics_text-generation": { # (1)
# "input_tokens": 18,
# "output_tokens": 8
# }
# },
# "model_name": "meta-llama/Meta-Llama-3.1-8B-Instruct"
# }
# ]
```

!!! NOTE
1. The `LLMs` will not only return the text but also a `statistics_{STEP_NAME}` field that will contain statistics related to the generation. If available, at least the input and output tokens will be returned.

!!! Note
The `Step.load()` always needs to be executed when being used as a standalone. Within a pipeline, this will be done automatically during pipeline execution.

As shown above, the [`TextGeneration`][distilabel.steps.tasks.TextGeneration] task adds a `generation` based on the `instruction`.

!!! Tip
!!! Tip "New in version 1.2.0"
Since version `1.2.0`, we provide some metadata about the LLM call through `distilabel_metadata`. This can be disabled by setting the `add_raw_output` attribute to `False` when creating the task.

Additionally, since version `1.4.0`, the formatted input can also be included, which can be helpful when testing
Expand All @@ -57,9 +66,12 @@ As shown above, the [`TextGeneration`][distilabel.steps.tasks.TextGeneration] ta
)
```

!!! Tip "New in version 1.5.0"
Since version `1.5.0` `distilabel_metadata` includes a new `statistics` field out of the box. The generation from the LLM will not only contain the text, but also statistics associated with the text if available, like the input and output tokens. This field will be generated with `statistic_{STEP_NAME}` to avoid collisions between different steps in the pipeline, similar to how `raw_output_{STEP_NAME}` works.

### Task.print

!!! Info
!!! Info "New in version 1.4.0"
New since version `1.4.0`, [`Task.print`][distilabel.steps.tasks.base._Task.print] `Task.print` method.

The `Tasks` include a handy method to show what the prompt formatted for an `LLM` would look like, let's see an example with [`UltraFeedback`][distilabel.steps.tasks.ultrafeedback.UltraFeedback], but it applies to any other `Task`.
Expand Down Expand Up @@ -271,3 +283,6 @@ We can define a custom step by creating a new subclass of the [`Task`][distilabe
# Format the `LLM` output here
return {"output_field": output}
```

!!! Warning
Most `Tasks` reuse the `Task.process` method to process the generations, but if a new `Task` defines a custom `process` method, like happens for example with [`Magpie`][distilabel.steps.tasks.magpie.base.Magpie], one hast to deal with the `statistics` returned by the `LLM`.
29 changes: 22 additions & 7 deletions src/distilabel/models/llms/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,19 @@
from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.models.llms.base import AsyncLLM
from distilabel.models.llms.typing import GenerateOutput
from distilabel.models.llms.utils import prepare_output
from distilabel.steps.tasks.typing import (
FormattedInput,
InstructorStructuredOutputType,
)

if TYPE_CHECKING:
from typing import BaseModel

from anthropic import AsyncAnthropic
from anthropic.types import Message

from distilabel.llms.typing import LLMStatistics


_ANTHROPIC_API_KEY_ENV_VAR_NAME = "ANTHROPIC_API_KEY"
Expand Down Expand Up @@ -260,17 +266,26 @@ async def agenerate( # type: ignore
if structured_output:
kwargs = self._prepare_kwargs(kwargs, structured_output)

generations = []

completion = await self._aclient.messages.create(**kwargs) # type: ignore
completion: Union["Message", "BaseModel"] = await self._aclient.messages.create(
**kwargs
) # type: ignore
if structured_output:
generations.append(completion.model_dump_json())
return generations
# raw_response = completion._raw_response
return prepare_output(
[completion.model_dump_json()],
**self._get_llm_statistics(completion._raw_response),
)

if (content := completion.content[0].text) is None:
self._logger.warning(
f"Received no response using Anthropic client (model: '{self.model}')."
f" Finish reason was: {completion.stop_reason}"
)
generations.append(content)
return generations
return prepare_output([content], **self._get_llm_statistics(completion))

@staticmethod
def _get_llm_statistics(completion: "Message") -> "LLMStatistics":
return {
"input_tokens": [completion.usage.input_tokens],
"output_tokens": [completion.usage.output_tokens],
}
52 changes: 45 additions & 7 deletions src/distilabel/models/llms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import time
from abc import ABC, abstractmethod
from functools import cached_property
from itertools import islice
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
Expand All @@ -33,7 +34,6 @@
RuntimeParametersMixin,
)
from distilabel.utils.docstring import parse_google_docstring
from distilabel.utils.itertools import grouper
from distilabel.utils.notebook import in_notebook
from distilabel.utils.serialization import _Serializable

Expand Down Expand Up @@ -459,18 +459,16 @@ async def _agenerate(
)
for input in inputs
]
return await asyncio.gather(*tasks)
result = await asyncio.gather(*tasks)
return result

tasks = [
asyncio.create_task(self.agenerate(input=input, **kwargs))
for input in inputs
for _ in range(num_generations)
]
outputs = [outputs[0] for outputs in await asyncio.gather(*tasks)]
return [
list(group)
for group in grouper(outputs, n=num_generations, incomplete="ignore")
]
outputs = await asyncio.gather(*tasks)
return merge_responses(outputs, n=num_generations)

def generate(
self,
Expand Down Expand Up @@ -590,3 +588,43 @@ def _prepare_kwargs(
},
)
return arguments


def merge_responses(
responses: List[Dict[str, Any]], n: int = 1
) -> List[Dict[str, Any]]:
"""Helper function to group the responses from `LLM.agenerate` method according
to the number of generations requested.
Args:
responses: the responses from the `LLM.agenerate` method.
n: number of responses to group together. Defaults to 1.
Returns:
List of merged responses, where each merged response contains n generations
and their corresponding statistics.
"""
if not responses:
return []

def chunks(lst, n):
"""Yield successive n-sized chunks from lst."""
for i in range(0, len(lst), n):
yield list(islice(lst, i, i + n))

# Split responses into groups of size n
grouped_responses = list(chunks(responses, n))

result = []
for group in grouped_responses:
first = group[0]
merged = {
"generations": sum((r["generations"] for r in group), []),
"statistics": {
key: sum((r["statistics"][key] for r in group), [])
for key in first["statistics"]
},
}
result.append(merged)

return result
Loading

0 comments on commit 2469407

Please sign in to comment.