From 2469407aa3f71dcd8ee7a451cac8c4bd917c34f1 Mon Sep 17 00:00:00 2001 From: Agus Date: Tue, 19 Nov 2024 09:41:34 +0100 Subject: [PATCH] Update `LLM.generate` output to include `statistics` (#1034) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Gabriel Martín Blázquez --- .../sections/how_to_guides/basic/llm/index.md | 60 ++- .../how_to_guides/basic/task/index.md | 43 +- src/distilabel/models/llms/anthropic.py | 29 +- src/distilabel/models/llms/base.py | 52 ++- src/distilabel/models/llms/cohere.py | 40 +- src/distilabel/models/llms/groq.py | 21 +- .../llms/huggingface/inference_endpoints.py | 89 ++-- .../models/llms/huggingface/transformers.py | 20 +- src/distilabel/models/llms/litellm.py | 28 +- src/distilabel/models/llms/llamacpp.py | 13 +- src/distilabel/models/llms/mistral.py | 20 +- src/distilabel/models/llms/ollama.py | 12 +- src/distilabel/models/llms/openai.py | 34 +- src/distilabel/models/llms/typing.py | 21 +- src/distilabel/models/llms/utils.py | 62 +++ src/distilabel/models/llms/vertexai.py | 13 +- src/distilabel/models/llms/vllm.py | 103 ++++- src/distilabel/steps/tasks/base.py | 104 ++++- .../steps/tasks/evol_instruct/base.py | 51 ++- .../steps/tasks/evol_instruct/generator.py | 56 ++- .../steps/tasks/evol_quality/base.py | 23 +- src/distilabel/steps/tasks/magpie/base.py | 80 +++- src/distilabel/steps/typing.py | 20 +- src/distilabel/utils/dicts.py | 64 ++- .../test_offline_batch_generation.py | 11 +- tests/unit/conftest.py | 32 +- .../huggingface/test_inference_endpoints.py | 143 +++++-- .../llms/huggingface/test_transformers.py | 18 +- tests/unit/models/llms/test_anthropic.py | 42 +- tests/unit/models/llms/test_cohere.py | 42 +- tests/unit/models/llms/test_groq.py | 35 +- tests/unit/models/llms/test_litellm.py | 14 +- tests/unit/models/llms/test_llamacpp.py | 7 +- tests/unit/models/llms/test_mistral.py | 48 ++- tests/unit/models/llms/test_ollama.py | 22 +- tests/unit/models/llms/test_openai.py | 103 ++++- tests/unit/models/llms/test_vertexai.py | 24 +- tests/unit/models/llms/test_vllm.py | 168 +++++--- tests/unit/models/llms/utils.py | 9 +- .../steps/clustering/test_text_clustering.py | 15 +- .../unit/steps/tasks/apigen/test_generator.py | 13 +- .../steps/tasks/evol_instruct/test_base.py | 19 + .../tasks/evol_instruct/test_generator.py | 28 +- .../steps/tasks/evol_quality/test_base.py | 13 + tests/unit/steps/tasks/magpie/test_base.py | 227 +++++++++- .../tasks/structured_outputs/test_outlines.py | 5 +- tests/unit/steps/tasks/test_base.py | 392 ++++++++---------- tests/unit/steps/tasks/test_decorator.py | 8 +- .../tasks/test_improving_text_embeddings.py | 127 +++++- .../tasks/test_instruction_backtranslation.py | 16 +- .../steps/tasks/test_structured_generation.py | 15 +- .../steps/tasks/test_text_classification.py | 17 +- .../unit/steps/tasks/test_text_generation.py | 6 +- tests/unit/steps/tasks/test_ultrafeedback.py | 29 +- 54 files changed, 2061 insertions(+), 645 deletions(-) create mode 100644 src/distilabel/models/llms/utils.py diff --git a/docs/sections/how_to_guides/basic/llm/index.md b/docs/sections/how_to_guides/basic/llm/index.md index d5d5a37368..d715994cb4 100644 --- a/docs/sections/how_to_guides/basic/llm/index.md +++ b/docs/sections/how_to_guides/basic/llm/index.md @@ -7,7 +7,10 @@ LLM subclasses are designed to be used within a [Task][distilabel.steps.tasks.Ta ```python from distilabel.models import InferenceEndpointsLLM -llm = InferenceEndpointsLLM(model="meta-llama/Meta-Llama-3.1-70B-Instruct") +llm = InferenceEndpointsLLM( + model_id="meta-llama/Meta-Llama-3.1-70B-Instruct", + tokenizer_id="meta-llama/Meta-Llama-3.1-70B-Instruct" +) llm.load() llm.generate_outputs( @@ -15,12 +18,34 @@ llm.generate_outputs( [{"role": "user", "content": "What's the capital of Spain?"}], ], ) -# "The capital of Spain is Madrid." +# [ +# { +# "generations": [ +# "The capital of Spain is Madrid." +# ], +# "statistics": { +# "input_tokens": [ +# 43 +# ], +# "output_tokens": [ +# 8 +# ] +# } +# } +# ] ``` -!!! NOTE +!!! Note Always call the `LLM.load` or `Task.load` method when using LLMs standalone or as part of a `Task`. If using a `Pipeline`, this is done automatically in `Pipeline.run()`. +!!! Tip "New in version 1.5.0" + Since version `1.5.0` the LLM output is a list of dictionaries (one per item in the `inputs`), + each containing `generations`, that reports the text returned by the `LLM`, and a `statistics` field that will store statistics related to the `LLM` generation. Initially, this will include + `input_tokens` and `output_tokens` when available, which will be obtained via the API when available, or if a tokenizer is available for the model used, using the tokenizer for the model. + This data will be moved by the corresponding `Task` during the pipeline processing and moved to `distilabel_metadata` so we can operate on this data if we want, like for example computing the number of tokens per dataset. + + To access to the previous result one just has to access to the generations in the resulting dictionary: `result[0]["generations"]`. + ### Offline Batch Generation By default, all `LLM`s will generate text in a synchronous manner i.e. send inputs using `generate_outputs` method that will get blocked until outputs are generated. There are some `LLM`s (such as [OpenAILLM][distilabel.models.llms.openai.OpenAILLM]) that implements what we denote as _offline batch generation_, which allows to send the inputs to the LLM-as-a-service which will generate the outputs asynchronously and give us a job id that we can use later to check the status and retrieve the generated outputs when they are ready. LLM-as-a-service platforms offers this feature as a way to save costs in exchange of waiting for the outputs to be generated. @@ -56,7 +81,8 @@ llm.generate_outputs( # (4) [{"role": "user", "content": "What's the capital of Spain?"}], ], ) -# "The capital of Spain is Madrid." +# [{'generations': ['The capital of Spain is Madrid.'], +# 'statistics': {'input_tokens': [13], 'output_tokens': [7]}}] ``` 1. At first the `jobs_ids` attribute is `None`. @@ -81,7 +107,8 @@ llm.generate_outputs( [{"role": "user", "content": "What's the capital of Spain?"}], ], ) -# "The capital of Spain is Madrid." +# [{'generations': ['The capital of Spain is Madrid.'], +# 'statistics': {'input_tokens': [13], 'output_tokens': [7]}}] ``` ### Within a Task @@ -92,20 +119,30 @@ Pass the LLM as an argument to the [`Task`][distilabel.steps.tasks.Task], and th from distilabel.models import OpenAILLM from distilabel.steps.tasks import TextGeneration -llm = OpenAILLM(model="gpt-4") +llm = OpenAILLM(model="gpt-4o-mini") task = TextGeneration(name="text_generation", llm=llm) task.load() next(task.process(inputs=[{"instruction": "What's the capital of Spain?"}])) -# [{'instruction': "What's the capital of Spain?", "generation": "The capital of Spain is Madrid."}] +# [{'instruction': "What's the capital of Spain?", +# 'generation': 'The capital of Spain is Madrid.', +# 'distilabel_metadata': {'raw_output_text_generation': 'The capital of Spain is Madrid.', +# 'raw_input_text_generation': [{'role': 'user', +# 'content': "What's the capital of Spain?"}], +# 'statistics_text_generation': {'input_tokens': 13, 'output_tokens': 7}}, +# 'model_name': 'gpt-4o-mini'}] ``` +!!! Note + As mentioned in *Working with LLMs* section, the generation of an LLM is automatically moved to `distilabel_metadata` to avoid interference with the common workflow, so the addition of the `statistics` it's an extra component available for the user, but nothing has to be changed in the + defined pipelines. + ### Runtime Parameters LLMs can have runtime parameters, such as `generation_kwargs`, provided via the `Pipeline.run()` method using the `params` argument. -!!! NOTE +!!! Note Runtime parameters can differ between LLM subclasses, caused by the different functionalities offered by the LLM providers. ```python @@ -122,7 +159,7 @@ with Pipeline(name="text-generation-pipeline") as pipeline: text_generation = TextGeneration( name="text_generation", - llm=OpenAILLM(model="gpt-4"), + llm=OpenAILLM(model="gpt-4o-mini"), ) load_dataset >> text_generation @@ -200,9 +237,12 @@ To create custom LLMs, subclass either [`LLM`][distilabel.models.llms.LLM] for s `generate` and `agenerate` keyword arguments (but `input` and `num_generations`) are considered as `RuntimeParameter`s, so a value can be passed to them via the `parameters` argument of the `Pipeline.run` method. -!!! NOTE +!!! Note To have the arguments of the `generate` and `agenerate` coerced to the expected types, the `validate_call` decorator is used, which will automatically coerce the arguments to the expected types, and raise an error if the types are not correct. This is specially useful when providing a value for an argument of `generate` or `agenerate` from the CLI, since the CLI will always provide the arguments as strings. +!!! Warning + Additional LLMs created in `distilabel` will have to take into account how the `statistics` are generated to properly include them in the LLM output. + ## Available LLMs [Our LLM gallery](../../../../components-gallery/llms/index.md) shows a list of the available LLMs that can be used within the `distilabel` library. diff --git a/docs/sections/how_to_guides/basic/task/index.md b/docs/sections/how_to_guides/basic/task/index.md index 7f1d8260e0..dd5de6f837 100644 --- a/docs/sections/how_to_guides/basic/task/index.md +++ b/docs/sections/how_to_guides/basic/task/index.md @@ -21,26 +21,35 @@ task.load() next(task.process([{"instruction": "What's the capital of Spain?"}])) # [ -# { -# 'instruction': "What's the capital of Spain?", -# 'generation': 'The capital of Spain is Madrid.', -# 'distilabel_metadata': { -# 'raw_output_text-generation': 'The capital of Spain is Madrid.', -# 'raw_input_text-generation': [ -# {'role': 'user', 'content': "What's the capital of Spain?"} -# ] -# }, -# 'model_name': 'meta-llama/Meta-Llama-3-70B-Instruct' -# } +# { +# "instruction": "What's the capital of Spain?", +# "generation": "The capital of Spain is Madrid.", +# "distilabel_metadata": { +# "raw_output_text-generation": "The capital of Spain is Madrid.", +# "raw_input_text-generation": [ +# { +# "role": "user", +# "content": "What's the capital of Spain?" +# } +# ], +# "statistics_text-generation": { # (1) +# "input_tokens": 18, +# "output_tokens": 8 +# } +# }, +# "model_name": "meta-llama/Meta-Llama-3.1-8B-Instruct" +# } # ] ``` -!!! NOTE +1. The `LLMs` will not only return the text but also a `statistics_{STEP_NAME}` field that will contain statistics related to the generation. If available, at least the input and output tokens will be returned. + +!!! Note The `Step.load()` always needs to be executed when being used as a standalone. Within a pipeline, this will be done automatically during pipeline execution. As shown above, the [`TextGeneration`][distilabel.steps.tasks.TextGeneration] task adds a `generation` based on the `instruction`. -!!! Tip +!!! Tip "New in version 1.2.0" Since version `1.2.0`, we provide some metadata about the LLM call through `distilabel_metadata`. This can be disabled by setting the `add_raw_output` attribute to `False` when creating the task. Additionally, since version `1.4.0`, the formatted input can also be included, which can be helpful when testing @@ -57,9 +66,12 @@ As shown above, the [`TextGeneration`][distilabel.steps.tasks.TextGeneration] ta ) ``` +!!! Tip "New in version 1.5.0" + Since version `1.5.0` `distilabel_metadata` includes a new `statistics` field out of the box. The generation from the LLM will not only contain the text, but also statistics associated with the text if available, like the input and output tokens. This field will be generated with `statistic_{STEP_NAME}` to avoid collisions between different steps in the pipeline, similar to how `raw_output_{STEP_NAME}` works. + ### Task.print -!!! Info +!!! Info "New in version 1.4.0" New since version `1.4.0`, [`Task.print`][distilabel.steps.tasks.base._Task.print] `Task.print` method. The `Tasks` include a handy method to show what the prompt formatted for an `LLM` would look like, let's see an example with [`UltraFeedback`][distilabel.steps.tasks.ultrafeedback.UltraFeedback], but it applies to any other `Task`. @@ -271,3 +283,6 @@ We can define a custom step by creating a new subclass of the [`Task`][distilabe # Format the `LLM` output here return {"output_field": output} ``` + +!!! Warning + Most `Tasks` reuse the `Task.process` method to process the generations, but if a new `Task` defines a custom `process` method, like happens for example with [`Magpie`][distilabel.steps.tasks.magpie.base.Magpie], one hast to deal with the `statistics` returned by the `LLM`. diff --git a/src/distilabel/models/llms/anthropic.py b/src/distilabel/models/llms/anthropic.py index 7cd3cbcd3f..c6c79a9141 100644 --- a/src/distilabel/models/llms/anthropic.py +++ b/src/distilabel/models/llms/anthropic.py @@ -30,13 +30,19 @@ from distilabel.mixins.runtime_parameters import RuntimeParameter from distilabel.models.llms.base import AsyncLLM from distilabel.models.llms.typing import GenerateOutput +from distilabel.models.llms.utils import prepare_output from distilabel.steps.tasks.typing import ( FormattedInput, InstructorStructuredOutputType, ) if TYPE_CHECKING: + from typing import BaseModel + from anthropic import AsyncAnthropic + from anthropic.types import Message + + from distilabel.llms.typing import LLMStatistics _ANTHROPIC_API_KEY_ENV_VAR_NAME = "ANTHROPIC_API_KEY" @@ -260,17 +266,26 @@ async def agenerate( # type: ignore if structured_output: kwargs = self._prepare_kwargs(kwargs, structured_output) - generations = [] - - completion = await self._aclient.messages.create(**kwargs) # type: ignore + completion: Union["Message", "BaseModel"] = await self._aclient.messages.create( + **kwargs + ) # type: ignore if structured_output: - generations.append(completion.model_dump_json()) - return generations + # raw_response = completion._raw_response + return prepare_output( + [completion.model_dump_json()], + **self._get_llm_statistics(completion._raw_response), + ) if (content := completion.content[0].text) is None: self._logger.warning( f"Received no response using Anthropic client (model: '{self.model}')." f" Finish reason was: {completion.stop_reason}" ) - generations.append(content) - return generations + return prepare_output([content], **self._get_llm_statistics(completion)) + + @staticmethod + def _get_llm_statistics(completion: "Message") -> "LLMStatistics": + return { + "input_tokens": [completion.usage.input_tokens], + "output_tokens": [completion.usage.output_tokens], + } diff --git a/src/distilabel/models/llms/base.py b/src/distilabel/models/llms/base.py index 58ca3b5f62..4657360afb 100644 --- a/src/distilabel/models/llms/base.py +++ b/src/distilabel/models/llms/base.py @@ -21,6 +21,7 @@ import time from abc import ABC, abstractmethod from functools import cached_property +from itertools import islice from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union from pydantic import BaseModel, ConfigDict, Field, PrivateAttr @@ -33,7 +34,6 @@ RuntimeParametersMixin, ) from distilabel.utils.docstring import parse_google_docstring -from distilabel.utils.itertools import grouper from distilabel.utils.notebook import in_notebook from distilabel.utils.serialization import _Serializable @@ -459,18 +459,16 @@ async def _agenerate( ) for input in inputs ] - return await asyncio.gather(*tasks) + result = await asyncio.gather(*tasks) + return result tasks = [ asyncio.create_task(self.agenerate(input=input, **kwargs)) for input in inputs for _ in range(num_generations) ] - outputs = [outputs[0] for outputs in await asyncio.gather(*tasks)] - return [ - list(group) - for group in grouper(outputs, n=num_generations, incomplete="ignore") - ] + outputs = await asyncio.gather(*tasks) + return merge_responses(outputs, n=num_generations) def generate( self, @@ -590,3 +588,43 @@ def _prepare_kwargs( }, ) return arguments + + +def merge_responses( + responses: List[Dict[str, Any]], n: int = 1 +) -> List[Dict[str, Any]]: + """Helper function to group the responses from `LLM.agenerate` method according + to the number of generations requested. + + Args: + responses: the responses from the `LLM.agenerate` method. + n: number of responses to group together. Defaults to 1. + + Returns: + List of merged responses, where each merged response contains n generations + and their corresponding statistics. + """ + if not responses: + return [] + + def chunks(lst, n): + """Yield successive n-sized chunks from lst.""" + for i in range(0, len(lst), n): + yield list(islice(lst, i, i + n)) + + # Split responses into groups of size n + grouped_responses = list(chunks(responses, n)) + + result = [] + for group in grouped_responses: + first = group[0] + merged = { + "generations": sum((r["generations"] for r in group), []), + "statistics": { + key: sum((r["statistics"][key] for r in group), []) + for key in first["statistics"] + }, + } + result.append(merged) + + return result diff --git a/src/distilabel/models/llms/cohere.py b/src/distilabel/models/llms/cohere.py index 80fbddf4f7..043ac4214c 100644 --- a/src/distilabel/models/llms/cohere.py +++ b/src/distilabel/models/llms/cohere.py @@ -23,18 +23,24 @@ Union, ) +import orjson from pydantic import Field, PrivateAttr, SecretStr, validate_call from distilabel.mixins.runtime_parameters import RuntimeParameter from distilabel.models.llms.base import AsyncLLM from distilabel.models.llms.typing import GenerateOutput +from distilabel.models.llms.utils import compute_tokens, prepare_output from distilabel.steps.tasks.typing import ( FormattedInput, InstructorStructuredOutputType, ) if TYPE_CHECKING: - from cohere import AsyncClient, ChatMessage + from cohere import AsyncClient, ChatMessage, Message + from pydantic import BaseModel + from tokenizers import Tokenizer + + from distilabel.llms.typing import LLMStatistics _COHERE_API_KEY_ENV_VAR_NAME = "COHERE_API_KEY" @@ -135,6 +141,7 @@ class User(BaseModel): _ChatMessage: Type["ChatMessage"] = PrivateAttr(...) _aclient: "AsyncClient" = PrivateAttr(...) + _tokenizer: "Tokenizer" = PrivateAttr(...) @property def model_name(self) -> str: @@ -172,6 +179,10 @@ def load(self) -> None: if structured_output := result.get("structured_output"): self.structured_output = structured_output + from cohere.manually_maintained.tokenizers import get_hf_tokenizer + + self._tokenizer: "Tokenizer" = get_hf_tokenizer(self._aclient, self.model) + def _format_chat_to_cohere( self, input: "FormattedInput" ) -> Tuple[Union[str, None], List["ChatMessage"], str]: @@ -278,16 +289,35 @@ async def agenerate( # type: ignore if structured_output: kwargs = self._prepare_kwargs(kwargs, structured_output) # type: ignore - response = await self._aclient.chat(**kwargs) # type: ignore + response: Union["Message", "BaseModel"] = await self._aclient.chat(**kwargs) # type: ignore if structured_output: - return [response.model_dump_json()] + return prepare_output( + [response.model_dump_json()], + **self._get_llm_statistics( + input, orjson.dumps(response.model_dump_json()).decode("utf-8") + ), # type: ignore + ) if (text := response.text) == "": self._logger.warning( # type: ignore f"Received no response using Cohere client (model: '{self.model}')." f" Finish reason was: {response.finish_reason}" ) - return [None] + return prepare_output( + [None], + **self._get_llm_statistics(input, ""), + ) + + return prepare_output( + [text], + **self._get_llm_statistics(input, text), + ) - return [text] + def _get_llm_statistics( + self, input: FormattedInput, output: str + ) -> "LLMStatistics": + return { + "input_tokens": [compute_tokens(input, self._tokenizer.encode)], + "output_tokens": [compute_tokens(output, self._tokenizer.encode)], + } diff --git a/src/distilabel/models/llms/groq.py b/src/distilabel/models/llms/groq.py index 92ff9b8b35..2977c513f3 100644 --- a/src/distilabel/models/llms/groq.py +++ b/src/distilabel/models/llms/groq.py @@ -19,6 +19,7 @@ from distilabel.models.llms.base import AsyncLLM from distilabel.models.llms.typing import GenerateOutput +from distilabel.models.llms.utils import prepare_output from distilabel.steps.base import RuntimeParameter from distilabel.steps.tasks.typing import ( FormattedInput, @@ -27,6 +28,9 @@ if TYPE_CHECKING: from groq import AsyncGroq + from groq.types.chat.chat_completion import ChatCompletion + + from distilabel.llms.typing import LLMStatistics _GROQ_API_BASE_URL_ENV_VAR_NAME = "GROQ_BASE_URL" @@ -225,12 +229,14 @@ async def agenerate( # type: ignore if structured_output: kwargs = self._prepare_kwargs(kwargs, structured_output) - generations = [] completion = await self._aclient.chat.completions.create(**kwargs) # type: ignore if structured_output: - generations.append(completion.model_dump_json()) - return generations + return prepare_output( + [completion.model_dump_json()], + **self._get_llm_statistics(completion._raw_response), + ) + generations = [] for choice in completion.choices: if (content := choice.message.content) is None: self._logger.warning( # type: ignore @@ -238,4 +244,11 @@ async def agenerate( # type: ignore f" Finish reason was: {choice.finish_reason}" ) generations.append(content) - return generations + return prepare_output(generations, **self._get_llm_statistics(completion)) + + @staticmethod + def _get_llm_statistics(completion: "ChatCompletion") -> "LLMStatistics": + return { + "input_tokens": [completion.usage.prompt_tokens if completion else 0], + "output_tokens": [completion.usage.completion_tokens if completion else 0], + } diff --git a/src/distilabel/models/llms/huggingface/inference_endpoints.py b/src/distilabel/models/llms/huggingface/inference_endpoints.py index 3f4bc1856b..c60199452b 100644 --- a/src/distilabel/models/llms/huggingface/inference_endpoints.py +++ b/src/distilabel/models/llms/huggingface/inference_endpoints.py @@ -32,6 +32,7 @@ from distilabel.mixins.runtime_parameters import RuntimeParameter from distilabel.models.llms.base import AsyncLLM from distilabel.models.llms.typing import GenerateOutput +from distilabel.models.llms.utils import compute_tokens, prepare_output from distilabel.models.mixins.magpie import MagpieChatTemplateMixin from distilabel.steps.tasks.typing import ( FormattedInput, @@ -42,6 +43,12 @@ if TYPE_CHECKING: from huggingface_hub import AsyncInferenceClient + from huggingface_hub.inference._generated.types.chat_completion import ( + ChatCompletionOutput, + ) + from huggingface_hub.inference._generated.types.text_generation import ( + TextGenerationOutput, + ) from transformers import PreTrainedTokenizer @@ -387,12 +394,12 @@ async def _generate_with_text_generation( return_full_text: bool = False, seed: Optional[int] = None, watermark: bool = False, - ) -> Union[str, None]: + ) -> GenerateOutput: structured_output = self._get_structured_output(input) completion = None try: - completion = await self._aclient.text_generation( # type: ignore + completion: "TextGenerationOutput" = await self._aclient.text_generation( # type: ignore prompt=self.prepare_input(input), # type: ignore max_new_tokens=max_new_tokens, do_sample=do_sample, @@ -409,13 +416,25 @@ async def _generate_with_text_generation( seed=seed or random.randint(0, sys.maxsize), watermark=watermark, grammar=structured_output, # type: ignore + details=True, ) except Exception as e: self._logger.warning( # type: ignore f"⚠️ Received no response using Inference Client (model: '{self.model_name}')." f" Finish reason was: {e}" ) - return completion + + return prepare_output( + [completion.generated_text], + input_tokens=[ + compute_tokens(self.prepare_input(input), self._tokenizer.encode) + if self._tokenizer + else 0 + ], + output_tokens=[ + completion.details.generated_tokens if completion.details else 0 + ], + ) async def _generate_with_chat_completion( self, @@ -431,10 +450,10 @@ async def _generate_with_chat_completion( tool_prompt: Optional[str] = None, tools: Optional[List[Dict[str, Any]]] = None, top_p: Optional[float] = None, - ) -> Union[str, None]: + ) -> GenerateOutput: message = None try: - completion = await self._aclient.chat_completion( # type: ignore + completion: "ChatCompletionOutput" = await self._aclient.chat_completion( # type: ignore messages=input, # type: ignore max_tokens=max_new_tokens, frequency_penalty=frequency_penalty, @@ -461,7 +480,11 @@ async def _generate_with_chat_completion( f"⚠️ Received no response using Inference Client (model: '{self.model_name}')." f" Finish reason was: {e}" ) - return message + return prepare_output( + [message], + input_tokens=[completion.usage.prompt_tokens], + output_tokens=[completion.usage.completion_tokens], + ) def _check_stop_sequences( self, @@ -574,37 +597,33 @@ async def agenerate( # type: ignore stop_sequences = self._check_stop_sequences(stop_sequences) if self.tokenizer_id is None: - return [ - await self._generate_with_chat_completion( - input=input, # type: ignore - max_new_tokens=max_new_tokens, - frequency_penalty=frequency_penalty, - logit_bias=logit_bias, - presence_penalty=presence_penalty, - seed=seed, - stop_sequences=stop_sequences, - temperature=temperature, - tool_choice=tool_choice, - tool_prompt=tool_prompt, - tools=tools, - top_p=top_p, - ) - ] - - return [ - await self._generate_with_text_generation( - input=input, + return await self._generate_with_chat_completion( + input=input, # type: ignore max_new_tokens=max_new_tokens, - do_sample=do_sample, - typical_p=typical_p, - repetition_penalty=repetition_penalty, frequency_penalty=frequency_penalty, + logit_bias=logit_bias, + presence_penalty=presence_penalty, + seed=seed, + stop_sequences=stop_sequences, temperature=temperature, + tool_choice=tool_choice, + tool_prompt=tool_prompt, + tools=tools, top_p=top_p, - top_k=top_k, - stop_sequences=stop_sequences, - return_full_text=return_full_text, - seed=seed, - watermark=watermark, ) - ] + + return await self._generate_with_text_generation( + input=input, + max_new_tokens=max_new_tokens, + do_sample=do_sample, + typical_p=typical_p, + repetition_penalty=repetition_penalty, + frequency_penalty=frequency_penalty, + temperature=temperature, + top_p=top_p, + top_k=top_k, + stop_sequences=stop_sequences, + return_full_text=return_full_text, + seed=seed, + watermark=watermark, + ) diff --git a/src/distilabel/models/llms/huggingface/transformers.py b/src/distilabel/models/llms/huggingface/transformers.py index e34731a21b..eafd514123 100644 --- a/src/distilabel/models/llms/huggingface/transformers.py +++ b/src/distilabel/models/llms/huggingface/transformers.py @@ -20,6 +20,7 @@ from distilabel.mixins.runtime_parameters import RuntimeParameter from distilabel.models.llms.base import LLM from distilabel.models.llms.typing import GenerateOutput +from distilabel.models.llms.utils import compute_tokens, prepare_output from distilabel.models.mixins.cuda_device_placement import CudaDevicePlacementMixin from distilabel.models.mixins.magpie import MagpieChatTemplateMixin from distilabel.steps.tasks.typing import OutlinesStructuredOutputType, StandardInput @@ -233,11 +234,28 @@ def generate( # type: ignore prefix_allowed_tokens_fn=self._prefix_allowed_tokens_fn, pad_token_id=self._pipeline.tokenizer.eos_token_id, # type: ignore ) - return [ + llm_output = [ [generation["generated_text"] for generation in output] for output in outputs ] + result = [] + for input, output in zip(inputs, llm_output): + result.append( + prepare_output( + output, + input_tokens=[ + compute_tokens(input, self._pipeline.tokenizer.encode) + ], + output_tokens=[ + compute_tokens(row, self._pipeline.tokenizer.encode) + for row in output + ], + ) + ) + + return result + def get_last_hidden_states( self, inputs: List["StandardInput"] ) -> List["HiddenState"]: diff --git a/src/distilabel/models/llms/litellm.py b/src/distilabel/models/llms/litellm.py index 1852d76775..d2471f2991 100644 --- a/src/distilabel/models/llms/litellm.py +++ b/src/distilabel/models/llms/litellm.py @@ -15,11 +15,13 @@ import logging from typing import TYPE_CHECKING, Callable, List, Optional, Union +import orjson from pydantic import Field, PrivateAttr, validate_call from distilabel.mixins.runtime_parameters import RuntimeParameter from distilabel.models.llms.base import AsyncLLM from distilabel.models.llms.typing import GenerateOutput +from distilabel.models.llms.utils import prepare_output from distilabel.steps.tasks.typing import FormattedInput, InstructorStructuredOutputType if TYPE_CHECKING: @@ -194,6 +196,7 @@ async def agenerate( # type: ignore # noqa: C901 A list of lists of strings containing the generated responses for each input. """ import litellm + from litellm import token_counter structured_output = None if isinstance(input, tuple): @@ -256,10 +259,25 @@ async def _call_aclient_until_n_choices() -> List["Choices"]: raise e generations = [] + input_tokens = [ + token_counter(model=self.model, messages=input) + ] * num_generations + output_tokens = [] if self.structured_output: - generations.append([choice.model_dump_json() for choice in choices]) - return generations + for choice in choices: + generations.append(choice.model_dump_json()) + output_tokens.append( + token_counter( + model=self.model, + text=orjson.dumps(choice.model_dump_json()).decode("utf-8"), + ) + ) + return prepare_output( + generations, + input_tokens=input_tokens, + output_tokens=output_tokens, + ) for choice in choices: if (content := choice.message.content) is None: @@ -268,4 +286,8 @@ async def _call_aclient_until_n_choices() -> List["Choices"]: f" Finish reason was: {choice.finish_reason}" ) generations.append(content) - return generations + output_tokens.append(token_counter(model=self.model, text=content)) + + return prepare_output( + generations, input_tokens=input_tokens, output_tokens=output_tokens + ) diff --git a/src/distilabel/models/llms/llamacpp.py b/src/distilabel/models/llms/llamacpp.py index 20b66f8cfe..77e2707c1c 100644 --- a/src/distilabel/models/llms/llamacpp.py +++ b/src/distilabel/models/llms/llamacpp.py @@ -19,6 +19,7 @@ from distilabel.mixins.runtime_parameters import RuntimeParameter from distilabel.models.llms.base import LLM from distilabel.models.llms.typing import GenerateOutput +from distilabel.models.llms.utils import prepare_output from distilabel.steps.tasks.typing import FormattedInput, OutlinesStructuredOutputType if TYPE_CHECKING: @@ -219,6 +220,7 @@ def generate( # type: ignore structured_output = self.structured_output outputs = [] + output_tokens = [] for _ in range(num_generations): # NOTE(plaguss): There seems to be a bug in how the logits processor # is used. Basically it consumes the FSM internally, and it isn't reinitialized @@ -241,7 +243,16 @@ def generate( # type: ignore ) ) outputs.append(chat_completions["choices"][0]["message"]["content"]) - batch_outputs.append(outputs) + output_tokens.append(chat_completions["usage"]["completion_tokens"]) + batch_outputs.append( + prepare_output( + outputs, + input_tokens=[chat_completions["usage"]["prompt_tokens"]] + * num_generations, + output_tokens=output_tokens, + ) + ) + return batch_outputs def _prepare_structured_output( diff --git a/src/distilabel/models/llms/mistral.py b/src/distilabel/models/llms/mistral.py index 5848402757..873565091b 100644 --- a/src/distilabel/models/llms/mistral.py +++ b/src/distilabel/models/llms/mistral.py @@ -20,6 +20,7 @@ from distilabel.mixins.runtime_parameters import RuntimeParameter from distilabel.models.llms.base import AsyncLLM from distilabel.models.llms.typing import GenerateOutput +from distilabel.models.llms.utils import prepare_output from distilabel.steps.tasks.typing import ( FormattedInput, InstructorStructuredOutputType, @@ -27,6 +28,9 @@ if TYPE_CHECKING: from mistralai import Mistral + from mistralai.models.chatcompletionresponse import ChatCompletionResponse + + from distilabel.llms.typing import LLMStatistics _MISTRALAI_API_KEY_ENV_VAR_NAME = "MISTRAL_API_KEY" @@ -221,8 +225,10 @@ async def agenerate( # type: ignore completion = await self._aclient.chat.complete_async(**kwargs) # type: ignore if structured_output: - generations.append(completion.model_dump_json()) - return generations + return prepare_output( + [completion.model_dump_json()], + **self._get_llm_statistics(completion._raw_response), + ) for choice in completion.choices: if (content := choice.message.content) is None: @@ -231,4 +237,12 @@ async def agenerate( # type: ignore f" Finish reason was: {choice.finish_reason}" ) generations.append(content) - return generations + + return prepare_output(generations, **self._get_llm_statistics(completion)) + + @staticmethod + def _get_llm_statistics(completion: "ChatCompletionResponse") -> "LLMStatistics": + return { + "input_tokens": [completion.usage.prompt_tokens], + "output_tokens": [completion.usage.completion_tokens], + } diff --git a/src/distilabel/models/llms/ollama.py b/src/distilabel/models/llms/ollama.py index 009d336aed..f704627487 100644 --- a/src/distilabel/models/llms/ollama.py +++ b/src/distilabel/models/llms/ollama.py @@ -20,11 +20,14 @@ from distilabel.mixins.runtime_parameters import RuntimeParameter from distilabel.models.llms.base import AsyncLLM from distilabel.models.llms.typing import GenerateOutput +from distilabel.models.llms.utils import prepare_output from distilabel.steps.tasks.typing import InstructorStructuredOutputType, StandardInput if TYPE_CHECKING: from ollama import AsyncClient + from distilabel.llms.typing import LLMStatistics + # Copied from `ollama._types.Options` class Options(TypedDict, total=False): @@ -175,4 +178,11 @@ async def agenerate( # type: ignore f" Finish reason was: {e}" ) - return [text] + return prepare_output([text], **self._get_llm_statistics(completion)) + + @staticmethod + def _get_llm_statistics(completion: Dict[str, Any]) -> "LLMStatistics": + return { + "input_tokens": [completion["prompt_eval_count"]], + "output_tokens": [completion["eval_count"]], + } diff --git a/src/distilabel/models/llms/openai.py b/src/distilabel/models/llms/openai.py index 71c7941f3e..a9ccc90dfa 100644 --- a/src/distilabel/models/llms/openai.py +++ b/src/distilabel/models/llms/openai.py @@ -24,6 +24,7 @@ from distilabel.mixins.runtime_parameters import RuntimeParameter from distilabel.models.llms.base import AsyncLLM from distilabel.models.llms.typing import GenerateOutput +from distilabel.models.llms.utils import prepare_output from distilabel.steps.tasks.typing import FormattedInput, InstructorStructuredOutputType if TYPE_CHECKING: @@ -31,7 +32,8 @@ from openai.types import Batch as OpenAIBatch from openai.types import FileObject as OpenAIFileObject from openai.types.chat import ChatCompletion as OpenAIChatCompletion - from pydantic import BaseModel + + from distilabel.llms.typing import LLMStatistics _OPENAI_API_KEY_ENV_VAR_NAME = "OPENAI_API_KEY" @@ -299,26 +301,14 @@ async def agenerate( # type: ignore kwargs = self._prepare_kwargs(kwargs, structured_output) # type: ignore completion = await self._aclient.chat.completions.create(**kwargs) # type: ignore - if structured_output: - return self._generations_from_structured_output(completion) + return prepare_output( + [completion.model_dump_json()], + **self._get_llm_statistics(completion._raw_response), + ) return self._generations_from_openai_completion(completion) - def _generations_from_structured_output( - self, completion: "BaseModel" - ) -> "GenerateOutput": - """Get the generations from the structured output object. - - Args: - completion: an instance of `pydantic.BaseModel` with the content of the structuted - output. - - Returns: - A list with the content of the structured output. - """ - return [completion.model_dump_json()] - def _generations_from_openai_completion( self, completion: "OpenAIChatCompletion" ) -> "GenerateOutput": @@ -338,7 +328,8 @@ def _generations_from_openai_completion( f" Finish reason was: {choice.finish_reason}" ) generations.append(content) - return generations + + return prepare_output(generations, **self._get_llm_statistics(completion)) def offline_batch_generate( self, @@ -685,3 +676,10 @@ def _name_for_openai_files(self, file_no: int) -> str: return f"distilabel-pipeline-fileno-{file_no}.jsonl" return f"distilabel-pipeline-{envs.DISTILABEL_PIPELINE_NAME}-{envs.DISTILABEL_PIPELINE_CACHE_ID}-fileno-{file_no}.jsonl" + + @staticmethod + def _get_llm_statistics(completion: "OpenAIChatCompletion") -> "LLMStatistics": + return { + "input_tokens": [completion.usage.prompt_tokens if completion else 0], + "output_tokens": [completion.usage.completion_tokens if completion else 0], + } diff --git a/src/distilabel/models/llms/typing.py b/src/distilabel/models/llms/typing.py index a19d30cb00..512c76b471 100644 --- a/src/distilabel/models/llms/typing.py +++ b/src/distilabel/models/llms/typing.py @@ -12,9 +12,26 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Any, List, TypeVar, Union +from typing import TYPE_CHECKING, Any, Dict, List, TypedDict, TypeVar, Union + +LLMOutput = List[Union[str, None]] + + +class TokenCount(TypedDict): + input_tokens: List[int] + output_tokens: List[int] + + +LLMStatistics = Union[TokenCount, Dict[str, Any]] +"""Initially the LLMStatistics will contain the token count, but can have more variables. +They can be added once we have them defined for every LLM. +""" + + +class GenerateOutput(TypedDict): + generations: LLMOutput + statistics: LLMStatistics -GenerateOutput = List[Union[str, None]] if TYPE_CHECKING: from numpy import floating diff --git a/src/distilabel/models/llms/utils.py b/src/distilabel/models/llms/utils.py new file mode 100644 index 0000000000..6a5ae78a1e --- /dev/null +++ b/src/distilabel/models/llms/utils.py @@ -0,0 +1,62 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Callable, List, Optional, Union + +from distilabel.steps.tasks.typing import ChatType + +if TYPE_CHECKING: + from distilabel.models.llms.typing import GenerateOutput, LLMOutput + + +def compute_tokens( + text_or_messages: Union[str, ChatType], tokenizer: Callable[[str], List[int]] +) -> int: + """Helper function to count the number of tokens in a text or list of messages. + + Args: + text_or_messages: Either a string response or a list of messages. + tokenizer: A callable function that take str and returns the tokenized version of the text. + + Returns: + The number of tokens. + """ + if isinstance(text_or_messages, list): + return sum([len(tokenizer(message["content"])) for message in text_or_messages]) + else: + return len(tokenizer(text_or_messages)) + + +def prepare_output( + generations: "LLMOutput", + input_tokens: Optional[List[int]] = None, + output_tokens: Optional[List[int]] = None, +) -> "GenerateOutput": + """Helper function to prepare the output of the LLM. + + Args: + generations: The outputs from an LLM. + input_tokens: The number of tokens of the inputs. Defaults to `None`. + output_tokens: The number of tokens of the LLM response. Defaults to `None`. + + Returns: + Output generation from an LLM. + """ + return { + "generations": generations, + "statistics": { + "input_tokens": input_tokens or [], + "output_tokens": output_tokens or [], + }, + } diff --git a/src/distilabel/models/llms/vertexai.py b/src/distilabel/models/llms/vertexai.py index 357a3817e4..c617b7bcf2 100644 --- a/src/distilabel/models/llms/vertexai.py +++ b/src/distilabel/models/llms/vertexai.py @@ -18,11 +18,14 @@ from distilabel.models.llms.base import AsyncLLM from distilabel.models.llms.typing import GenerateOutput +from distilabel.models.llms.utils import prepare_output from distilabel.steps.tasks.typing import StandardInput if TYPE_CHECKING: from vertexai.generative_models import Content, GenerationResponse, GenerativeModel + from distilabel.llms.typing import LLMStatistics + class VertexAILLM(AsyncLLM): """VertexAI LLM implementation running the async API clients for Gemini. @@ -167,8 +170,14 @@ async def agenerate( # type: ignore f"Received no response using VertexAI client (model: '{self.model}')." f" Finish reason was: '{content.candidates[0].finish_reason}'." ) - - return [text] + return prepare_output([text], **self._get_llm_statistics(content)) + + @staticmethod + def _get_llm_statistics(content: "GenerationResponse") -> "LLMStatistics": + return { + "input_tokens": [content.usage_metadata.prompt_token_count], + "output_tokens": [content.usage_metadata.candidates_token_count], + } def _is_gemini_model(model: str) -> bool: diff --git a/src/distilabel/models/llms/vllm.py b/src/distilabel/models/llms/vllm.py index 417aadabed..dd83d8489e 100644 --- a/src/distilabel/models/llms/vllm.py +++ b/src/distilabel/models/llms/vllm.py @@ -33,6 +33,7 @@ from distilabel.models.llms.base import LLM from distilabel.models.llms.openai import OpenAILLM from distilabel.models.llms.typing import GenerateOutput +from distilabel.models.llms.utils import compute_tokens, prepare_output from distilabel.models.mixins.cuda_device_placement import CudaDevicePlacementMixin from distilabel.models.mixins.magpie import MagpieChatTemplateMixin from distilabel.steps.tasks.typing import FormattedInput, OutlinesStructuredOutputType @@ -41,8 +42,11 @@ from openai import OpenAI # noqa from transformers import PreTrainedTokenizer from vllm import LLM as _vLLM + from vllm.outputs import RequestOutputs, CompletionOutput from distilabel.steps.tasks.typing import StandardInput + from distilabel.llms.typing import LLMStatistics + LogitsProcessorFn = Union[ Callable[[List[int], Any], Any], @@ -231,7 +235,7 @@ def prepare_input(self, input: "StandardInput") -> str: The prompt to send to the LLM. """ if self._tokenizer.chat_template is None: - return input[0]["content"] + return [item["content"] for item in input if item["role"] == "user"][0] prompt: str = ( self._tokenizer.apply_chat_template( @@ -267,7 +271,14 @@ def _prepare_batches( batches = {} for i, (instruction, structured_output) in enumerate(inputs): instruction = self.prepare_input(instruction) - instruction_order[instruction] = i + + # We need to convert the instruction to a string to make it hashable + str_instruction = instruction + if not isinstance(instruction, str): + str_instruction = json.dumps(instruction) + + instruction_order[str_instruction] = i + structured_output = json.dumps(structured_output) if structured_output not in batches: batches[structured_output] = [instruction] @@ -280,7 +291,7 @@ def _prepare_batches( ] # Generate the list of indices based on the original order sorted_indices = [ - instruction_order[instruction] for instruction in flat_instructions + instruction_order[str_instruction] for instruction in flat_instructions ] return [ (batch, json.loads(schema)) for schema, batch in batches.items() @@ -353,12 +364,12 @@ def generate( # type: ignore # Simulate a batch without the structured output content prepared_batches = [([self.prepare_input(input) for input in inputs], None)] sorted_indices = None - # Case in which we have a single structured output for the dataset if self._structured_output_logits_processor: logits_processors.append(self._structured_output_logits_processor) batched_outputs = [] + generations = [] for prepared_inputs, structured_output in prepared_batches: if structured_output: @@ -383,23 +394,37 @@ def generate( # type: ignore **extra_sampling_params, ) - batch_outputs = self._model.generate( + batch_outputs: List["RequestOutputs"] = self._model.generate( prepared_inputs, sampling_params, use_tqdm=False, # type: ignore ) + # TODO: This is repeated in prepare_output, but for simplicity we extract + # the batched_outputs as we did when there wasn't statistics and we just + # return the str generations batched_outputs += [ [output.text for output in outputs.outputs] for outputs in batch_outputs ] + for input, outputs in zip(prepared_inputs, batch_outputs): + generations.append( + prepare_output( + [output.text for output in outputs.outputs], + **self._get_llm_statistics(input, outputs), + ) + ) # If logits_processor is set, we need to sort the outputs back to the original order # (would be needed only if we have multiple structured outputs in the dataset) if sorted_indices is not None: - batched_outputs = _sort_batches( - batched_outputs, sorted_indices, num_generations=num_generations + # Sort the batched outputs together with the statistics + generations = self._prepare_sorted_results( + batched_outputs, + sorted_indices, + generations, + num_generations=num_generations, ) - return batched_outputs + return generations def _prepare_structured_output( self, structured_output: Optional[OutlinesStructuredOutputType] = None @@ -421,6 +446,65 @@ def _prepare_structured_output( self.structured_output["schema"] = schema return result["processor"] + def _get_llm_statistics( + self, input: "FormattedInput", outputs: "CompletionOutput" + ) -> "LLMStatistics": + output_tokens = [len(output.token_ids) for output in outputs.outputs] + return { + "input_tokens": [compute_tokens(input, self._tokenizer.encode)] + * len(output_tokens), + "output_tokens": output_tokens, + } + + @staticmethod + def _prepare_sorted_results( + batched_outputs: List[List[FormattedInput]], + sorted_indices: List[int], + generations: List[GenerateOutput], + num_generations: int = 1, + ) -> List[GenerateOutput]: + """Helper method to sort the results in case of multiple structured outputs in the dataset. + + Args: + batched_outputs: The mini-batches generated by the model. + sorted_indices: The indices that would sort the mini-batches back to the original order. + generations: The prepared outputs that would be returned in the general case, + from which the statistics will be extracted and sorted. + num_generations: The number of generations requested to vLLM. Defaults to 1. + + Returns: + The list of GenerateOutput sorted back to the original order. + """ + + # This was the only required sort back with only the generations + batched_outputs = _sort_batches( + batched_outputs, sorted_indices, num_generations=num_generations + ) + # Prepare the statistics to be sorted + # Loop over all the variables in the statistics + # Get the keys from the LLMStatistics + statistic_fields = list(generations[0]["statistics"].keys()) + statistics = {} + for field in statistic_fields: + batched_field = _sort_batches( + [g["statistics"][field] for g in generations], + sorted_indices, + num_generations=num_generations, + ) + statistics[field] = batched_field + + # Regenerates the outputs as they are returned by `prepare_output` + sorted_results = [] + for i, batched_output in enumerate(batched_outputs): + generation = {"generations": batched_output} + statistics = { + field: batched_field[i] for field, batched_field in statistics.items() + } + generation.update({"statistics": statistics}) + sorted_results.append(generation) + + return sorted_results + class ClientvLLM(OpenAILLM, MagpieChatTemplateMixin): """A client for the `vLLM` server implementing the OpenAI API specification. @@ -604,7 +688,8 @@ async def agenerate( # type: ignore f" Finish reason was: {choice.finish_reason}" ) generations.append(text) - return generations + + return prepare_output(generations, **self._get_llm_statistics(completion)) def _sort_batches( diff --git a/src/distilabel/steps/tasks/base.py b/src/distilabel/steps/tasks/base.py index ee2dae790d..9be8f8ee1a 100644 --- a/src/distilabel/steps/tasks/base.py +++ b/src/distilabel/steps/tasks/base.py @@ -33,7 +33,7 @@ from distilabel.utils.dicts import group_dicts if TYPE_CHECKING: - from distilabel.models.llms.typing import GenerateOutput + from distilabel.models.llms.typing import GenerateOutput, LLMStatistics from distilabel.steps.tasks.typing import ChatType, FormattedInput from distilabel.steps.typing import StepOutput @@ -129,7 +129,7 @@ def impute_step_outputs( data = row.copy() for output in self.get_outputs().keys(): data[output] = None - data = self._maybe_add_raw_input_output( + data = self._create_metadata( data, None, None, @@ -169,17 +169,24 @@ def _format_outputs( A list containing a dictionary with the outputs of the task for each input. """ inputs = [None] if input is None else [input] - formatted_outputs = [] - for output, input in zip(outputs, inputs * len(outputs)): # type: ignore + repeate_inputs = len(outputs.get("generations")) + outputs = normalize_statistics(outputs) + + for (output, stats), input in zip( + iterate_generations_with_stats(outputs), inputs * repeate_inputs + ): # type: ignore try: + # Extract the generations, and move the statistics to the distilabel_metadata, + # to keep everything clean formatted_output = self.format_output(output, input) - formatted_output = self._maybe_add_raw_input_output( + formatted_output = self._create_metadata( formatted_output, output, input, add_raw_output=self.add_raw_output, # type: ignore add_raw_input=self.add_raw_input, # type: ignore + statistics=stats, ) formatted_outputs.append(formatted_output) except Exception as e: @@ -198,7 +205,7 @@ def _output_on_failure( # Create a dictionary with the outputs of the task (every output set to None) outputs = {output: None for output in self.outputs} outputs["model_name"] = self.llm.model_name # type: ignore - outputs = self._maybe_add_raw_input_output( + outputs = self._create_metadata( outputs, output, input, @@ -207,16 +214,28 @@ def _output_on_failure( ) return outputs - def _maybe_add_raw_input_output( + def _create_metadata( self, output: Dict[str, Any], - raw_output: Union[str, None], + raw_output: List[Union[str, None]], input: Union[str, None], add_raw_output: bool = True, add_raw_input: bool = True, - ): + statistics: Optional["LLMStatistics"] = None, + ) -> Dict[str, Any]: """Adds the raw output and or the formatted input of the LLM to the output dictionary if `add_raw_output` is True or `add_raw_input` is True. + + Args: + output: + The output dictionary after formatting the output from the LLM, + to add the raw output and or raw input. + raw_output: The raw output of the LLM (the list of generations). + input: The raw input of the LLM. + add_raw_output: Whether to add the raw output to the output dictionary. + add_raw_input: Whether to add the raw input to the output dictionary. + statistics: The statistics generated by the LLM, which should contain at least + the number of input and output tokens. """ meta = output.get(DISTILABEL_METADATA_KEY, {}) @@ -224,6 +243,8 @@ def _maybe_add_raw_input_output( meta[f"raw_output_{self.name}"] = raw_output if add_raw_input: meta[f"raw_input_{self.name}"] = self.format_input(input) if input else None + if statistics: + meta[f"statistics_{self.name}"] = statistics if meta: output[DISTILABEL_METADATA_KEY] = meta @@ -405,13 +426,13 @@ def process(self, inputs: StepInput) -> "StepOutput": # type: ignore formatted_inputs = self._format_inputs(inputs) - # `outputs` is a list containing a list of generations per input + # `outputs` is a dict containing the LLM outputs in the `generations` + # key and the statistics in the `statistics` key outputs = self.llm.generate_outputs( inputs=formatted_inputs, num_generations=self.num_generations, # type: ignore **self.llm.get_generation_kwargs(), # type: ignore ) - task_outputs = [] for input, input_outputs in zip(inputs, outputs): formatted_outputs = self._format_outputs(input_outputs, input) @@ -453,3 +474,64 @@ class GlobalTask(_Task, GlobalStep): """ pass + + +def normalize_statistics(output: "GenerateOutput") -> "GenerateOutput": + """Transforms the GenerateOutput statistics to have the same length as the generations. + + Args: + data: A generate output that possibly has different lengths of statistics + vs generations (due to num_generations=3 returning 3 generations, but + for example the tokens are only counted once). + + Returns: + Normalized statistics according to the generations length. + + Examples: + ```python + data = { + "generations": ["text1", "text2", "text3", "text4"], + "statistics": {"input_tokens": [1], "output_tokens": [1, 2, 3]} + } + normalize_statistics(data) + data = { + "generations": ["text1", "text2", "text3"], + "statistics": {"input_tokens": [1, 1, 1], "output_tokens": [1, 2, 3]} + } + ``` + """ + statistics = output.get("statistics") + if not statistics: + return output + gen_length = len(output["generations"]) + + for stat_key, stat_values in output["statistics"].items(): + current_length = len(stat_values) + + if current_length < gen_length: + # Calculate how many times to repeat the tokens + repeats = gen_length // current_length + remainder = gen_length % current_length + + # Create new list with repeated values + new_values = stat_values * repeats + stat_values[:remainder] + output["statistics"][stat_key] = new_values + + return output + + +def iterate_generations_with_stats(output: "GenerateOutput") -> "GenerateOutput": + """Helper function to iterate together generations and statistics while + processing them inside _format_outputs. + + Args: + output: Output from the LLM.generate_outputs method. + + Yields: + Iterator of generation and statistics paired. + """ + for i, generation in enumerate(output["generations"]): + # Create a new dictionary with the statistics for this index + stats = {key: values[i] for key, values in output["statistics"].items()} + + yield generation, stats diff --git a/src/distilabel/steps/tasks/evol_instruct/base.py b/src/distilabel/steps/tasks/evol_instruct/base.py index 9bbf0de34b..3f2ba5da4f 100644 --- a/src/distilabel/steps/tasks/evol_instruct/base.py +++ b/src/distilabel/steps/tasks/evol_instruct/base.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from collections import defaultdict +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import numpy as np from pydantic import Field @@ -26,6 +27,7 @@ from distilabel.utils.lists import flatten_responses if TYPE_CHECKING: + from distilabel.llms.typing import LLMStatistics from distilabel.steps.typing import StepOutput @@ -267,6 +269,7 @@ def _evolve_instructions(self, inputs: "StepInput") -> List[List[str]]: """ instructions: List[List[str]] = [[input["instruction"]] for input in inputs] + statistics: "LLMStatistics" = defaultdict(list) for iter_no in range(self.num_evolutions): formatted_prompts = [] @@ -276,12 +279,16 @@ def _evolve_instructions(self, inputs: "StepInput") -> List[List[str]]: formatted_prompts = [ self.format_input(prompt) for prompt in formatted_prompts ] + responses = self.llm.generate( + formatted_prompts, + **self.llm.generation_kwargs, # type: ignore + ) generated_prompts = flatten_responses( - self.llm.generate( - formatted_prompts, - **self.llm.generation_kwargs, # type: ignore - ) + [response["generations"] for response in responses] ) + for response in responses: + for k, v in response["statistics"].items(): + statistics[k].append(v[0]) evolved_instructions = [] for generated_prompt in generated_prompts: @@ -304,12 +311,11 @@ def _evolve_instructions(self, inputs: "StepInput") -> List[List[str]]: self._logger.info( f"🔄 Ran iteration {iter_no} evolving {len(instructions)} instructions!" ) - - return instructions + return instructions, dict(statistics) def _generate_answers( self, evolved_instructions: List[List[str]] - ) -> List[List[str]]: + ) -> Tuple[List[List[str]], "LLMStatistics"]: """Generates the answer for the instructions in `instructions`. Args: @@ -331,16 +337,23 @@ def _generate_answers( num_generations=1, **self.llm.generation_kwargs, # type: ignore ) + generations = [response["generations"] for response in responses] + + statistics: Dict[str, Any] = defaultdict(list) + for response in responses: + for k, v in response["statistics"].items(): + statistics[k].append(v[0]) step = ( self.num_evolutions if not self.include_original_instruction else self.num_evolutions + 1 ) + return [ - flatten_responses(responses[i : i + step]) + flatten_responses(generations[i : i + step]) for i in range(0, len(responses), step) - ] + ], dict(statistics) @override def process(self, inputs: StepInput) -> "StepOutput": # type: ignore @@ -353,7 +366,7 @@ def process(self, inputs: StepInput) -> "StepOutput": # type: ignore A list of Python dictionaries with the outputs of the task. """ - evolved_instructions = self._evolve_instructions(inputs) + evolved_instructions, statistics = self._evolve_instructions(inputs) if self.store_evolutions: # Remove the input instruction from the `evolved_instructions` list @@ -365,6 +378,13 @@ def process(self, inputs: StepInput) -> "StepOutput": # type: ignore if not self.generate_answers: for input, instruction in zip(inputs, evolved_instructions): input.update(self.format_output(instruction)) + input.update( + { + "distilabel_metadata": { + f"statistics_instruction_{self.name}": statistics + } + } + ) yield inputs self._logger.info( @@ -376,7 +396,7 @@ def process(self, inputs: StepInput) -> "StepOutput": # type: ignore f"🧠 Generating answers for the {len(evolved_instructions)} evolved instructions!" ) - answers = self._generate_answers(evolved_instructions) + answers, statistics = self._generate_answers(evolved_instructions) self._logger.info( f"🎉 Finished generating answers for the {len(evolved_instructions)} evolved" @@ -387,6 +407,13 @@ def process(self, inputs: StepInput) -> "StepOutput": # type: ignore zip(inputs, evolved_instructions) ): input.update(self.format_output(instruction, answers[idx])) + input.update( + { + "distilabel_metadata": { + f"statistics_answer_{self.name}": statistics + } + } + ) yield inputs @override diff --git a/src/distilabel/steps/tasks/evol_instruct/generator.py b/src/distilabel/steps/tasks/evol_instruct/generator.py index 335e9844f0..fa568d392e 100644 --- a/src/distilabel/steps/tasks/evol_instruct/generator.py +++ b/src/distilabel/steps/tasks/evol_instruct/generator.py @@ -19,8 +19,9 @@ else: import importlib.resources as importlib_resources +from collections import defaultdict from functools import cached_property -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple import numpy as np from pydantic import Field, PrivateAttr @@ -32,6 +33,7 @@ from distilabel.utils.lists import flatten_responses if TYPE_CHECKING: + from distilabel.llms.typing import LLMStatistics from distilabel.steps.tasks.typing import ChatType from distilabel.steps.typing import GeneratorStepOutput @@ -256,7 +258,9 @@ def _apply_random_mutation(self, iter_no: int) -> List["ChatType"]: prompts.append([{"role": "user", "content": prompt_with_template}]) return prompts - def _generate_answers(self, instructions: List[List[str]]) -> List[str]: + def _generate_answers( + self, instructions: List[List[str]] + ) -> Tuple[List[str], "LLMStatistics"]: """Generates the answer for the last instruction in `instructions`. Args: @@ -276,10 +280,17 @@ def _generate_answers(self, instructions: List[List[str]]) -> List[str]: _formatted_instructions, **self.llm.generation_kwargs, # type: ignore ) - return flatten_responses(responses) + statistics: Dict[str, Any] = defaultdict(list) + for response in responses: + for k, v in response["statistics"].items(): + statistics[k].append(v[0]) + + return flatten_responses( + [response["generations"] for response in responses] + ), dict(statistics) @override - def process(self, offset: int = 0) -> "GeneratorStepOutput": # type: ignore + def process(self, offset: int = 0) -> "GeneratorStepOutput": # NOQA: C901, type: ignore """Processes the inputs of the task and generates the outputs using the LLM. Args: @@ -297,9 +308,17 @@ def process(self, offset: int = 0) -> "GeneratorStepOutput": # type: ignore while len(instructions) < self.num_instructions: prompts = self._apply_random_mutation(iter_no=iter_no) + # TODO: Update the function to extract from the dict + responses = self.llm.generate(prompts, **self.llm.generation_kwargs) # type: ignore + generated_prompts = flatten_responses( - self.llm.generate(prompts, **self.llm.generation_kwargs) # type: ignore + [response["generations"] for response in responses] ) + statistics: "LLMStatistics" = defaultdict(list) + for response in responses: + for k, v in response["statistics"].items(): + statistics[k].append(v[0]) + for idx, generated_prompt in enumerate(generated_prompts): generated_prompt = generated_prompt.split("Prompt#:")[-1].strip() if self.max_length >= len(generated_prompt) >= self.min_length: # type: ignore @@ -319,11 +338,15 @@ def process(self, offset: int = 0) -> "GeneratorStepOutput": # type: ignore mutation_no = len(instructions) - mutation_no if not self.generate_answers and len(instructions[-mutation_no:]) > 0: + formatted_generations = [] + for mutated_instruction in instructions[-mutation_no:]: + mutated_instruction = self.format_output(mutated_instruction) + mutated_instruction["distilabel_metadata"] = { + f"statistics_instruction_{self.name}": dict(statistics) + } + formatted_generations.append(mutated_instruction) yield ( - [ - self.format_output(mutated_instruction) - for mutated_instruction in instructions[-mutation_no:] - ], + formatted_generations, len(instructions) >= self.num_instructions, ) @@ -334,17 +357,22 @@ def process(self, offset: int = 0) -> "GeneratorStepOutput": # type: ignore f"🧠 Generating answers for the {len(instructions)} evolved instructions!" ) - answers = self._generate_answers(instructions) + answers, statistics = self._generate_answers(instructions) self._logger.info( f"🎉 Finished generating answers for the {len(instructions)} evolved instructions!" ) + formatted_outputs = [] + for instruction, answer in zip(instructions, answers): + formatted_output = self.format_output(instruction, answer) + formatted_output["distilabel_metadata"] = { + f"statistics_answer_{self.name}": dict(statistics) + } + formatted_outputs.append(formatted_output) + yield ( - [ - self.format_output(instruction, answer) - for instruction, answer in zip(instructions, answers) - ], + formatted_outputs, True, ) diff --git a/src/distilabel/steps/tasks/evol_quality/base.py b/src/distilabel/steps/tasks/evol_quality/base.py index b7d2690c35..8ea7061105 100644 --- a/src/distilabel/steps/tasks/evol_quality/base.py +++ b/src/distilabel/steps/tasks/evol_quality/base.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Any, Dict, List, Union +from collections import defaultdict +from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union import numpy as np from pydantic import Field @@ -200,7 +201,9 @@ def _apply_random_mutation(self, instruction: str, response: str) -> str: .replace("", response) ) - def _evolve_reponses(self, inputs: "StepInput") -> List[List[str]]: + def _evolve_reponses( + self, inputs: "StepInput" + ) -> Tuple[List[List[str]], Dict[str, Any]]: """Evolves the instructions provided as part of the inputs of the task. Args: @@ -213,6 +216,7 @@ def _evolve_reponses(self, inputs: "StepInput") -> List[List[str]]: np.random.seed(self.seed) instructions: List[List[str]] = [[input["instruction"]] for input in inputs] responses: List[List[str]] = [[input["response"]] for input in inputs] + statistics: Dict[str, Any] = defaultdict(list) for iter_no in range(self.num_evolutions): formatted_prompts = [] @@ -229,24 +233,28 @@ def _evolve_reponses(self, inputs: "StepInput") -> List[List[str]]: formatted_prompts, **self.llm.generation_kwargs, # type: ignore ) + for response in generated_responses: + for k, v in response["statistics"].items(): + statistics[k].append(v[0]) if self.store_evolutions: responses = [ - response + [evolved_response[0]] + response + [evolved_response["generations"][0]] for response, evolved_response in zip( responses, generated_responses ) ] else: responses = [ - [evolved_response[0]] for evolved_response in generated_responses + [evolved_response["generations"][0]] + for evolved_response in generated_responses ] self._logger.info( f"🔄 Ran iteration {iter_no} evolving {len(responses)} responses!" ) - return responses + return responses, dict(statistics) @override def process(self, inputs: StepInput) -> "StepOutput": # type: ignore @@ -259,7 +267,7 @@ def process(self, inputs: StepInput) -> "StepOutput": # type: ignore A list of Python dictionaries with the outputs of the task. """ - responses = self._evolve_reponses(inputs) + responses, statistics = self._evolve_reponses(inputs) if self.store_evolutions: # Remove the input instruction from the `evolved_responses` list @@ -268,6 +276,9 @@ def process(self, inputs: StepInput) -> "StepOutput": # type: ignore for input, response in zip(inputs, responses): input.update(self.format_output(response)) + input.update( + {"distilabel_metadata": {f"statistics_{self.name}": statistics}} + ) yield inputs self._logger.info(f"🎉 Finished evolving {len(responses)} instructions!") diff --git a/src/distilabel/steps/tasks/magpie/base.py b/src/distilabel/steps/tasks/magpie/base.py index 5135e13ae0..265497409c 100644 --- a/src/distilabel/steps/tasks/magpie/base.py +++ b/src/distilabel/steps/tasks/magpie/base.py @@ -27,13 +27,16 @@ from distilabel.models.mixins.magpie import MagpieChatTemplateMixin from distilabel.steps.base import StepInput from distilabel.steps.tasks.base import Task +from distilabel.utils.dicts import merge_dicts if TYPE_CHECKING: + from distilabel.models.llms.typing import LLMStatistics from distilabel.steps.tasks.typing import ChatType from distilabel.steps.typing import StepColumns, StepOutput + MAGPIE_MULTI_TURN_SYSTEM_PROMPT = ( - "You are a helpful Al assistant. The user will engage in a multi−round conversation" + "You are a helpful Al assistant. The user will engage in a multi-round conversation" " with you, asking initial questions and following up with additional related questions." " Your goal is to provide thorough, relevant and insightful responses to help the user" " with their queries." @@ -192,15 +195,25 @@ def _generate_instruction( num_generations=1, **self.llm.generation_kwargs, # type: ignore ) + stats = [] rows = [] for output, system_prompt_key in zip_longest( outputs, system_prompt_keys, fillvalue=None ): - row = {"instruction": output[0]} # type: ignore + row = { + "instruction": output["generations"][0], + "distilabel_metadata": { + f"statistics_{self.name}": output["statistics"] + }, + } # type: ignore if system_prompt_key is not None: row["system_prompt_key"] = system_prompt_key rows.append(row) - return rows + stats.append( + {} + ) # Mimics the stats to keep _generate_with_pre_query_template + + return rows, stats def _prepare_conversation_outputs( self, conversations: List["ChatType"], system_prompt_keys: List[str] @@ -245,18 +258,22 @@ def _prepare_conversation_outputs( def _generate_conversation_turn( self, role: str, conversations: List["ChatType"], active_indices: List[int] - ) -> Tuple[List["ChatType"], List[int]]: + ) -> Tuple[List["ChatType"], List[int], "LLMStatistics"]: # Generate an output for the conversations that are still active (no previous `None`s) outputs = self.llm.generate( inputs=[conversations[idx] for idx in active_indices], num_generations=1, **self.llm.generation_kwargs, # type: ignore ) + # Extract the single message from the conversation and the statistics in separate lists + messages, statistics = zip( + *[(output["generations"][0], output["statistics"]) for output in outputs] + ) active_conversations = [conversations[idx] for idx in active_indices] updated_conversations = self._append_messages_to_conversations( role=role, - messages=[output[0] for output in outputs], + messages=messages, conversations=active_conversations, ) @@ -264,10 +281,10 @@ def _generate_conversation_turn( conversations[idx] = conv new_active_indices = [ - idx for idx, output in zip(active_indices, outputs) if output[0] is not None + idx for idx, output in zip(active_indices, outputs) if output is not None ] - return conversations, new_active_indices + return conversations, new_active_indices, statistics def _generate_multi_turn_conversation( self, inputs: List[Dict[str, Any]] @@ -278,30 +295,45 @@ def _generate_multi_turn_conversation( # Keep track of the active conversations, as it could happen that for some conversation # we can't generate the next turn because the `LLM` returned `None`. active_indices = list(range(len(conversations))) - + stats = [] for i in range(self.n_turns): # type: ignore if not active_indices: break # Generate user message - conversations, active_indices = self._generate_conversation_turn( - role="user", conversations=conversations, active_indices=active_indices + conversations, active_indices, statistics_user = ( + self._generate_conversation_turn( + role="user", + conversations=conversations, + active_indices=active_indices, + ) ) if i == self.n_turns - 1 and self.end_with_user: # type: ignore + statistics = merge_dicts(*[statistics_user]) + stats.append(statistics) break if not active_indices: break # Generate assistant message - conversations, active_indices = self._generate_conversation_turn( - role="assistant", - conversations=conversations, - active_indices=active_indices, + conversations, active_indices, statistics_assistant = ( + self._generate_conversation_turn( + role="assistant", + conversations=conversations, + active_indices=active_indices, + ) ) + # Merge the statistics of the user and assistant messages to have the same shape as the conversations + statistics = merge_dicts(*[statistics_user, statistics_assistant]) + stats.append(statistics) - return self._prepare_conversation_outputs(conversations, system_prompt_keys) + # Merge the dicts again at the conversation level + stats = merge_dicts(*stats) + return self._prepare_conversation_outputs( + conversations, system_prompt_keys + ), stats def _generate_with_pre_query_template( self, inputs: List[Dict[str, Any]] @@ -314,16 +346,22 @@ def _generate_with_pre_query_template( Returns: The list of generated conversations. """ - outputs = ( + outputs, statistics = ( self._generate_instruction(inputs) if self.only_instruction else self._generate_multi_turn_conversation(inputs) ) - - return [ - {**input, **output, "model_name": self.llm.model_name} - for input, output in zip(inputs, outputs) - ] + generations = [] + for input, output, stats in zip(inputs, outputs, statistics): + generation = { + **input, + **output, + "model_name": self.llm.model_name, + } + if not self.only_instruction: + generation["distilabel_metadata"] = {f"statistics_{self.name}": stats} + generations.append(generation) + return generations class Magpie(Task, MagpieBase): diff --git a/src/distilabel/steps/typing.py b/src/distilabel/steps/typing.py index 720037a74f..4f6f53d5d9 100644 --- a/src/distilabel/steps/typing.py +++ b/src/distilabel/steps/typing.py @@ -14,8 +14,24 @@ from typing import Any, Dict, Iterator, List, Tuple, Union -StepOutput = Iterator[List[Dict[str, Any]]] -"""`StepOutput` is an alias of the typing `Iterator[List[Dict[str, Any]]]`""" +StepData = List[Dict[str, Any]] +StepStatistics = Dict[str, Any] +StepOutput = Iterator[Dict[str, Union[StepData, StepStatistics]]] +r"""`StepOutput` is an alias of the typing. +A step output is a dict of the form: +{ + "outputs": [ + {"col1": "val1", "col2": "val2"}, + {"col1": "val1", "col2": "val2"}, + {"col1": "val1", "col2": "val2"}, + ], + "statistics": { + "llm": {}, + "time": 12341234, + ... + } +} +""" GeneratorStepOutput = Iterator[Tuple[List[Dict[str, Any]], bool]] """`GeneratorStepOutput` is an alias of the typing `Iterator[Tuple[List[Dict[str, Any]], bool]]`""" diff --git a/src/distilabel/utils/dicts.py b/src/distilabel/utils/dicts.py index 6c651ae32a..b2d0af9a55 100644 --- a/src/distilabel/utils/dicts.py +++ b/src/distilabel/utils/dicts.py @@ -14,17 +14,19 @@ import json from collections import defaultdict +from itertools import chain from typing import Any, Dict, List, TypeVar _K = TypeVar("_K") -def group_dicts(*dicts: Dict[_K, Any]) -> Dict[_K, List[Any]]: +def group_dicts(*dicts: Dict[_K, Any], flatten: bool = False) -> Dict[_K, List[Any]]: """Combines multiple dictionaries into a single dictionary joining the values as a list for each key. Args: *dicts: the dictionaries to be combined. + flatten: whether to flatten the list of values for each key. Returns: The combined dictionary. @@ -33,8 +35,66 @@ def group_dicts(*dicts: Dict[_K, Any]) -> Dict[_K, List[Any]]: for d in dicts: for key, value in d.items(): combined_dict[key].append(value) - return dict(combined_dict) + + combined_dict = dict(combined_dict) + if flatten: + combined_dict = { + k: list(chain.from_iterable(v)) for k, v in combined_dict.items() + } + return combined_dict def flatten_dict(x: Dict[Any, Any]) -> Dict[Any, Any]: return {k: json.dumps(v) if isinstance(v, dict) else v for k, v in x.items()} + + +def merge_dicts(*dict_lists: dict) -> list[dict]: + """ + Merge N lists of dictionaries with matching keys. + The keys can be any strings, but they must match across all dictionaries within each position. + + Args: + *dict_lists: Variable number of lists of dictionaries + + Returns: + list: Merged list of dictionaries with combined values + + Raises: + ValueError: If lists have different lengths or dictionaries have mismatched keys + """ + if not dict_lists: + return [] + + # Verify all lists have the same length + first_len = len(dict_lists[0]) + if not all(len(d) == first_len for d in dict_lists): + raise ValueError("All input lists must have the same length") + + # For each position, get keys from first list's dictionary + result = [] + for i in range(first_len): + # Get keys from the first dictionary at this position + keys = set(dict_lists[0][i].keys()) + + # Verify all dictionaries at this position have the same keys + for dict_list in dict_lists: + if set(dict_list[i].keys()) != keys: + raise ValueError( + f"All dictionaries at position {i} must have the same keys" + ) + + merged_dict = {key: [] for key in keys} + + # For each dictionary at position i in all lists + for dict_list in dict_lists: + current_dict = dict_list[i] + for key in keys: + # Ensure value is a list + value = current_dict[key] + if not isinstance(value, list): + value = [value] + merged_dict[key].extend(value) + + result.append(merged_dict) + + return result diff --git a/tests/integration/test_offline_batch_generation.py b/tests/integration/test_offline_batch_generation.py index e3dea4af56..ae34d04159 100644 --- a/tests/integration/test_offline_batch_generation.py +++ b/tests/integration/test_offline_batch_generation.py @@ -51,8 +51,15 @@ def offline_batch_generate( raise DistilabelOfflineBatchGenerationNotFinishedException( jobs_ids=self.jobs_ids # type: ignore ) - - return [["output" for _ in range(num_generations)]] + return [ + { + "generations": [f"output {i}" for i in range(num_generations)], + "statistics": { + "input_tokens": [12] * num_generations, + "output_tokens": [12] * num_generations, + }, + } + ] * len(inputs) def test_offline_batch_generation() -> None: diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index b3ec2de908..9aa4ea3361 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -15,6 +15,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Union import pytest +from pydantic import PrivateAttr from distilabel.models.llms.base import LLM, AsyncLLM from distilabel.models.mixins.magpie import MagpieChatTemplateMixin @@ -28,9 +29,11 @@ # Defined here too, so that the serde still works class DummyAsyncLLM(AsyncLLM): structured_output: Any = None + n_generations_supported: bool = True # To work as OpenAI or an LLM that doesn't allow num_generations out of the box + _num_generations_param_supported: bool = PrivateAttr(default=True) def load(self) -> None: - pass + self._num_generations_param_supported = self.n_generations_supported @property def model_name(self) -> str: @@ -39,7 +42,13 @@ def model_name(self) -> str: async def agenerate( # type: ignore self, input: "FormattedInput", num_generations: int = 1 ) -> "GenerateOutput": - return ["output" for _ in range(num_generations)] + return { + "generations": ["output" for i in range(num_generations)], + "statistics": { + "input_tokens": [12] * num_generations, + "output_tokens": [12] * num_generations, + }, + } class DummyLLM(LLM): @@ -55,7 +64,15 @@ def model_name(self) -> str: def generate( # type: ignore self, inputs: "FormattedInput", num_generations: int = 1 ) -> List["GenerateOutput"]: - return [["output" for _ in range(num_generations)]] + return [ + { + "generations": [f"output {i}" for i in range(num_generations)], + "statistics": { + "input_tokens": [12] * num_generations, + "output_tokens": [12] * num_generations, + }, + } + ] * len(inputs) class DummyMagpieLLM(LLM, MagpieChatTemplateMixin): @@ -70,7 +87,14 @@ def generate( self, inputs: List["FormattedInput"], num_generations: int = 1, **kwargs: Any ) -> List["GenerateOutput"]: return [ - ["Hello Magpie" for _ in range(num_generations)] for _ in range(len(inputs)) + { + "generations": ["Hello Magpie"] * num_generations, + "statistics": { + "input_tokens": [12] * num_generations, + "output_tokens": [12] * num_generations, + }, + } + for _ in range(len(inputs)) ] diff --git a/tests/unit/models/llms/huggingface/test_inference_endpoints.py b/tests/unit/models/llms/huggingface/test_inference_endpoints.py index f4054b6736..874cd9a595 100644 --- a/tests/unit/models/llms/huggingface/test_inference_endpoints.py +++ b/tests/unit/models/llms/huggingface/test_inference_endpoints.py @@ -14,7 +14,7 @@ import os import random -from typing import Generator +from typing import Any, Dict, Generator, List from unittest import mock from unittest.mock import AsyncMock, MagicMock, patch @@ -130,17 +130,29 @@ async def test_agenerate_with_text_generation( llm.load() llm._aclient.text_generation = AsyncMock( - return_value=" Aenean hendrerit aliquam velit. ..." + return_value=MagicMock( + generated_text="Aenean hendrerit aliquam velit...", + details=MagicMock( + generated_tokens=66, + ), + ) ) - assert await llm.agenerate( + result = await llm.agenerate( input=[ { "role": "user", "content": "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", }, ] - ) == [" Aenean hendrerit aliquam velit. ..."] + ) + assert result == { + "generations": ["Aenean hendrerit aliquam velit..."], + "statistics": { + "input_tokens": [31], + "output_tokens": [66], + }, + } @pytest.mark.asyncio async def test_agenerate_with_chat_completion( @@ -173,14 +185,21 @@ async def test_agenerate_with_chat_completion( ) ) - assert await llm.agenerate( + result = await llm.agenerate( input=[ { "role": "user", "content": "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", }, ] - ) == [" Aenean hendrerit aliquam velit. ..."] + ) + assert result == { + "generations": [" Aenean hendrerit aliquam velit. ..."], + "statistics": { + "input_tokens": [18], + "output_tokens": [66], + }, + } @pytest.mark.asyncio async def test_agenerate_with_chat_completion_fails( @@ -213,30 +232,85 @@ async def test_agenerate_with_chat_completion_fails( ) ) - assert await llm.agenerate( + result = await llm.agenerate( input=[ { "role": "user", "content": "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", }, ] - ) == [None] + ) + assert result == { + "generations": [None], + "statistics": { + "input_tokens": [18], + "output_tokens": [66], + }, + } + @pytest.mark.parametrize( + "num_generations, expected_result", + [ + ( + 1, + [ + { + "generations": ["text"], + "statistics": {"input_tokens": [18], "output_tokens": [66]}, + } + ], + ), + ( + 2, + [ + { + "generations": ["text"] * 2, + "statistics": { + "input_tokens": [18, 18], + "output_tokens": [66, 66], + }, + } + ], + ), + ], + ) @pytest.mark.asyncio - async def test_generate(self, mock_inference_client: MagicMock) -> None: + async def test_generate( + self, + mock_inference_client: MagicMock, + num_generations: int, + expected_result: List[Dict[str, Any]], + ) -> None: llm = InferenceEndpointsLLM( model_id="distilabel-internal-testing/tiny-random-mistral", - tokenizer_id="distilabel-internal-testing/tiny-random-mistral", ) llm.load() - llm._aclient.text_generation = AsyncMock( - return_value=" Aenean hendrerit aliquam velit. ..." + llm._aclient.chat_completion = AsyncMock( # type: ignore + return_value=ChatCompletionOutput( # type: ignore + choices=[ + ChatCompletionOutputComplete( + finish_reason="eos_token", + index=0, + message=ChatCompletionOutputMessage( + role="assistant", + content="text", + ), + ) + ] + * num_generations, + created=1721045246, + id="", + model="meta-llama/Meta-Llama-3-70B-Instruct", + system_fingerprint="2.1.1-dev0-sha-4327210", + usage=ChatCompletionOutputUsage( + completion_tokens=66, prompt_tokens=18, total_tokens=84 + ), + ) ) nest_asyncio.apply() - - assert llm.generate( + result = llm.generate( inputs=[ [ {"role": "system", "content": ""}, @@ -245,8 +319,10 @@ async def test_generate(self, mock_inference_client: MagicMock) -> None: "content": "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", }, ] - ] - ) == [[" Aenean hendrerit aliquam velit. ..."]] + ], + num_generations=num_generations, + ) + assert result == expected_result @pytest.mark.asyncio async def test_agenerate_with_structured_output( @@ -260,39 +336,32 @@ async def test_agenerate_with_structured_output( llm.load() llm._aclient.text_generation = AsyncMock( - return_value=" Aenean hendrerit aliquam velit. ..." + return_value=MagicMock( + generated_text="Aenean hendrerit aliquam velit...", + details=MagicMock( + generated_tokens=66, + ), + ) ) - # Since there's a pseudo-random number within the generation kwargs, we set the seed # here first to ensure reproducibility within the tests random.seed(42) - assert await llm.agenerate( + result = await llm.agenerate( input=[ { "role": "user", "content": "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", }, ] - ) == [" Aenean hendrerit aliquam velit. ..."] - - kwargs = { - "prompt": " [INST] Lorem ipsum dolor sit amet, consectetur adipiscing elit. [/INST]", - "max_new_tokens": 128, - "do_sample": False, - "typical_p": None, - "repetition_penalty": None, - "frequency_penalty": None, - "temperature": 1.0, - "top_p": None, - "top_k": None, - "stop_sequences": None, - "return_full_text": False, - "seed": 2053695854357871005, # pre-computed random value with `random.seed(42)` - "watermark": False, - "grammar": {"type": "regex", "value": "\\b[A-Z][a-z]*\\b"}, + ) + assert result == { + "generations": ["Aenean hendrerit aliquam velit..."], + "statistics": { + "input_tokens": [31], + "output_tokens": [66], + }, } - llm._aclient.text_generation.assert_called_with(**kwargs) # type: ignore def test_serialization(self, mock_inference_client: MagicMock) -> None: llm = InferenceEndpointsLLM( diff --git a/tests/unit/models/llms/huggingface/test_transformers.py b/tests/unit/models/llms/huggingface/test_transformers.py index a298ff737e..bbaaa787c2 100644 --- a/tests/unit/models/llms/huggingface/test_transformers.py +++ b/tests/unit/models/llms/huggingface/test_transformers.py @@ -53,9 +53,23 @@ def test_generate(self, transformers_llm: TransformersLLM) -> None: ], num_generations=3, ) - + # Note: It returns the following structure: + # [ + # { + # "generations": [text1, text2, text3], # As much as num_generations + # "statistics": { + # "input_tokens": [7], + # "output_tokens": [128, 128, 128], # The sum of the tokens of the generated texts + # }, + # }, + # {...} + # ] assert len(responses) == 2 - assert len(responses[0]) == 3 + generations = responses[0]["generations"] + statistics = responses[0]["statistics"] + assert len(generations) == 3 + assert "input_tokens" in statistics + assert "output_tokens" in statistics def test_get_last_hidden_states(self, transformers_llm: TransformersLLM) -> None: inputs = [ diff --git a/tests/unit/models/llms/test_anthropic.py b/tests/unit/models/llms/test_anthropic.py index 3051b99789..205e96b984 100644 --- a/tests/unit/models/llms/test_anthropic.py +++ b/tests/unit/models/llms/test_anthropic.py @@ -37,12 +37,14 @@ async def test_agenerate(self, mock_anthropic: MagicMock) -> None: llm = AnthropicLLM(model="claude-3-opus-20240229", api_key="api.key") # type: ignore llm._aclient = mock_anthropic - mocked_completion = Mock() - mocked_completion.content = [Mock(text="Aenean hendrerit aliquam velit...")] + mocked_completion = Mock( + content=[Mock(text="Aenean hendrerit aliquam velit...")], + usage=Mock(input_tokens=100, output_tokens=100), + ) llm._aclient.messages.create = AsyncMock(return_value=mocked_completion) - await llm.agenerate( + result = await llm.agenerate( input=[ {"role": "system", "content": ""}, { @@ -51,6 +53,10 @@ async def test_agenerate(self, mock_anthropic: MagicMock) -> None: }, ] ) + assert result == { + "generations": ["Aenean hendrerit aliquam velit..."], + "statistics": {"input_tokens": [100], "output_tokens": [100]}, + } @pytest.mark.asyncio async def test_agenerate_structured(self, mock_openai: MagicMock) -> None: @@ -65,8 +71,12 @@ async def test_agenerate_structured(self, mock_openai: MagicMock) -> None: ) # type: ignore llm._aclient = mock_openai - sample_user = DummyUserDetail(name="John Doe", age=30) - + mocked_usage = MagicMock( + usage=MagicMock(input_tokens=100, output_tokens=100), + ) + sample_user = DummyUserDetail( + name="John Doe", age=30, _raw_response=mocked_usage + ) llm._aclient.messages.create = AsyncMock(return_value=sample_user) generation = await llm.agenerate( @@ -78,7 +88,13 @@ async def test_agenerate_structured(self, mock_openai: MagicMock) -> None: }, ] ) - assert generation[0] == sample_user.model_dump_json() + assert generation == { + "generations": [sample_user.model_dump_json()], + "statistics": { + "input_tokens": [100], + "output_tokens": [100], + }, + } @pytest.mark.skipif( sys.version_info < (3, 9), reason="`mistralai` requires Python 3.9 or higher" @@ -88,14 +104,16 @@ async def test_generate(self, mock_anthropic: MagicMock) -> None: llm = AnthropicLLM(model="claude-3-opus-20240229") # type: ignore llm._aclient = mock_anthropic - mocked_completion = Mock() - mocked_completion.content = [Mock(text="Aenean hendrerit aliquam velit...")] + mocked_completion = Mock( + content=[Mock(text="Aenean hendrerit aliquam velit...")], + usage=Mock(input_tokens=100, output_tokens=100), + ) llm._aclient.messages.create = AsyncMock(return_value=mocked_completion) nest_asyncio.apply() - llm.generate( + result = llm.generate( inputs=[ [ {"role": "system", "content": ""}, @@ -106,6 +124,12 @@ async def test_generate(self, mock_anthropic: MagicMock) -> None: ] ] ) + assert result == [ + { + "generations": ["Aenean hendrerit aliquam velit..."], + "statistics": {"input_tokens": [100], "output_tokens": [100]}, + } + ] @pytest.mark.parametrize( "structured_output, dump", diff --git a/tests/unit/models/llms/test_cohere.py b/tests/unit/models/llms/test_cohere.py index 4b0a83cbb3..7ce0f359af 100644 --- a/tests/unit/models/llms/test_cohere.py +++ b/tests/unit/models/llms/test_cohere.py @@ -19,6 +19,7 @@ import nest_asyncio import pytest +from tokenizers import Tokenizer from distilabel.models.llms.cohere import CohereLLM @@ -50,16 +51,12 @@ async def test_agenerate(self, mock_async_client: mock.MagicMock) -> None: llm = CohereLLM(model="command-r") llm._aclient = mock_async_client # type: ignore - mocked_completion = mock.Mock( - choices=[ - mock.Mock( - message=mock.Mock(content=" Aenean hendrerit aliquam velit. ...") - ) - ] - ) + mocked_completion = mock.Mock(text="Aenean hendrerit aliquam velit...") llm._aclient.chat = mock.AsyncMock(return_value=mocked_completion) - await llm.agenerate( + llm._tokenizer = Tokenizer.from_pretrained("bert-base-uncased") + + result = await llm.agenerate( input=[ {"role": "system", "content": ""}, { @@ -68,9 +65,13 @@ async def test_agenerate(self, mock_async_client: mock.MagicMock) -> None: }, ] ) + assert result == { + "generations": ["Aenean hendrerit aliquam velit..."], + "statistics": {"input_tokens": [25], "output_tokens": [16]}, + } @pytest.mark.skipif( - sys.version_info < (3, 9), reason="`mistralai` requires Python 3.9 or higher" + sys.version_info < (3, 9), reason="`cohere` requires Python 3.9 or higher" ) @pytest.mark.asyncio async def test_agenerate_structured( @@ -89,6 +90,7 @@ async def test_agenerate_structured( sample_user = DummyUserDetail(name="John Doe", age=30) llm._aclient.chat = mock.AsyncMock(return_value=sample_user) + llm._tokenizer = Tokenizer.from_pretrained("bert-base-uncased") generation = await llm.agenerate( input=[ @@ -99,25 +101,23 @@ async def test_agenerate_structured( }, ] ) - assert generation == [sample_user.model_dump_json()] + assert generation == { + "generations": [sample_user.model_dump_json()], + "statistics": {"input_tokens": [25], "output_tokens": [26]}, + } @pytest.mark.asyncio async def test_generate(self, mock_async_client: mock.MagicMock) -> None: llm = CohereLLM(model="command-r") llm._aclient = mock_async_client # type: ignore - mocked_completion = mock.Mock( - choices=[ - mock.Mock( - message=mock.Mock(content=" Aenean hendrerit aliquam velit. ...") - ) - ] - ) + mocked_completion = mock.Mock(text="Aenean hendrerit aliquam velit...") llm._aclient.chat = mock.AsyncMock(return_value=mocked_completion) + llm._tokenizer = Tokenizer.from_pretrained("bert-base-uncased") nest_asyncio.apply() - llm.generate( + result = llm.generate( inputs=[ [ {"role": "system", "content": ""}, @@ -128,6 +128,12 @@ async def test_generate(self, mock_async_client: mock.MagicMock) -> None: ] ] ) + assert result == [ + { + "generations": ["Aenean hendrerit aliquam velit..."], + "statistics": {"input_tokens": [25], "output_tokens": [16]}, + } + ] @pytest.mark.parametrize( "structured_output, dump", diff --git a/tests/unit/models/llms/test_groq.py b/tests/unit/models/llms/test_groq.py index ce80c02c8a..3d64c7ab24 100644 --- a/tests/unit/models/llms/test_groq.py +++ b/tests/unit/models/llms/test_groq.py @@ -38,7 +38,10 @@ async def test_agenerate(self, mock_groq: MagicMock) -> None: llm._aclient = mock_groq mocked_completion = Mock( - choices=[Mock(message=Mock(content=" Aenean hendrerit aliquam velit. ..."))] + choices=[ + Mock(message=Mock(content=" Aenean hendrerit aliquam velit. ...")) + ], + usage=Mock(prompt_tokens=100, completion_tokens=100), ) llm._aclient.chat.completions.create = AsyncMock(return_value=mocked_completion) @@ -50,10 +53,13 @@ async def test_agenerate(self, mock_groq: MagicMock) -> None: "content": "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", }, ] - ) == [" Aenean hendrerit aliquam velit. ..."] + ) == { + "generations": [" Aenean hendrerit aliquam velit. ..."], + "statistics": {"input_tokens": [100], "output_tokens": [100]}, + } @pytest.mark.skipif( - sys.version_info < (3, 9), reason="`mistralai` requires Python 3.9 or higher" + sys.version_info < (3, 9), reason="`groq` requires Python 3.9 or higher" ) @pytest.mark.asyncio async def test_agenerate_structured(self, mock_openai: MagicMock) -> None: @@ -68,8 +74,12 @@ async def test_agenerate_structured(self, mock_openai: MagicMock) -> None: ) # type: ignore llm._aclient = mock_openai - sample_user = DummyUserDetail(name="John Doe", age=30) - + mocked_usage = MagicMock( + usage=MagicMock(prompt_tokens=100, completion_tokens=100), + ) + sample_user = DummyUserDetail( + name="John Doe", age=30, _raw_response=mocked_usage + ) llm._aclient.chat.completions.create = AsyncMock(return_value=sample_user) generation = await llm.agenerate( @@ -81,7 +91,10 @@ async def test_agenerate_structured(self, mock_openai: MagicMock) -> None: }, ] ) - assert generation[0] == sample_user.model_dump_json() + assert generation == { + "generations": [sample_user.model_dump_json()], + "statistics": {"input_tokens": [100], "output_tokens": [100]}, + } @pytest.mark.asyncio async def test_generate(self, mock_groq: MagicMock) -> None: @@ -89,7 +102,8 @@ async def test_generate(self, mock_groq: MagicMock) -> None: llm._aclient = mock_groq mocked_completion = Mock( - choices=[Mock(message=Mock(content=" Aenean hendrerit aliquam velit. ..."))] + choices=[Mock(message=Mock(content="Aenean hendrerit aliquam velit..."))], + usage=Mock(prompt_tokens=100, completion_tokens=100), ) llm._aclient.chat.completions.create = AsyncMock(return_value=mocked_completion) @@ -105,7 +119,12 @@ async def test_generate(self, mock_groq: MagicMock) -> None: }, ] ] - ) == [[" Aenean hendrerit aliquam velit. ..."]] + ) == [ + { + "generations": ["Aenean hendrerit aliquam velit..."], + "statistics": {"input_tokens": [100], "output_tokens": [100]}, + } + ] @pytest.mark.parametrize( "structured_output, dump", diff --git a/tests/unit/models/llms/test_litellm.py b/tests/unit/models/llms/test_litellm.py index 60dfaacbb0..df7bed9a35 100644 --- a/tests/unit/models/llms/test_litellm.py +++ b/tests/unit/models/llms/test_litellm.py @@ -42,7 +42,7 @@ async def test_agenerate(self, mock_litellm: MagicMock, model: str) -> None: ) llm._aclient = AsyncMock(return_value=mocked_completion) - await llm.agenerate( + result = await llm.agenerate( input=[ {"role": "system", "content": ""}, { @@ -51,6 +51,10 @@ async def test_agenerate(self, mock_litellm: MagicMock, model: str) -> None: }, ] ) + assert result == { + "generations": [" Aenean hendrerit aliquam velit. ..."], + "statistics": {"input_tokens": [21], "output_tokens": [11]}, + } @pytest.mark.asyncio async def test_generate(self, mock_litellm: MagicMock, model: str) -> None: @@ -64,7 +68,7 @@ async def test_generate(self, mock_litellm: MagicMock, model: str) -> None: nest_asyncio.apply() - llm.generate( + result = llm.generate( inputs=[ [ {"role": "system", "content": ""}, @@ -75,6 +79,12 @@ async def test_generate(self, mock_litellm: MagicMock, model: str) -> None: ] ] ) + assert result == [ + { + "generations": [" Aenean hendrerit aliquam velit. ..."], + "statistics": {"input_tokens": [21], "output_tokens": [11]}, + } + ] def test_serialization(self, _: MagicMock, model: str) -> None: llm = LiteLLM(model=model) # type: ignore diff --git a/tests/unit/models/llms/test_llamacpp.py b/tests/unit/models/llms/test_llamacpp.py index 19cdcd929b..94bf008f19 100644 --- a/tests/unit/models/llms/test_llamacpp.py +++ b/tests/unit/models/llms/test_llamacpp.py @@ -54,9 +54,12 @@ def test_generate(self, llm: LlamaCppLLM) -> None: ], num_generations=3, ) - assert len(responses) == 2 - assert len(responses[0]) == 3 + generations = responses[0]["generations"] + statistics = responses[0]["statistics"] + assert len(generations) == 3 + assert "input_tokens" in statistics + assert "output_tokens" in statistics @pytest.mark.parametrize( "structured_output, dump", diff --git a/tests/unit/models/llms/test_mistral.py b/tests/unit/models/llms/test_mistral.py index a0095b3d73..f5ae5a4116 100644 --- a/tests/unit/models/llms/test_mistral.py +++ b/tests/unit/models/llms/test_mistral.py @@ -40,15 +40,18 @@ def test_mistral_llm(self, mock_mistral: MagicMock) -> None: @pytest.mark.asyncio async def test_agenerate(self, mock_mistral: MagicMock) -> None: - llm = MistralLLM(model="mistral-tiny", api_key="api.key") # type: ignore + llm = MistralLLM(model="mistral-small", api_key="api.key") # type: ignore llm._aclient = mock_mistral mocked_completion = Mock( - choices=[Mock(message=Mock(content=" Aenean hendrerit aliquam velit. ..."))] + choices=[ + Mock(message=Mock(content=" Aenean hendrerit aliquam velit. ...")) + ], + usage=Mock(prompt_tokens=10, completion_tokens=10, total_tokens=20), ) - llm._aclient.chat = AsyncMock(return_value=mocked_completion) + llm._aclient.chat.complete_async = AsyncMock(return_value=mocked_completion) - await llm.agenerate( + result = await llm.agenerate( input=[ {"role": "system", "content": ""}, { @@ -57,11 +60,15 @@ async def test_agenerate(self, mock_mistral: MagicMock) -> None: }, ] ) + assert result == { + "generations": [" Aenean hendrerit aliquam velit. ..."], + "statistics": {"input_tokens": [10], "output_tokens": [10]}, + } @pytest.mark.asyncio async def test_agenerate_structured(self, mock_mistral: MagicMock) -> None: llm = MistralLLM( - model="mistral-tiny", + model="mistral-small", api_key="api.key", structured_output={ "schema": DummyUserDetail, @@ -71,12 +78,16 @@ async def test_agenerate_structured(self, mock_mistral: MagicMock) -> None: ) # type: ignore llm._aclient = mock_mistral - sample_user = DummyUserDetail(name="John Doe", age=30) - + mocked_usage = MagicMock( + usage=MagicMock(prompt_tokens=100, completion_tokens=100), + ) + sample_user = DummyUserDetail( + name="John Doe", age=30, _raw_response=mocked_usage + ) + # llm._aclient.chat.completions.create = AsyncMock(return_value=Mock(messages=sample_user)) llm._aclient.chat.completions.create = AsyncMock(return_value=sample_user) # This should work just with the _aclient.chat method once it's fixed in instructor, and # then in our code. - # llm._aclient.chat = AsyncMock(return_value=sample_user) generation = await llm.agenerate( input=[ @@ -87,7 +98,13 @@ async def test_agenerate_structured(self, mock_mistral: MagicMock) -> None: }, ] ) - assert generation[0] == sample_user.model_dump_json() + assert generation == { + "generations": [sample_user.model_dump_json()], + "statistics": { + "input_tokens": [100], + "output_tokens": [100], + }, + } @pytest.mark.asyncio async def test_generate(self, mock_mistral: MagicMock) -> None: @@ -95,7 +112,10 @@ async def test_generate(self, mock_mistral: MagicMock) -> None: llm._aclient = mock_mistral mocked_completion = Mock( - choices=[Mock(message=Mock(content=" Aenean hendrerit aliquam velit. ..."))] + choices=[ + Mock(message=Mock(content=" Aenean hendrerit aliquam velit. ...")) + ], + usage=Mock(prompt_tokens=10, completion_tokens=10, total_tokens=20), ) llm._aclient.chat = Mock( complete_async=AsyncMock(return_value=mocked_completion) @@ -103,7 +123,7 @@ async def test_generate(self, mock_mistral: MagicMock) -> None: nest_asyncio.apply() - llm.generate( + result = llm.generate( inputs=[ [ {"role": "system", "content": ""}, @@ -114,6 +134,12 @@ async def test_generate(self, mock_mistral: MagicMock) -> None: ] ] ) + assert result == [ + { + "generations": [" Aenean hendrerit aliquam velit. ..."], + "statistics": {"input_tokens": [10], "output_tokens": [10]}, + } + ] @pytest.mark.parametrize( "structured_output, dump", diff --git a/tests/unit/models/llms/test_ollama.py b/tests/unit/models/llms/test_ollama.py index 137ea8adf9..167ec6a1dc 100644 --- a/tests/unit/models/llms/test_ollama.py +++ b/tests/unit/models/llms/test_ollama.py @@ -33,11 +33,13 @@ async def test_agenerate(self, mock_ollama: MagicMock) -> None: llm._aclient = mock_ollama mocked_completion = { - "message": {"content": " Aenean hendrerit aliquam velit. ..."} + "message": {"content": "Aenean hendrerit aliquam velit..."}, + "prompt_eval_count": 10, + "eval_count": 10, } llm._aclient.chat = AsyncMock(return_value=mocked_completion) - await llm.agenerate( + result = await llm.agenerate( input=[ {"role": "system", "content": ""}, { @@ -46,6 +48,10 @@ async def test_agenerate(self, mock_ollama: MagicMock) -> None: }, ] ) + assert result == { + "generations": ["Aenean hendrerit aliquam velit..."], + "statistics": {"input_tokens": [10], "output_tokens": [10]}, + } @pytest.mark.asyncio async def test_generate(self, mock_ollama: MagicMock) -> None: @@ -53,14 +59,16 @@ async def test_generate(self, mock_ollama: MagicMock) -> None: llm._aclient = mock_ollama mocked_completion = { - "message": {"content": " Aenean hendrerit aliquam velit. ..."} + "message": {"content": "Aenean hendrerit aliquam velit..."}, + "prompt_eval_count": 10, + "eval_count": 10, } llm._aclient.chat = AsyncMock(return_value=mocked_completion) nest_asyncio.apply() - llm.generate( + result = llm.generate( inputs=[ [ {"role": "system", "content": ""}, @@ -71,6 +79,12 @@ async def test_generate(self, mock_ollama: MagicMock) -> None: ] ] ) + assert result == [ + { + "generations": ["Aenean hendrerit aliquam velit..."], + "statistics": {"input_tokens": [10], "output_tokens": [10]}, + } + ] def test_serialization(self, _: MagicMock) -> None: llm = OllamaLLM(model="notus") # type: ignore diff --git a/tests/unit/models/llms/test_openai.py b/tests/unit/models/llms/test_openai.py index 30caaa86ad..b0c242f690 100644 --- a/tests/unit/models/llms/test_openai.py +++ b/tests/unit/models/llms/test_openai.py @@ -15,7 +15,7 @@ import os import sys from textwrap import dedent -from typing import Any, Dict +from typing import Any, Dict, List from unittest import mock from unittest.mock import AsyncMock, MagicMock, Mock, patch @@ -65,11 +65,14 @@ async def test_agenerate( llm._aclient = async_openai_mock mocked_completion = Mock( - choices=[Mock(message=Mock(content=" Aenean hendrerit aliquam velit. ..."))] + choices=[ + Mock(message=Mock(content=" Aenean hendrerit aliquam velit. ...")) + ], + usage=Mock(prompt_tokens=100, completion_tokens=100), ) llm._aclient.chat.completions.create = AsyncMock(return_value=mocked_completion) - await llm.agenerate( + result = await llm.agenerate( input=[ {"role": "system", "content": ""}, { @@ -78,6 +81,10 @@ async def test_agenerate( }, ] ) + assert result == { + "generations": [" Aenean hendrerit aliquam velit. ..."], + "statistics": {"input_tokens": [100], "output_tokens": [100]}, + } @pytest.mark.asyncio async def test_agenerate_structured( @@ -93,8 +100,16 @@ async def test_agenerate_structured( }, ) # type: ignore llm._aclient = async_openai_mock + import tiktoken + + llm._tokenizer = tiktoken.encoding_for_model(self.model_id) - sample_user = DummyUserDetail(name="John Doe", age=30) + mocked_usage = MagicMock( + usage=MagicMock(prompt_tokens=100, completion_tokens=100), + ) + sample_user = DummyUserDetail( + name="John Doe", age=30, _raw_response=mocked_usage + ) llm._aclient.chat.completions.create = AsyncMock(return_value=sample_user) @@ -107,26 +122,58 @@ async def test_agenerate_structured( }, ] ) - assert generation[0] == sample_user.model_dump_json() + assert generation == { + "generations": [sample_user.model_dump_json()], + "statistics": {"input_tokens": [100], "output_tokens": [100]}, + } @pytest.mark.skipif( sys.version_info < (3, 9), reason="`mistralai` requires Python 3.9 or higher" ) + @pytest.mark.parametrize( + "num_generations, expected_result", + [ + ( + 1, + [ + { + "generations": [" Aenean hendrerit aliquam velit. ..."], + "statistics": {"input_tokens": [100], "output_tokens": [100]}, + } + ], + ), + ( + 2, + [ + { + "generations": [" Aenean hendrerit aliquam velit. ..."] * 2, + "statistics": {"input_tokens": [100], "output_tokens": [100]}, + } + ], + ), + ], + ) @pytest.mark.asyncio async def test_generate( - self, async_openai_mock: MagicMock, _openai_mock: MagicMock + self, + async_openai_mock: MagicMock, + _openai_mock: MagicMock, + num_generations: int, + expected_result: List[Dict[str, Any]], ) -> None: llm = OpenAILLM(model=self.model_id, api_key="api.key") # type: ignore llm._aclient = async_openai_mock mocked_completion = Mock( choices=[Mock(message=Mock(content=" Aenean hendrerit aliquam velit. ..."))] + * num_generations, + usage=Mock(prompt_tokens=100, completion_tokens=100), ) llm._aclient.chat.completions.create = AsyncMock(return_value=mocked_completion) nest_asyncio.apply() - llm.generate( + result = llm.generate( inputs=[ [ {"role": "system", "content": ""}, @@ -137,6 +184,7 @@ async def test_generate( ] ] ) + assert result == expected_result with pytest.raises(ValueError): llm.generate( @@ -206,6 +254,11 @@ def test_check_and_get_batch_results( }, } ], + "usage": { + "prompt_tokens": 100, + "completion_tokens": 100, + "total_tokens": 200, + }, }, }, }, @@ -228,6 +281,11 @@ def test_check_and_get_batch_results( }, } ], + "usage": { + "prompt_tokens": 100, + "completion_tokens": 100, + "total_tokens": 200, + }, }, }, }, @@ -236,7 +294,23 @@ def test_check_and_get_batch_results( llm.load() outputs = llm._check_and_get_batch_results() - assert outputs == [["output 1"], ["output 2"]] + + assert outputs == [ + { + "generations": ["output 1"], + "statistics": { + "input_tokens": [100], + "output_tokens": [100], + }, + }, + { + "generations": ["output 2"], + "statistics": { + "input_tokens": [100], + "output_tokens": [100], + }, + }, + ] def test_check_and_get_batch_results_raises_valueerror( self, _async_openai_mock: MagicMock, _openai_mock: MagicMock @@ -322,12 +396,23 @@ def test_parse_output( }, } ], + "usage": { + "prompt_tokens": 100, + "completion_tokens": 100, + "total_tokens": 200, + }, }, } } ) - assert result == [" Aenean hendrerit aliquam velit. ..."] + assert result == { + "generations": [" Aenean hendrerit aliquam velit. ..."], + "statistics": { + "input_tokens": [100], + "output_tokens": [100], + }, + } def test_retrieve_batch_results( self, _async_openai_mock: MagicMock, openai_mock: MagicMock diff --git a/tests/unit/models/llms/test_vertexai.py b/tests/unit/models/llms/test_vertexai.py index d32f773a3c..529fbf332a 100644 --- a/tests/unit/models/llms/test_vertexai.py +++ b/tests/unit/models/llms/test_vertexai.py @@ -41,9 +41,10 @@ async def test_agenerate(self, mock_generative_model: MagicMock) -> None: llm._generation_config_class = GenerationConfig mocked_completion = Mock( - choices=[Mock(message=Mock(content=" Aenean hendrerit aliquam velit. ..."))] + candidates=[Mock(text=" Aenean hendrerit aliquam velit. ...")], + usage_metadata=Mock(prompt_token_count=10, candidates_token_count=10), ) - llm._aclient.chat.completions.create = AsyncMock(return_value=mocked_completion) + llm._aclient.generate_content_async = AsyncMock(return_value=mocked_completion) with pytest.raises( ValueError, match="`VertexAILLM only supports the roles 'user' or 'model'." @@ -58,7 +59,7 @@ async def test_agenerate(self, mock_generative_model: MagicMock) -> None: ] ) - await llm.agenerate( + result = await llm.agenerate( input=[ {"role": "model", "content": ""}, { @@ -67,6 +68,10 @@ async def test_agenerate(self, mock_generative_model: MagicMock) -> None: }, ] ) + assert result == { + "generations": [" Aenean hendrerit aliquam velit. ..."], + "statistics": {"input_tokens": [10], "output_tokens": [10]}, + } @pytest.mark.asyncio async def test_generate(self, mock_generative_model: MagicMock) -> None: @@ -77,9 +82,10 @@ async def test_generate(self, mock_generative_model: MagicMock) -> None: llm._generation_config_class = GenerationConfig mocked_completion = Mock( - choices=[Mock(message=Mock(content=" Aenean hendrerit aliquam velit. ..."))] + candidates=[Mock(text=" Aenean hendrerit aliquam velit. ...")], + usage_metadata=Mock(prompt_token_count=10, candidates_token_count=10), ) - llm._aclient.chat.completions.create = AsyncMock(return_value=mocked_completion) + llm._aclient.generate_content_async = AsyncMock(return_value=mocked_completion) nest_asyncio.apply() @@ -98,7 +104,7 @@ async def test_generate(self, mock_generative_model: MagicMock) -> None: ] ) - llm.generate( + result = llm.generate( inputs=[ [ {"role": "model", "content": "I am a model."}, @@ -109,6 +115,12 @@ async def test_generate(self, mock_generative_model: MagicMock) -> None: ] ] ) + assert result == [ + { + "generations": [" Aenean hendrerit aliquam velit. ..."], + "statistics": {"input_tokens": [10], "output_tokens": [10]}, + } + ] def test_serialization(self, _: MagicMock) -> None: llm = VertexAILLM(model="gemini-1.0-pro") diff --git a/tests/unit/models/llms/test_vllm.py b/tests/unit/models/llms/test_vllm.py index 07c561af86..dda129fd8b 100644 --- a/tests/unit/models/llms/test_vllm.py +++ b/tests/unit/models/llms/test_vllm.py @@ -12,19 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List +from typing import Any, Dict, List from unittest import mock -import numpy as np import pytest from openai.pagination import SyncPage from openai.types import Model from openai.types.completion import Completion from openai.types.completion_choice import CompletionChoice +from openai.types.completion_usage import CompletionUsage from pydantic import BaseModel +from transformers import AutoTokenizer from distilabel.models.llms import vLLM -from distilabel.models.llms.vllm import ClientvLLM, _sort_batches +from distilabel.models.llms.vllm import ClientvLLM class Character(BaseModel): @@ -101,9 +102,10 @@ class Animal(BaseModel): ] -# Just a mock to avoid loading the model class DummyTokenizer: - chat_template = None + # chat_template = None + chat_template = "template" + vocabulary = {"I'm": 1, "fine": 2, "thank": 3, "you": 4, "sir": 5} def __init__(self) -> None: pass @@ -111,83 +113,108 @@ def __init__(self) -> None: def apply_chat_template(self, input, **kwargs): return input + def encode(self, text: str): + return [1, 2, 3, 4, 5] + + def convert_token_to_string(self, token: str) -> str: + return "token" + + def get_vocab(self): + return self.vocabulary + class TestvLLM: + @pytest.mark.parametrize("multi_structured_output", (False, True)) @pytest.mark.parametrize( - "num_generations, expected_sorted_batches", + "num_generations, expected_result", [ ( 1, [ - "Generate a character from a RPG game.", - "Generate an animal from a zoo.", - "Repeated character", - "What's the weather like today in Seattle in Celsius degrees?", - "Other character", - "repeated regex", + { + "generations": ["I'm fine thank you"], + "statistics": {"input_tokens": [10], "output_tokens": [6]}, + } ], ), ( - 3, - np.repeat( - [ - "Generate a character from a RPG game.", - "Generate an animal from a zoo.", - "Repeated character", - "What's the weather like today in Seattle in Celsius degrees?", - "Other character", - "repeated regex", - ], - 3, - ).tolist(), + 2, + [ + { + "generations": ["I'm fine thank you"] * 2, + "statistics": { + "input_tokens": [10, 10], + "output_tokens": [6, 6], + }, + } + ], ), ], ) - def test_prepare_batches_and_sort_back( - self, num_generations: int, expected_sorted_batches: List[str] - ): - formatted_inputs = [ - (item["instruction"], item["structured_output"]) - for row in SAMPLE_DATA - for item in row - ] + def test_generate( + self, + multi_structured_output: bool, + num_generations: int, + expected_result: List[Dict[str, Any]], + ) -> None: llm = vLLM(model="dummy") + tokenizer = AutoTokenizer.from_pretrained( + "distilabel-internal-testing/tiny-random-mistral" + ) llm._tokenizer = DummyTokenizer() - batches, indices = llm._prepare_batches(formatted_inputs) - # NOTE: We have to simulate calling self._model.generate(n=num_generations) and then sorting the results - num_generations_batches = [] - for batch in batches: - num_generations_batches.append( - (np.repeat(batch[0], num_generations).tolist(), batch[1]) + vllm_mock = mock.MagicMock() + vllm_mock.get_tokenizer = mock.MagicMock(return_value=tokenizer) + # mock the import by hacking sys.modules + # https://stackoverflow.com/questions/60919705/how-to-mock-in-a-python-unittest-a-library-not-installed-locally + import sys + + if "vllm" not in sys.modules: + sys.modules["vllm"] = vllm_mock + llm._model = vllm_mock + + mocked_requests_output = [ + mock.Mock( # RequestOutput + outputs=[ + mock.Mock( # CompletionOutput + text="I'm fine thank you", + token_ids=[1, 2, 3, 4, 5, 7], + ) + ] + * num_generations, ) - batches = num_generations_batches - # Recreate as the output from batched_outputs += [[output.text for output in outputs.outputs] for outputs in batch_outputs] - batches = [batch for batch, _ in batches] - sorted_batches = _sort_batches( - batches, indices, num_generations=num_generations - ) + ] - assert sorted_batches == [ - np.repeat( - [ - "Generate a character from a RPG game.", - "Generate an animal from a zoo.", - "Repeated character", - ], - num_generations, - ).tolist(), - np.repeat( - ["What's the weather like today in Seattle in Celsius degrees?"], - num_generations, - ).tolist(), - np.repeat( + llm._model.generate = mock.MagicMock(return_value=mocked_requests_output) + if not multi_structured_output: + formatted_inputs = [ [ - "Other character", - "repeated regex", - ], - num_generations, - ).tolist(), - ] + {"role": "system", "content": "sysprompt"}, + { + "role": "user", + "content": "I'm fine thank you", + }, + ] + ] + else: + formatted_inputs = [ + ( + [ + {"role": "system", "content": "sysprompt"}, + { + "role": "user", + "content": "I'm fine thank you", + }, + ], + { + # "format": "json", + "format": "regex", + "schema": r".*", + # "schema": Character.model_json_schema(), + }, + ) + ] + result = llm.generate(inputs=formatted_inputs, num_generations=num_generations) + assert result == expected_result @mock.patch("openai.OpenAI") @@ -240,6 +267,11 @@ async def test_agenerate( text="I'm fine thank you sir", ), ], + usage=CompletionUsage( + completion_tokens=10, + prompt_tokens=10, + total_tokens=20, + ), ) ) @@ -247,4 +279,10 @@ async def test_agenerate( input=[{"role": "user", "content": "Hi, how are you?"}] ) - assert generations == ["I'm fine thank you", "I'm fine thank you sir"] + assert generations == { + "generations": ["I'm fine thank you", "I'm fine thank you sir"], + "statistics": { + "input_tokens": [10], + "output_tokens": [10], + }, + } diff --git a/tests/unit/models/llms/utils.py b/tests/unit/models/llms/utils.py index 7b899253bb..1888388f6e 100644 --- a/tests/unit/models/llms/utils.py +++ b/tests/unit/models/llms/utils.py @@ -12,9 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from pydantic import BaseModel +from typing import Any + +from pydantic import BaseModel, PrivateAttr class DummyUserDetail(BaseModel): name: str age: int + _raw_response: Any = PrivateAttr() + + def __init__(self, **data): + super().__init__(**data) + self._raw_response = data.get("_raw_response") diff --git a/tests/unit/steps/clustering/test_text_clustering.py b/tests/unit/steps/clustering/test_text_clustering.py index 0659da71ec..ddd473bb76 100644 --- a/tests/unit/steps/clustering/test_text_clustering.py +++ b/tests/unit/steps/clustering/test_text_clustering.py @@ -32,11 +32,16 @@ async def agenerate( # type: ignore self, input: "FormattedInput", num_generations: int = 1 ) -> "GenerateOutput": if self.n == 1: - return [json.dumps({"labels": "label"}) for _ in range(num_generations)] - return [ - json.dumps({"labels": ["label" for _ in range(self.n)]}) - for _ in range(self.n) - ] + text = json.dumps({"labels": "label"}) + else: + text = json.dumps({"labels": ["label" for _ in range(self.n)]}) + return { + "generations": [text] * num_generations, + "statistics": { + "input_tokens": [12] * num_generations, + "output_tokens": [12] * num_generations, + }, + } class TestTextClustering: diff --git a/tests/unit/steps/tasks/apigen/test_generator.py b/tests/unit/steps/tasks/apigen/test_generator.py index efe14ff12f..38580c2c42 100644 --- a/tests/unit/steps/tasks/apigen/test_generator.py +++ b/tests/unit/steps/tasks/apigen/test_generator.py @@ -49,9 +49,16 @@ def generate( if self.use_structured_output: query_answers = {"pairs": query_answers} return [ - [json.dumps(query_answers) for _ in range(num_generations)] - for _ in range(len(inputs)) - ] + { + "generations": [ + json.dumps(query_answers) for _ in range(num_generations) + ], + "statistics": { + "input_tokens": [12] * num_generations, + "output_tokens": [12] * num_generations, + }, + } + ] * len(inputs) # Example of 3 rows from Salesforce/xlam-function-calling-60k diff --git a/tests/unit/steps/tasks/evol_instruct/test_base.py b/tests/unit/steps/tasks/evol_instruct/test_base.py index 053bac0a4f..027c67809b 100644 --- a/tests/unit/steps/tasks/evol_instruct/test_base.py +++ b/tests/unit/steps/tasks/evol_instruct/test_base.py @@ -69,6 +69,12 @@ def test_process(self, dummy_llm: LLM) -> None: "instruction": "test", "evolved_instruction": "output", "model_name": "test", + "distilabel_metadata": { + "statistics_instruction_task": { + "input_tokens": [12, 12], + "output_tokens": [12, 12], + } + }, } ] ] @@ -89,6 +95,12 @@ def test_process_store_evolutions(self, dummy_llm: LLM) -> None: "instruction": "test", "evolved_instructions": ["output", "output"], "model_name": "test", + "distilabel_metadata": { + "statistics_instruction_task": { + "input_tokens": [12, 12], + "output_tokens": [12, 12], + } + }, } ] ] @@ -110,6 +122,12 @@ def test_process_generate_answers(self, dummy_llm: LLM) -> None: "evolved_instruction": "output", "answer": "output", "model_name": "test", + "distilabel_metadata": { + "statistics_answer_task": { + "input_tokens": [12], + "output_tokens": [12], + } + }, } ] ] @@ -140,6 +158,7 @@ def test_serialization(self, dummy_llm: LLM) -> None: "jobs_ids": None, "offline_batch_generation_block_until_done": None, "use_offline_batch_generation": False, + "n_generations_supported": True, "type_info": { "module": task.llm.__module__, "name": task.llm.__class__.__name__, diff --git a/tests/unit/steps/tasks/evol_instruct/test_generator.py b/tests/unit/steps/tasks/evol_instruct/test_generator.py index e87d09a9ce..41b70591ec 100644 --- a/tests/unit/steps/tasks/evol_instruct/test_generator.py +++ b/tests/unit/steps/tasks/evol_instruct/test_generator.py @@ -64,19 +64,36 @@ def test_process(self, dummy_llm: LLM) -> None: task = EvolInstructGenerator( name="task", llm=dummy_llm, - num_instructions=1, + num_instructions=2, min_length=1, max_length=10, pipeline=pipeline, ) task.load() + assert list(task.process()) == [ ( [ { "instruction": "output", "model_name": "test", - } + "distilabel_metadata": { + "statistics_instruction_task": { + "input_tokens": [12, 12], + "output_tokens": [12, 12], + } + }, + }, + { + "instruction": "output", + "model_name": "test", + "distilabel_metadata": { + "statistics_instruction_task": { + "input_tokens": [12, 12], + "output_tokens": [12, 12], + } + }, + }, ], True, ) @@ -101,6 +118,12 @@ def test_process_generate_answers(self, dummy_llm: LLM) -> None: "instruction": "output", "answer": "output", "model_name": "test", + "distilabel_metadata": { + "statistics_answer_task": { + "input_tokens": [12], + "output_tokens": [12], + } + }, } ], True, @@ -122,6 +145,7 @@ def test_serialization(self, dummy_llm: LLM) -> None: "jobs_ids": None, "offline_batch_generation_block_until_done": None, "use_offline_batch_generation": False, + "n_generations_supported": True, "type_info": { "module": task.llm.__class__.__module__, "name": task.llm.__class__.__name__, diff --git a/tests/unit/steps/tasks/evol_quality/test_base.py b/tests/unit/steps/tasks/evol_quality/test_base.py index c77df8d8ad..7c84fe3587 100644 --- a/tests/unit/steps/tasks/evol_quality/test_base.py +++ b/tests/unit/steps/tasks/evol_quality/test_base.py @@ -60,6 +60,12 @@ def test_process(self, dummy_llm: LLM) -> None: "response": "mock", "evolved_response": "output", "model_name": "test", + "distilabel_metadata": { + "statistics_task": { + "input_tokens": [12, 12], + "output_tokens": [12, 12], + } + }, } ] ] @@ -81,6 +87,12 @@ def test_process_store_evolutions(self, dummy_llm: LLM) -> None: "response": "mock", "evolved_responses": ["output", "output"], "model_name": "test", + "distilabel_metadata": { + "statistics_task": { + "input_tokens": [12, 12], + "output_tokens": [12, 12], + } + }, } ] ] @@ -111,6 +123,7 @@ def test_serialization(self, dummy_llm: LLM) -> None: "jobs_ids": None, "offline_batch_generation_block_until_done": None, "use_offline_batch_generation": False, + "n_generations_supported": True, "type_info": { "module": task.llm.__module__, "name": task.llm.__class__.__name__, diff --git a/tests/unit/steps/tasks/magpie/test_base.py b/tests/unit/steps/tasks/magpie/test_base.py index aac4e504f9..07de3982b7 100644 --- a/tests/unit/steps/tasks/magpie/test_base.py +++ b/tests/unit/steps/tasks/magpie/test_base.py @@ -86,16 +86,34 @@ def test_process(self) -> None: "instruction": "Hello Magpie", "response": "Hello Magpie", "model_name": "test", + "distilabel_metadata": { + "statistics_magpie_0": { + "input_tokens": [12, 12], + "output_tokens": [12, 12], + } + }, }, { "instruction": "Hello Magpie", "response": "Hello Magpie", "model_name": "test", + "distilabel_metadata": { + "statistics_magpie_0": { + "input_tokens": [12, 12], + "output_tokens": [12, 12], + } + }, }, { "instruction": "Hello Magpie", "response": "Hello Magpie", "model_name": "test", + "distilabel_metadata": { + "statistics_magpie_0": { + "input_tokens": [12, 12], + "output_tokens": [12, 12], + } + }, }, ] @@ -119,6 +137,12 @@ def test_process_with_system_prompt(self) -> None: {"role": "assistant", "content": "Hello Magpie"}, ], "model_name": "test", + "distilabel_metadata": { + "statistics_magpie_0": { + "input_tokens": [12, 12, 12, 12], + "output_tokens": [12, 12, 12, 12], + } + }, }, { "conversation": [ @@ -129,6 +153,12 @@ def test_process_with_system_prompt(self) -> None: {"role": "assistant", "content": "Hello Magpie"}, ], "model_name": "test", + "distilabel_metadata": { + "statistics_magpie_0": { + "input_tokens": [12, 12, 12, 12], + "output_tokens": [12, 12, 12, 12], + } + }, }, { "conversation": [ @@ -139,6 +169,12 @@ def test_process_with_system_prompt(self) -> None: {"role": "assistant", "content": "Hello Magpie"}, ], "model_name": "test", + "distilabel_metadata": { + "statistics_magpie_0": { + "input_tokens": [12, 12, 12, 12], + "output_tokens": [12, 12, 12, 12], + } + }, }, ] @@ -167,6 +203,12 @@ def test_process_with_several_system_prompts(self) -> None: {"role": "assistant", "content": "Hello Magpie"}, ], "model_name": "test", + "distilabel_metadata": { + "statistics_magpie_0": { + "input_tokens": [12, 12, 12, 12], + "output_tokens": [12, 12, 12, 12], + } + }, }, { "conversation": [ @@ -177,6 +219,12 @@ def test_process_with_several_system_prompts(self) -> None: {"role": "assistant", "content": "Hello Magpie"}, ], "model_name": "test", + "distilabel_metadata": { + "statistics_magpie_0": { + "input_tokens": [12, 12, 12, 12], + "output_tokens": [12, 12, 12, 12], + } + }, }, { "conversation": [ @@ -187,6 +235,12 @@ def test_process_with_several_system_prompts(self) -> None: {"role": "assistant", "content": "Hello Magpie"}, ], "model_name": "test", + "distilabel_metadata": { + "statistics_magpie_0": { + "input_tokens": [12, 12, 12, 12], + "output_tokens": [12, 12, 12, 12], + } + }, }, ] @@ -194,10 +248,42 @@ def test_process_failing_generation_for_some_rows(self) -> None: with mock.patch( "tests.unit.conftest.DummyMagpieLLM.generate", side_effect=[ - [["Hello Magpie"], [None], ["Hello Magpie"]], - [["Hello Magpie"], ["Hello Magpie"]], - [["Hello Magpie"], [None]], - [["Hello Magpie"]], + [ + { + "generations": ["Hello Magpie user"], + "statistics": { + "input_tokens": [12], + "output_tokens": [12], + }, + } + ], + [ + { + "generations": [None], + "statistics": { + "input_tokens": [], + "output_tokens": [], + }, + } + ], + [ + { + "generations": [None], + "statistics": { + "input_tokens": [], + "output_tokens": [], + }, + } + ], + [ + { + "generations": ["Hello Magpie assistant"], + "statistics": { + "input_tokens": [12], + "output_tokens": [12], + }, + } + ], ], ): task = Magpie( @@ -206,26 +292,19 @@ def test_process_failing_generation_for_some_rows(self) -> None: task.load() - assert next(task.process(inputs=[{}, {}, {}])) == [ - { - "conversation": [ - {"role": "user", "content": "Hello Magpie"}, - {"role": "assistant", "content": "Hello Magpie"}, - {"role": "user", "content": "Hello Magpie"}, - {"role": "assistant", "content": "Hello Magpie"}, - ], - "model_name": "test", - }, - { - "conversation": [], - "model_name": "test", - }, + assert next(task.process(inputs=[{}])) == [ { "conversation": [ - {"role": "user", "content": "Hello Magpie"}, - {"role": "assistant", "content": "Hello Magpie"}, + {"role": "user", "content": "Hello Magpie user"}, + {"role": "assistant", "content": "Hello Magpie assistant"}, ], "model_name": "test", + "distilabel_metadata": { + "statistics_magpie_0": { + "input_tokens": [12, 12], + "output_tokens": [12, 12], + } + }, }, ] @@ -243,6 +322,12 @@ def test_process_with_n_turns(self) -> None: {"role": "assistant", "content": "Hello Magpie"}, ], "model_name": "test", + "distilabel_metadata": { + "statistics_magpie_0": { + "input_tokens": [12, 12, 12, 12], + "output_tokens": [12, 12, 12, 12], + } + }, }, { "conversation": [ @@ -252,6 +337,12 @@ def test_process_with_n_turns(self) -> None: {"role": "assistant", "content": "Hello Magpie"}, ], "model_name": "test", + "distilabel_metadata": { + "statistics_magpie_0": { + "input_tokens": [12, 12, 12, 12], + "output_tokens": [12, 12, 12, 12], + } + }, }, { "conversation": [ @@ -261,6 +352,12 @@ def test_process_with_n_turns(self) -> None: {"role": "assistant", "content": "Hello Magpie"}, ], "model_name": "test", + "distilabel_metadata": { + "statistics_magpie_0": { + "input_tokens": [12, 12, 12, 12], + "output_tokens": [12, 12, 12, 12], + } + }, }, ] @@ -281,6 +378,12 @@ def test_process_with_end_with_user(self) -> None: {"role": "user", "content": "Hello Magpie"}, ], "model_name": "test", + "distilabel_metadata": { + "statistics_magpie_0": { + "input_tokens": [12, 12, 12], + "output_tokens": [12, 12, 12], + } + }, }, { "conversation": [ @@ -289,6 +392,12 @@ def test_process_with_end_with_user(self) -> None: {"role": "user", "content": "Hello Magpie"}, ], "model_name": "test", + "distilabel_metadata": { + "statistics_magpie_0": { + "input_tokens": [12, 12, 12], + "output_tokens": [12, 12, 12], + } + }, }, { "conversation": [ @@ -297,6 +406,12 @@ def test_process_with_end_with_user(self) -> None: {"role": "user", "content": "Hello Magpie"}, ], "model_name": "test", + "distilabel_metadata": { + "statistics_magpie_0": { + "input_tokens": [12, 12, 12], + "output_tokens": [12, 12, 12], + } + }, }, ] @@ -319,6 +434,12 @@ def test_process_with_include_system_prompt(self) -> None: {"role": "assistant", "content": "Hello Magpie"}, ], "model_name": "test", + "distilabel_metadata": { + "statistics_magpie_0": { + "input_tokens": [12, 12, 12, 12], + "output_tokens": [12, 12, 12, 12], + } + }, }, { "conversation": [ @@ -329,6 +450,12 @@ def test_process_with_include_system_prompt(self) -> None: {"role": "assistant", "content": "Hello Magpie"}, ], "model_name": "test", + "distilabel_metadata": { + "statistics_magpie_0": { + "input_tokens": [12, 12, 12, 12], + "output_tokens": [12, 12, 12, 12], + } + }, }, { "conversation": [ @@ -339,6 +466,12 @@ def test_process_with_include_system_prompt(self) -> None: {"role": "assistant", "content": "Hello Magpie"}, ], "model_name": "test", + "distilabel_metadata": { + "statistics_magpie_0": { + "input_tokens": [12, 12, 12, 12], + "output_tokens": [12, 12, 12, 12], + } + }, }, ] @@ -370,6 +503,12 @@ def test_process_with_system_prompt_per_row(self) -> None: {"role": "assistant", "content": "Hello Magpie"}, ], "model_name": "test", + "distilabel_metadata": { + "statistics_magpie_0": { + "input_tokens": [12, 12, 12, 12], + "output_tokens": [12, 12, 12, 12], + } + }, }, { "system_prompt": "You're a florist expert assistant.", @@ -381,6 +520,12 @@ def test_process_with_system_prompt_per_row(self) -> None: {"role": "assistant", "content": "Hello Magpie"}, ], "model_name": "test", + "distilabel_metadata": { + "statistics_magpie_0": { + "input_tokens": [12, 12, 12, 12], + "output_tokens": [12, 12, 12, 12], + } + }, }, { "system_prompt": "You're a plumber expert assistant.", @@ -392,6 +537,12 @@ def test_process_with_system_prompt_per_row(self) -> None: {"role": "assistant", "content": "Hello Magpie"}, ], "model_name": "test", + "distilabel_metadata": { + "statistics_magpie_0": { + "input_tokens": [12, 12, 12, 12], + "output_tokens": [12, 12, 12, 12], + } + }, }, ] @@ -420,18 +571,36 @@ def test_process_with_system_prompt_and_probabilities(self) -> None: "response": "Hello Magpie", "system_prompt_key": "system_prompt_1", "model_name": "test", + "distilabel_metadata": { + "statistics_magpie_0": { + "input_tokens": [12, 12], + "output_tokens": [12, 12], + } + }, }, { "instruction": "Hello Magpie", "response": "Hello Magpie", "system_prompt_key": "system_prompt_2", "model_name": "test", + "distilabel_metadata": { + "statistics_magpie_0": { + "input_tokens": [12, 12], + "output_tokens": [12, 12], + } + }, }, { "instruction": "Hello Magpie", "response": "Hello Magpie", "system_prompt_key": "system_prompt_1", "model_name": "test", + "distilabel_metadata": { + "statistics_magpie_0": { + "input_tokens": [12, 12], + "output_tokens": [12, 12], + } + }, }, ] @@ -447,14 +616,32 @@ def test_process_only_instruction(self) -> None: { "instruction": "Hello Magpie", "model_name": "test", + "distilabel_metadata": { + "statistics_magpie_0": { + "input_tokens": [12], + "output_tokens": [12], + } + }, }, { "instruction": "Hello Magpie", "model_name": "test", + "distilabel_metadata": { + "statistics_magpie_0": { + "input_tokens": [12], + "output_tokens": [12], + } + }, }, { "instruction": "Hello Magpie", "model_name": "test", + "distilabel_metadata": { + "statistics_magpie_0": { + "input_tokens": [12], + "output_tokens": [12], + } + }, }, ] diff --git a/tests/unit/steps/tasks/structured_outputs/test_outlines.py b/tests/unit/steps/tasks/structured_outputs/test_outlines.py index a535081e65..fc6f9a2f7c 100644 --- a/tests/unit/steps/tasks/structured_outputs/test_outlines.py +++ b/tests/unit/steps/tasks/structured_outputs/test_outlines.py @@ -19,7 +19,6 @@ from distilabel.models.llms.huggingface.transformers import TransformersLLM from distilabel.steps.tasks.structured_outputs.outlines import ( - # StructuredOutputType, model_to_schema, ) from distilabel.steps.tasks.typing import OutlinesStructuredOutputType @@ -138,8 +137,8 @@ def test_generation( ] result = llm.generate(prompt, max_new_tokens=30) assert isinstance(result, list) - assert isinstance(result[0], list) - assert isinstance(result[0][0], str) + assert isinstance(result[0], dict) + assert "generations" in result[0] and "statistics" in result[0] @pytest.mark.parametrize( "format, schema, dump", diff --git a/tests/unit/steps/tasks/test_base.py b/tests/unit/steps/tasks/test_base.py index 29341052fb..ab48a79b09 100644 --- a/tests/unit/steps/tasks/test_base.py +++ b/tests/unit/steps/tasks/test_base.py @@ -86,6 +86,10 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None: ): Task(name="task", llm=DummyAsyncLLM()) # type: ignore + @pytest.mark.parametrize( + "n_generations_supported", + [True, False], + ) @pytest.mark.parametrize( "input, group_generations, expected", [ @@ -109,6 +113,10 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None: {"content": "", "role": "system"}, {"content": "test_0", "role": "user"}, ], + "statistics_task": { + "input_tokens": 12, + "output_tokens": 12, + }, }, }, { @@ -123,34 +131,10 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None: {"content": "", "role": "system"}, {"content": "test_0", "role": "user"}, ], - }, - }, - { - "instruction": "test_0", - "additional_info": "additional_info_0", - "output": "output", - "info_from_input": "additional_info_0", - "model_name": "test", - "distilabel_metadata": { - "raw_output_task": "output", - "raw_input_task": [ - {"content": "", "role": "system"}, - {"content": "test_0", "role": "user"}, - ], - }, - }, - { - "instruction": "test_1", - "additional_info": "additional_info_1", - "output": "output", - "info_from_input": "additional_info_1", - "model_name": "test", - "distilabel_metadata": { - "raw_output_task": "output", - "raw_input_task": [ - {"content": "", "role": "system"}, - {"content": "test_1", "role": "user"}, - ], + "statistics_task": { + "input_tokens": 12, + "output_tokens": 12, + }, }, }, { @@ -165,6 +149,10 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None: {"content": "", "role": "system"}, {"content": "test_1", "role": "user"}, ], + "statistics_task": { + "input_tokens": 12, + "output_tokens": 12, + }, }, }, { @@ -179,6 +167,10 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None: {"content": "", "role": "system"}, {"content": "test_1", "role": "user"}, ], + "statistics_task": { + "input_tokens": 12, + "output_tokens": 12, + }, }, }, { @@ -193,6 +185,10 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None: {"content": "", "role": "system"}, {"content": "test_2", "role": "user"}, ], + "statistics_task": { + "input_tokens": 12, + "output_tokens": 12, + }, }, }, { @@ -207,20 +203,10 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None: {"content": "", "role": "system"}, {"content": "test_2", "role": "user"}, ], - }, - }, - { - "instruction": "test_2", - "additional_info": "additional_info_2", - "output": "output", - "info_from_input": "additional_info_2", - "model_name": "test", - "distilabel_metadata": { - "raw_output_task": "output", - "raw_input_task": [ - {"content": "", "role": "system"}, - {"content": "test_2", "role": "user"}, - ], + "statistics_task": { + "input_tokens": 12, + "output_tokens": 12, + }, }, }, ], @@ -236,11 +222,10 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None: { "instruction": "test_0", "additional_info": "additional_info_0", - "output": ["output", "output", "output"], + "output": ["output", "output"], "info_from_input": [ "additional_info_0", "additional_info_0", - "additional_info_0", ], "model_name": "test", "distilabel_metadata": [ @@ -256,6 +241,10 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None: "role": "user", }, ], + "statistics_task": { + "input_tokens": 12, + "output_tokens": 12, + }, }, { "raw_output_task": "output", @@ -269,33 +258,20 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None: "role": "user", }, ], + "statistics_task": { + "input_tokens": 12, + "output_tokens": 12, + }, }, - { - "raw_output_task": "output", - "raw_input_task": [ - { - "content": "", - "role": "system", - }, - { - "content": "test_0", - "role": "user", - }, - ], - }, - # {"raw_output_task": "output"}, - # {"raw_output_task": "output"}, - # {"raw_output_task": "output"}, ], }, { "instruction": "test_1", "additional_info": "additional_info_1", - "output": ["output", "output", "output"], + "output": ["output", "output"], "info_from_input": [ "additional_info_1", "additional_info_1", - "additional_info_1", ], "model_name": "test", "distilabel_metadata": [ @@ -311,6 +287,10 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None: "role": "user", }, ], + "statistics_task": { + "input_tokens": 12, + "output_tokens": 12, + }, }, { "raw_output_task": "output", @@ -324,30 +304,20 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None: "role": "user", }, ], - }, - { - "raw_output_task": "output", - "raw_input_task": [ - { - "content": "", - "role": "system", - }, - { - "content": "test_1", - "role": "user", - }, - ], + "statistics_task": { + "input_tokens": 12, + "output_tokens": 12, + }, }, ], }, { "instruction": "test_2", "additional_info": "additional_info_2", - "output": ["output", "output", "output"], + "output": ["output", "output"], "info_from_input": [ "additional_info_2", "additional_info_2", - "additional_info_2", ], "model_name": "test", "distilabel_metadata": [ @@ -363,6 +333,10 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None: "role": "user", }, ], + "statistics_task": { + "input_tokens": 12, + "output_tokens": 12, + }, }, { "raw_output_task": "output", @@ -376,19 +350,10 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None: "role": "user", }, ], - }, - { - "raw_output_task": "output", - "raw_input_task": [ - { - "content": "", - "role": "system", - }, - { - "content": "test_2", - "role": "user", - }, - ], + "statistics_task": { + "input_tokens": 12, + "output_tokens": 12, + }, }, ], }, @@ -401,16 +366,19 @@ def test_process( input: List[Dict[str, str]], group_generations: bool, expected: List[Dict[str, Any]], + n_generations_supported: bool, ) -> None: pipeline = Pipeline(name="unit-test-pipeline") - llm = DummyAsyncLLM() + llm = DummyAsyncLLM(n_generations_supported=n_generations_supported) + llm.load() task = DummyTask( name="task", llm=llm, pipeline=pipeline, group_generations=group_generations, - num_generations=3, + num_generations=2, ) + task.load() result = next(task.process(input)) assert result == expected @@ -423,6 +391,7 @@ def test_process_overriding_inputs(self) -> None: num_generations=3, input_mappings={"instruction": "instruction_2"}, ) + task.load() result = next( task.process_applying_mappings( @@ -435,7 +404,6 @@ def test_process_overriding_inputs(self) -> None: ] ) ) - assert result == [ { "additional_info": "info", @@ -451,6 +419,7 @@ def test_process_overriding_inputs(self) -> None: }, ], "raw_output_task": "output", + "statistics_task": {"input_tokens": 12, "output_tokens": 12}, }, "info_from_input": "info", "instruction": "instruction that won't be used but overriden by input mapping", @@ -472,6 +441,7 @@ def test_process_overriding_inputs(self) -> None: }, ], "raw_output_task": "output", + "statistics_task": {"input_tokens": 12, "output_tokens": 12}, }, "info_from_input": "info", "instruction": "instruction that won't be used but overriden by input mapping", @@ -493,6 +463,7 @@ def test_process_overriding_inputs(self) -> None: }, ], "raw_output_task": "output", + "statistics_task": {"input_tokens": 12, "output_tokens": 12}, }, "info_from_input": "info", "instruction": "instruction that won't be used but overriden by input mapping", @@ -560,131 +531,128 @@ def test_serialization(self) -> None: pipeline = Pipeline(name="unit-test-pipeline") llm = DummyAsyncLLM() task = DummyTask(name="task", llm=llm, pipeline=pipeline) - assert task.dump() == { - "name": "task", - "add_raw_output": True, - "add_raw_input": True, - "input_mappings": {}, - "output_mappings": {}, - "resources": { - "cpus": None, - "gpus": None, - "memory": None, - "replicas": 1, - "resources": None, - }, - "input_batch_size": 50, - "llm": { - "generation_kwargs": {}, - "structured_output": None, - "jobs_ids": None, - "offline_batch_generation_block_until_done": None, - "use_offline_batch_generation": False, - "type_info": { - "module": "tests.unit.conftest", - "name": "DummyAsyncLLM", - }, - }, - "group_generations": False, - "num_generations": 1, - "runtime_parameters_info": [ - { - "name": "resources", - "runtime_parameters_info": [ - { - "description": "The number of replicas for the step.", - "name": "replicas", - "optional": True, - }, - { - "description": "The number of CPUs assigned to each step replica.", - "name": "cpus", - "optional": True, - }, - { - "description": "The number of GPUs assigned to each step replica.", - "name": "gpus", - "optional": True, - }, - { - "description": "The memory in bytes required for each step replica.", - "name": "memory", - "optional": True, - }, - { - "description": "A dictionary containing names of custom resources and the number of those resources required for each step replica.", - "name": "resources", - "optional": True, - }, - ], - }, - { - "description": "The number of rows that will contain the batches processed by the step.", - "name": "input_batch_size", - "optional": True, - }, - { - "name": "llm", - "runtime_parameters_info": [ - { - "description": "The kwargs to be propagated to either `generate` or " - "`agenerate` methods within each `LLM`.", - "keys": [], - "name": "generation_kwargs", - }, - { - "description": "Whether to use the `offline_batch_generate` method to " - "generate the responses.", - "name": "use_offline_batch_generation", - "optional": True, - }, - { - "description": "If provided, then polling will be done until the " - "`ofline_batch_generate` method is able to retrieve the " - "results. The value indicate the time to wait between each " - "polling.", - "name": "offline_batch_generation_block_until_done", - "optional": True, - }, - ], - }, - { - "description": "Whether to include the raw output of the LLM in the key `raw_output_` of the `distilabel_metadata` dictionary output column", - "name": "add_raw_output", - "optional": True, + assert ( + task.dump() + == { + "name": "task", + "add_raw_output": True, + "add_raw_input": True, + "input_mappings": {}, + "output_mappings": {}, + "resources": { + "cpus": None, + "gpus": None, + "memory": None, + "replicas": 1, + "resources": None, }, - { - "description": "Whether to include the raw input of the LLM in the key `raw_input_` of the `distilabel_metadata` dictionary column", - "name": "add_raw_input", - "optional": True, + "input_batch_size": 50, + "llm": { + "generation_kwargs": {}, + "structured_output": None, + "n_generations_supported": True, # Just a trick during testing, it won't appear otherwise + "jobs_ids": None, + "offline_batch_generation_block_until_done": None, + "use_offline_batch_generation": False, + "type_info": { + "module": "tests.unit.conftest", + "name": "DummyAsyncLLM", + }, }, - { - "name": "num_generations", - "description": "The number of generations to be produced per input.", - "optional": True, + "group_generations": False, + "num_generations": 1, + "runtime_parameters_info": [ + { + "name": "resources", + "runtime_parameters_info": [ + { + "description": "The number of replicas for the step.", + "name": "replicas", + "optional": True, + }, + { + "description": "The number of CPUs assigned to each step replica.", + "name": "cpus", + "optional": True, + }, + { + "description": "The number of GPUs assigned to each step replica.", + "name": "gpus", + "optional": True, + }, + { + "description": "The memory in bytes required for each step replica.", + "name": "memory", + "optional": True, + }, + { + "description": "A dictionary containing names of custom resources and the number of those resources required for each step replica.", + "name": "resources", + "optional": True, + }, + ], + }, + { + "description": "The number of rows that will contain the batches processed by the step.", + "name": "input_batch_size", + "optional": True, + }, + { + "name": "llm", + "runtime_parameters_info": [ + { + "description": "The kwargs to be propagated to either `generate` or " + "`agenerate` methods within each `LLM`.", + "keys": [], + "name": "generation_kwargs", + }, + { + "description": "Whether to use the `offline_batch_generate` method to " + "generate the responses.", + "name": "use_offline_batch_generation", + "optional": True, + }, + { + "description": "If provided, then polling will be done until the " + "`ofline_batch_generate` method is able to retrieve the " + "results. The value indicate the time to wait between each " + "polling.", + "name": "offline_batch_generation_block_until_done", + "optional": True, + }, + ], + }, + { + "description": "Whether to include the raw output of the LLM in the key `raw_output_` of the `distilabel_metadata` dictionary output column", + "name": "add_raw_output", + "optional": True, + }, + { + "description": "Whether to include the raw input of the LLM in the key `raw_input_` of the `distilabel_metadata` dictionary column", + "name": "add_raw_input", + "optional": True, + }, + { + "name": "num_generations", + "description": "The number of generations to be produced per input.", + "optional": True, + }, + ], + "use_cache": True, + "type_info": { + "module": "tests.unit.conftest", + "name": "DummyTask", }, - ], - "use_cache": True, - "type_info": { - "module": "tests.unit.conftest", - "name": "DummyTask", - }, - "use_default_structured_output": False, - } + "use_default_structured_output": False, + } + ) with Pipeline(name="unit-test-pipeline") as pipeline: new_task = DummyTask.from_dict(task.dump()) assert isinstance(new_task, DummyTask) - @pytest.mark.parametrize( - "add_raw_output, add_raw_input", - [ - (True, False), - (False, True), - (True, True), - (False, False), - ], - ) + @pytest.mark.parametrize("add_raw_output", [True, False]) + @pytest.mark.parametrize("add_raw_input", [True, False]) def test_add_raw_input_and_or_output( self, add_raw_output: bool, add_raw_input: bool ) -> None: @@ -707,7 +675,6 @@ def test_add_raw_input_and_or_output( pprint.pprint(result) if add_raw_output or add_raw_input: - assert "distilabel_metadata" in result[0].keys() if add_raw_output: assert ( "raw_output_dummy_task_0" in result[0]["distilabel_metadata"].keys() @@ -716,5 +683,4 @@ def test_add_raw_input_and_or_output( assert ( "raw_input_dummy_task_0" in result[0]["distilabel_metadata"].keys() ) - else: - assert "distilabel_metadata" not in result[0].keys() + assert "statistics_dummy_task_0" in result[0]["distilabel_metadata"].keys() diff --git a/tests/unit/steps/tasks/test_decorator.py b/tests/unit/steps/tasks/test_decorator.py index 085153c1f8..2280779799 100644 --- a/tests/unit/steps/tasks/test_decorator.py +++ b/tests/unit/steps/tasks/test_decorator.py @@ -181,7 +181,7 @@ def MyTask( { "task": "summarize", "instruction": "The cell...", - "response": "output", + "response": "output 0", "model_name": "test", "distilabel_metadata": { "raw_input_my_task_0": [ @@ -194,7 +194,11 @@ def MyTask( "role": "user", }, ], - "raw_output_my_task_0": "output", + "raw_output_my_task_0": "output 0", + "statistics_my_task_0": { + "input_tokens": 12, + "output_tokens": 12, + }, }, } ] diff --git a/tests/unit/steps/tasks/test_improving_text_embeddings.py b/tests/unit/steps/tasks/test_improving_text_embeddings.py index 0a153034e9..1bc4128c7c 100644 --- a/tests/unit/steps/tasks/test_improving_text_embeddings.py +++ b/tests/unit/steps/tasks/test_improving_text_embeddings.py @@ -45,7 +45,15 @@ def model_name(self) -> str: def generate( # type: ignore self, inputs: List[ChatType], num_generations: int = 1 ) -> List[GenerateOutput]: - return [[self.output] for _ in range(num_generations)] + return [ + { + "generations": [self.output for _ in range(num_generations)], + "statistics": { + "input_tokens": [12] * num_generations, + "output_tokens": [12] * num_generations, + }, + } + ] * len(inputs) class TestEmbeddingTaskGenerator: @@ -74,13 +82,54 @@ def test_process(self, category: str, flatten_tasks: bool) -> None: assert task.outputs == ["tasks" if not flatten_tasks else "task", "model_name"] result = ( - ([{"tasks": ["A", "B", "C"], "model_name": "test"}], True) + ( + [ + { + "tasks": ["A", "B", "C"], + "model_name": "test", + "distilabel_metadata": { + "statistics_embedding_task_generator": { + "input_tokens": 12, + "output_tokens": 12, + } + }, + } + ], + True, + ) if not flatten_tasks else ( [ - {"task": "A", "model_name": "test"}, - {"task": "B", "model_name": "test"}, - {"task": "C", "model_name": "test"}, + { + "task": "A", + "model_name": "test", + "distilabel_metadata": { + "statistics_embedding_task_generator": { + "input_tokens": 12, + "output_tokens": 12, + } + }, + }, + { + "task": "B", + "model_name": "test", + "distilabel_metadata": { + "statistics_embedding_task_generator": { + "input_tokens": 12, + "output_tokens": 12, + } + }, + }, + { + "task": "C", + "model_name": "test", + "distilabel_metadata": { + "statistics_embedding_task_generator": { + "input_tokens": 12, + "output_tokens": 12, + } + }, + }, ], True, ) @@ -131,7 +180,20 @@ def test_process(self) -> None: assert task.outputs == ["S1", "S2", "S3", "model_name"] assert next(task.process()) == ( - [{"S1": "A", "S2": "B", "S3": "C", "model_name": "test"}], + [ + { + "S1": "A", + "S2": "B", + "S3": "C", + "model_name": "test", + "distilabel_metadata": { + "statistics_bitext_retrieval_generator": { + "input_tokens": 12, + "output_tokens": 12, + } + }, + } + ], True, ) @@ -192,7 +254,20 @@ def test_process(self) -> None: task.load() assert task.outputs == ["S1", "S2", "S3", "model_name"] assert next(task.process()) == ( - [{"S1": "A", "S2": "B", "S3": "C", "model_name": "test"}], + [ + { + "S1": "A", + "S2": "B", + "S3": "C", + "model_name": "test", + "distilabel_metadata": { + "statistics_monolingual_triplet_generator": { + "input_tokens": 12, + "output_tokens": 12, + } + }, + } + ], True, ) @@ -241,7 +316,18 @@ def test_process(self) -> None: assert task.outputs == ["input", "positive_document", "model_name"] assert next(task.process(inputs=[{"task": "A"}])) == [ - {"task": "A", "input": "A", "positive_document": "B", "model_name": "test"} + { + "task": "A", + "input": "A", + "positive_document": "B", + "model_name": "test", + "distilabel_metadata": { + "statistics_generate_long_text_matching_data": { + "input_tokens": 12, + "output_tokens": 12, + } + }, + } ] @@ -271,7 +357,18 @@ def test_process(self) -> None: task.load() assert task.outputs == ["input", "positive_document", "model_name"] assert next(task.process(inputs=[{"task": "A"}])) == [ - {"task": "A", "input": "A", "positive_document": "B", "model_name": "test"} + { + "task": "A", + "input": "A", + "positive_document": "B", + "model_name": "test", + "distilabel_metadata": { + "statistics_generate_short_text_matching_data": { + "input_tokens": 12, + "output_tokens": 12, + } + }, + } ] def test_reproducibility(self) -> None: @@ -333,6 +430,12 @@ def test_process(self) -> None: "label": "B", "misleading_label": "C", "model_name": "test", + "distilabel_metadata": { + "statistics_generate_text_classification_data": { + "input_tokens": 12, + "output_tokens": 12, + } + }, } ] @@ -410,5 +513,11 @@ def test_process(self) -> None: "positive_document": "B", "hard_negative_document": "C", "model_name": "test", + "distilabel_metadata": { + "statistics_generate_text_retrieval_data": { + "input_tokens": 12, + "output_tokens": 12, + } + }, } ] diff --git a/tests/unit/steps/tasks/test_instruction_backtranslation.py b/tests/unit/steps/tasks/test_instruction_backtranslation.py index 405195ef02..5e54d94658 100644 --- a/tests/unit/steps/tasks/test_instruction_backtranslation.py +++ b/tests/unit/steps/tasks/test_instruction_backtranslation.py @@ -35,7 +35,15 @@ def generate( self, inputs: List[ChatType], num_generations: int = 1, **kwargs: Any ) -> List[GenerateOutput]: return [ - ["This is the reason. Score: 1" for _ in range(num_generations)] + { + "generations": [ + "This is the reason. Score: 1" for _ in range(num_generations) + ], + "statistics": { + "input_tokens": [12] * num_generations, + "output_tokens": [12] * num_generations, + }, + } for _ in inputs ] @@ -88,7 +96,11 @@ def test_process(self) -> None: "reason": "This is the reason.", "model_name": "instruction-backtranslation-model", "distilabel_metadata": { - "raw_output_instruction-backtranslation": "This is the reason. Score: 1" + "raw_output_instruction-backtranslation": "This is the reason. Score: 1", + "statistics_instruction-backtranslation": { + "input_tokens": 12, + "output_tokens": 12, + }, }, } ] diff --git a/tests/unit/steps/tasks/test_structured_generation.py b/tests/unit/steps/tasks/test_structured_generation.py index 82b86ee93d..125b26ed37 100644 --- a/tests/unit/steps/tasks/test_structured_generation.py +++ b/tests/unit/steps/tasks/test_structured_generation.py @@ -37,7 +37,15 @@ def generate( # type: ignore self, inputs: List["StructuredInput"], num_generations: int = 1, **kwargs: Any ) -> List["GenerateOutput"]: return [ - [json.dumps({"test": "output"}) for _ in range(num_generations)] + { + "generations": [ + json.dumps({"test": "output"}) for _ in range(num_generations) + ], + "statistics": { + "input_tokens": [12] * num_generations, + "output_tokens": [12] * num_generations, + }, + } for _ in inputs ] @@ -123,6 +131,9 @@ def test_process(self) -> None: }, "generation": '{"test": "output"}', "model_name": "test", - "distilabel_metadata": {"raw_output_task": '{"test": "output"}'}, + "distilabel_metadata": { + "raw_output_task": '{"test": "output"}', + "statistics_task": {"input_tokens": 12, "output_tokens": 12}, + }, } ] diff --git a/tests/unit/steps/tasks/test_text_classification.py b/tests/unit/steps/tasks/test_text_classification.py index d9c36f58a5..c1bcf47e24 100644 --- a/tests/unit/steps/tasks/test_text_classification.py +++ b/tests/unit/steps/tasks/test_text_classification.py @@ -32,11 +32,18 @@ async def agenerate( # type: ignore self, input: "FormattedInput", num_generations: int = 1 ) -> "GenerateOutput": if self.n == 1: - return [json.dumps({"labels": "label"}) for _ in range(num_generations)] - return [ - json.dumps({"labels": [f"label_{i}" for i in range(self.n)]}) - for _ in range(num_generations) - ] + labels = "label" + else: + labels = ["label_0", "label_1", "label_2"] + return { + "generations": [ + json.dumps({"labels": labels}) for _ in range(num_generations) + ], + "statistics": { + "input_tokens": [12] * num_generations, + "output_tokens": [12] * num_generations, + }, + } class TestTextClassification: diff --git a/tests/unit/steps/tasks/test_text_generation.py b/tests/unit/steps/tasks/test_text_generation.py index 2a6abefb22..ad9b690430 100644 --- a/tests/unit/steps/tasks/test_text_generation.py +++ b/tests/unit/steps/tasks/test_text_generation.py @@ -103,6 +103,7 @@ def test_process(self) -> None: "model_name": "test", "distilabel_metadata": { "raw_output_task": "output", + "statistics_task": {"input_tokens": 12, "output_tokens": 12}, }, } ] @@ -230,6 +231,9 @@ def test_process(self) -> None: "messages": [{"role": "user", "content": "Tell me a joke."}], "generation": "output", "model_name": "test", - "distilabel_metadata": {"raw_output_task": "output"}, + "distilabel_metadata": { + "raw_output_task": "output", + "statistics_task": {"input_tokens": 12, "output_tokens": 12}, + }, } ] diff --git a/tests/unit/steps/tasks/test_ultrafeedback.py b/tests/unit/steps/tasks/test_ultrafeedback.py index 46ed061838..3754c8803d 100644 --- a/tests/unit/steps/tasks/test_ultrafeedback.py +++ b/tests/unit/steps/tasks/test_ultrafeedback.py @@ -36,12 +36,17 @@ def generate( self, inputs: List[ChatType], num_generations: int = 1, **kwargs: Any ) -> List[GenerateOutput]: return [ - [ - "Type: 1\nRationale: text\nRating: 1\nRationale: text\n\nType: 2\nRationale: text\nRating: 2\nRationale: text" - for _ in range(num_generations) - ] - for _ in inputs - ] + { + "generations": [ + "Type: 1\nRationale: text\nRating: 1\nRationale: text\n\nType: 2\nRationale: text\nRating: 2\nRationale: text" + for i in range(num_generations) + ], + "statistics": { + "input_tokens": [12] * num_generations, + "output_tokens": [12] * num_generations, + }, + } + ] * len(inputs) class TestUltraFeedback: @@ -65,7 +70,11 @@ def test_process_with_simple_aspect(self) -> None: "rationales": ["text", "text"], "model_name": "ultrafeedback-model", "distilabel_metadata": { - "raw_output_ultrafeedback": "Type: 1\nRationale: text\nRating: 1\nRationale: text\n\nType: 2\nRationale: text\nRating: 2\nRationale: text" + "raw_output_ultrafeedback": "Type: 1\nRationale: text\nRating: 1\nRationale: text\n\nType: 2\nRationale: text\nRating: 2\nRationale: text", + "statistics_ultrafeedback": { + "input_tokens": 12, + "output_tokens": 12, + }, }, } ] @@ -92,7 +101,11 @@ def test_process_with_complex_aspect(self) -> None: "rationales-for-ratings": ["text", "text"], "model_name": "ultrafeedback-model", "distilabel_metadata": { - "raw_output_ultrafeedback": "Type: 1\nRationale: text\nRating: 1\nRationale: text\n\nType: 2\nRationale: text\nRating: 2\nRationale: text" + "raw_output_ultrafeedback": "Type: 1\nRationale: text\nRating: 1\nRationale: text\n\nType: 2\nRationale: text\nRating: 2\nRationale: text", + "statistics_ultrafeedback": { + "input_tokens": 12, + "output_tokens": 12, + }, }, } ]