From ee2b9fda6f2a8736cd2330b79c8a3807211f7f46 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Tue, 19 Dec 2023 11:08:55 +0100 Subject: [PATCH 1/6] Update `LLM` with thread pool to return `Future[List[List[LLMOutput]]]` (#164) * Update `LLM` to return `Future[List[List[LLMOutput]]]` * Update example to read API key from env --- examples/pipeline-pool-llm.py | 2 +- src/distilabel/llm/base.py | 60 +++++----------- src/distilabel/pipeline.py | 119 ++++++++++++++------------------ src/distilabel/utils/futures.py | 48 +++++++++++++ src/distilabel/utils/types.py | 17 ++--- 5 files changed, 122 insertions(+), 124 deletions(-) create mode 100644 src/distilabel/utils/futures.py diff --git a/examples/pipeline-pool-llm.py b/examples/pipeline-pool-llm.py index 781dcff86a..868acf0cc8 100644 --- a/examples/pipeline-pool-llm.py +++ b/examples/pipeline-pool-llm.py @@ -39,7 +39,7 @@ def load_openai(task): return OpenAILLM( model="gpt-3.5-turbo", task=task, - openai_api_key="", + openai_api_key=os.getenv("OPENAI_API_KEY"), max_new_tokens=512, ) diff --git a/src/distilabel/llm/base.py b/src/distilabel/llm/base.py index 7a41c6ba9e..3be4f95db8 100644 --- a/src/distilabel/llm/base.py +++ b/src/distilabel/llm/base.py @@ -21,7 +21,6 @@ from abc import ABC, abstractmethod from concurrent.futures import Future, ThreadPoolExecutor from ctypes import c_char -from enum import Enum, auto from functools import cached_property from threading import Thread from typing import ( @@ -37,7 +36,7 @@ from distilabel.logger import get_logger from distilabel.tasks.prompt import Prompt -from distilabel.utils.types import is_list_of_futures +from distilabel.utils.futures import when_all_complete if TYPE_CHECKING: from distilabel.llm.utils import LLMOutput @@ -47,22 +46,6 @@ logger = get_logger() -class LLMFutures(Enum): - """An enum used to indicate whether the `LLM` returns futures or not, and if so - what the futures contains. - - Attributes: - CONTAINS_ROWS: used to indicate that the `LLM` will return `Future`s that - contains a `List[List[LLMOutput]]`. The first list just contains 1 element - (the row) and the second list contains the generations for that row. - CONTAINS_BATCHES: used to indicate that the `LLM` will return `Future`s that - contains a `List[List[LLMOutput]]`. - """ - - CONTAINS_ROWS = auto() - CONTAINS_BATCHES = auto() - - class LLM(ABC): def __init__( self, @@ -198,7 +181,7 @@ def generate( inputs: List[Dict[str, Any]], num_generations: int = 1, progress_callback_func: Union[Callable, None] = None, - ) -> Union[List[Future[List["LLMOutput"]]], List[List["LLMOutput"]]]: + ) -> Union[List[List["LLMOutput"]], Future[List[List["LLMOutput"]]]]: """Generates the outputs for the given inputs using the LLM. Args: @@ -214,12 +197,7 @@ def generate( def _progress(): if progress_callback_func is not None: - advance = ( - num_generations * len(inputs) - if self.return_futures is not None - else num_generations - ) - progress_callback_func(advance=advance) + progress_callback_func(advance=num_generations * len(inputs)) if self.thread_pool_executor is not None: futures = [] @@ -227,19 +205,19 @@ def _progress(): future = self.thread_pool_executor.submit( self._generate, [input], num_generations ) - future.add_done_callback(lambda _: _progress()) futures.append(future) - return futures + future = when_all_complete(futures) + future.add_done_callback(lambda _: _progress()) + return future generations = self._generate(inputs, num_generations) _progress() return generations @property - def return_futures(self) -> Union[LLMFutures, None]: - """Returns whether the LLM returns futures or not, and if so what the futures - contains.""" - return LLMFutures.CONTAINS_ROWS + def return_futures(self) -> bool: + """Whether the `LLM` returns futures""" + return True MAX_MODEL_NAME_LENGTH = 256 @@ -339,13 +317,7 @@ def run(self) -> None: inputs=request.inputs, num_generations=request.num_generations ) - # Generations are a list of `Future`s because the `LLM` is using a thread pool - if is_list_of_futures(results): - generations = [] - for future in results: - generations.extend(future.result()) - else: - generations = results + generations = results.result() if isinstance(results, Future) else results self._result_queue.put(_TextGenerationResult(generations)) @@ -591,10 +563,9 @@ def model_name(self) -> str: return "".join([c.decode() for c in self._model_name if c != b"\0"]) @property - def return_futures(self) -> Union[LLMFutures, None]: - """Returns whether the LLM returns futures or not, and if so what the futures - contains.""" - return LLMFutures.CONTAINS_BATCHES + def return_futures(self) -> bool: + """Whether the `LLM` returns futures""" + return True class LLMPool: @@ -719,5 +690,6 @@ def task(self) -> "Task": return self.llms[0].task @property - def return_futures(self) -> Union[LLMFutures, None]: - return None + def return_futures(self) -> bool: + """Whether the `LLM` returns futures""" + return False diff --git a/src/distilabel/pipeline.py b/src/distilabel/pipeline.py index 78e88d74a1..5cccb8b5e7 100644 --- a/src/distilabel/pipeline.py +++ b/src/distilabel/pipeline.py @@ -33,7 +33,7 @@ from datasets import Dataset, Split from distilabel.dataset import CustomDataset -from distilabel.llm.base import LLM, LLMFutures, LLMPool, ProcessLLM +from distilabel.llm.base import LLM, LLMPool, ProcessLLM from distilabel.llm.utils import LLMOutput from distilabel.logger import get_logger from distilabel.progress_bar import ( @@ -42,7 +42,7 @@ use_progress_bar, ) from distilabel.utils.dicts import combine_dicts -from distilabel.utils.types import is_list_of_futures +from distilabel.utils.types import is_future logger = get_logger() @@ -225,27 +225,6 @@ def _get_batch_generations( for _ in range(num_batches) ] ) - elif is_list_of_futures(outputs): - for future in outputs: - # Result of future is `List[List[LLMOutput]]` (first list contains 1 - # element, and the second list contains `num_generations` elements) - try: - batch_generations.extend(future.result()) - except Exception as e: - logger.error( - f"An error ocurred when getting the result of a future from the generator: {e}" - ) - batch_generations.extend( - [ - LLMOutput( - model_name=self.generator.model_name, - prompt_used=None, - raw_output=None, - parsed_output=None, - ) - for _ in range(num_generations) - ] - ) else: batch_generations = outputs return self._process_batch_generations(batch_generations=batch_generations) @@ -254,11 +233,7 @@ def _get_batch_labels( self, inputs: List[Dict[str, Any]], progress_callback_func: Union[Callable, None] = None, - ) -> Union[ - Future[List[List["LLMOutput"]]], - List[Future[List[List["LLMOutput"]]]], - List[List["LLMOutput"]], - ]: + ) -> Union[List[List["LLMOutput"]], Future[List[List["LLMOutput"]]]]: """Gets the batch labels for the given inputs. Args: @@ -269,7 +244,7 @@ def _get_batch_labels( to `None`. Returns: - Union[Future[List["LLMOutput"]], List[Future], List["LLMOutput"]]: the batch + Union[List[List["LLMOutput"]], Future[List[List["LLMOutput"]]]]: the batch labels. """ @@ -430,9 +405,8 @@ def _build_dataset( # noqa: C901 dataset: Dataset, generations: List[Dict[str, Any]], labels: Union[ - Future[List[List["LLMOutput"]]], - List[Future[List[List["LLMOutput"]]]], List[List["LLMOutput"]], + Future[List[List["LLMOutput"]]], ], batch_size: int, ) -> CustomDataset: @@ -442,8 +416,8 @@ def _build_dataset( # noqa: C901 Args: dataset (Dataset): the original dataset. generations (List[Dict[str, Any]]): the processed generations. - labels (Union[List[Future["LLMOutput"]], List["LLMOutput"]]): the processed - labels. + labels (Union[List[List[LLMOutput]], Future[List[List[LLMOutput]]]]): the + processed labels. Returns: CustomDataset: the final dataset. @@ -478,7 +452,7 @@ def _build_dataset( # noqa: C901 processed_labels = [{} for _ in range(len(dataset))] # type: ignore else: batch_labels = [] - if self.labeller.return_futures is not None: + if self.labeller.return_futures: for i, future in enumerate(labels, start=1): # type: ignore try: batch_labels.extend(future.result()) @@ -486,16 +460,8 @@ def _build_dataset( # noqa: C901 logger.error( f"An error occurred when getting the result from the labeller: {e}" ) - # If the LLM returned a list of futures (`LLM` with thread pool - # executor), and each future contains just the result for a single - # row, then we need to create an empty LLMOutput for each future. - # If the LLM returns a future containing `batch_size` rows in the - # result (`ProcessLLM`), then we need to create a list of empty - # LLMOutputs with the length of the `batch_size` num_outputs = ( - 1 - if self.labeller.return_futures == LLMFutures.CONTAINS_ROWS - else batch_size + batch_size if i * batch_size <= len(dataset) else len(dataset) % batch_size ) @@ -624,9 +590,8 @@ def _generate( # noqa: C901 generations: List[Dict[str, Any]] = [] labels: Union[ - Future[List[List["LLMOutput"]]], - List[Future[List[List["LLMOutput"]]]], List[List["LLMOutput"]], + Future[List[List["LLMOutput"]]], ] = [] ( @@ -677,8 +642,8 @@ def _generate( # noqa: C901 inputs=inputs, progress_callback_func=labelling_progress_func ) - if isinstance(batch_labels, Future): - labels.append(batch_labels) + if is_future(batch_labels): + labels.append(batch_labels) # type: ignore else: labels.extend(batch_labels) # type: ignore except Exception as e: @@ -702,6 +667,38 @@ def _generate( # noqa: C901 dataset, generations=generations, labels=labels, batch_size=batch_size ) + def dry_run(self, dataset: Dataset) -> CustomDataset: + """Performs a dry run over the provided dataset, which consists on generating the + outputs for the first row of the dataset, to ensure that the `Pipeline` will be + able to generate the outputs for the whole dataset. + + Args: + dataset (Dataset): the dataset to be used for generation. Just the first row + will be used for the dry run. + + Returns: + CustomDataset: the dataset containing the outputs for the first row. + """ + try: + # First we generate a `Dataset` only with the first row from the whole dataset + subset = Dataset.from_dict( + {key: [value] for key, value in dataset[0].items()} + ) + # Then we call the `_generate` method with it + return self._generate( + dataset=subset, + # Default kwargs to make the process as simple as possible + num_generations=1, + batch_size=1, + enable_checkpoints=False, + display_progress_bar=False, + ) + except Exception as e: + self._teardown() + raise RuntimeError( + f"`Pipeline.generate` failed during the dry run over {dataset[0]} with exception: {e}" + ) from e + def generate( self, dataset: Dataset, @@ -709,6 +706,7 @@ def generate( batch_size: int = 1, enable_checkpoints: bool = True, display_progress_bar: bool = False, + skip_dry_run: bool = False, ) -> CustomDataset: """Generates the outputs for the given dataset using the LLMs provided to the `Pipeline`. @@ -719,6 +717,7 @@ def generate( batch_size (int, optional): the batch size to be used for generation. Defaults to `1`. enable_checkpoints (bool, optional): whether to enable checkpoints or not. Defaults to `True`. display_progress_bar (bool, optional): whether to display the progress bar or not. Defaults to `False`. + skip_dry_run (bool, optional): whether to skip the dry run or not. Defaults to `False`. Returns: CustomDataset: the final dataset. @@ -747,30 +746,12 @@ def generate( >>> pipeline = Pipeline(generator=generator, labeller=labeller) >>> dataset = pipeline.generate(dataset=..., num_generations=1, batch_size=1) """ - try: + if not skip_dry_run: logger.info("Executing dry-run...") - # First we generate a `Dataset` only with the first row from the whole dataset - subset = Dataset.from_dict( - {key: [value] for key, value in dataset[0].items()} - ) - # Then we call the `_generate` method with it - _ = self._generate( - dataset=subset, - # Default kwargs to make the process as simple as possible - num_generations=1, - batch_size=1, - enable_checkpoints=False, - display_progress_bar=False, + self.dry_run(dataset) + logger.info( + "Dry-run executed with no issues. Starting the actual generation..." ) - except Exception as e: - self._teardown() - raise RuntimeError( - f"`Pipeline.generate` failed during the dry run over {dataset[0]} with exception: {e}" - ) from e - - logger.info( - "Dry-run executed with no issues. Starting the actual generation..." - ) dataset = use_progress_bar(self._generate)( dataset=dataset, diff --git a/src/distilabel/utils/futures.py b/src/distilabel/utils/futures.py new file mode 100644 index 0000000000..02c1e7fbc7 --- /dev/null +++ b/src/distilabel/utils/futures.py @@ -0,0 +1,48 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from concurrent.futures import Future, wait +from typing import List + +from typing_extensions import TypeVar + +T = TypeVar("T") + + +def when_all_complete(futures: List[Future[T]]) -> Future[T]: + """Returns a `Future` that will be completed when all the provided `futures` are + completed, and it will contain the results of the `futures`. + + Args: + futures (List[Future]): the `Future`s to wait for. + + Returns: + Future: the `Future` that will be completed when all the provided `futures` are + completed, and it will contain the results of the `futures`. + """ + all_done_future = Future() + results = [] + + def check_all_done(future: Future) -> None: + results.extend(future.result()) + _, not_done = wait(futures, return_when="FIRST_COMPLETED") + if len(not_done) == 0: + all_done_future.set_result(results) + + for future in futures: + future.add_done_callback(check_all_done) + + return all_done_future diff --git a/src/distilabel/utils/types.py b/src/distilabel/utils/types.py index bf1753bc58..867517658a 100644 --- a/src/distilabel/utils/types.py +++ b/src/distilabel/utils/types.py @@ -15,23 +15,20 @@ from __future__ import annotations from concurrent.futures import Future -from typing import List, Union +from typing import Any, Union from typing_extensions import TypeGuard, TypeVar -T = TypeVar("FutureResult") # type: ignore +T = TypeVar("T") -def is_list_of_futures( - results: Union[List[Future[T]], List[List[T]]], -) -> TypeGuard[List[Future[T]]]: - """Check if results is a list of futures. This function narrows the type of - `results` to `List[Future[T]]` if it is a list of futures. +def is_future(obj: Union[Future[T], Any]) -> TypeGuard[Future[T]]: + """Checks if an object is a future narrowing the type. Args: - results: A list of futures. + obj (Future[T]): Object to check Returns: - `True` if `results` is a list of futures, `False` otherwise. + TypeGuard[Future[T]]: True if it is a future """ - return isinstance(results, list) and isinstance(results[0], Future) + return isinstance(obj, Future) From 4047d69e1a17ff9e53a16ba119deb399a4a78f93 Mon Sep 17 00:00:00 2001 From: Alvaro Bartolome Date: Tue, 19 Dec 2023 11:17:15 +0100 Subject: [PATCH 2/6] Add `PrometheusTask` (#165) * Rename and move `UltraCMOutput` to `CritiqueTaskOutput` * Add `PrometheusTask` and `prometheus.jinja2` template * Fix `parse_output` in `PrometheusTask` * Add `examples/pipeline-prometheus.py` --- examples/pipeline-prometheus.py | 63 +++++++++++++++++ .../tasks/_templates/prometheus.jinja2 | 25 +++++++ src/distilabel/tasks/critique/base.py | 9 +++ src/distilabel/tasks/critique/prometheus.py | 67 +++++++++++++++++++ src/distilabel/tasks/critique/ultracm.py | 15 ++--- 5 files changed, 168 insertions(+), 11 deletions(-) create mode 100644 examples/pipeline-prometheus.py create mode 100644 src/distilabel/tasks/_templates/prometheus.jinja2 create mode 100644 src/distilabel/tasks/critique/prometheus.py diff --git a/examples/pipeline-prometheus.py b/examples/pipeline-prometheus.py new file mode 100644 index 0000000000..0a886f61ad --- /dev/null +++ b/examples/pipeline-prometheus.py @@ -0,0 +1,63 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import torch +from datasets import Dataset +from distilabel.llm import TransformersLLM +from distilabel.pipeline import Pipeline +from distilabel.tasks.critique.prometheus import PrometheusTask +from transformers import AutoTokenizer, LlamaForCausalLM + +if __name__ == "__main__": + model = LlamaForCausalLM.from_pretrained( + "kaist-ai/Prometheus-7b-v1.0", torch_dtype=torch.float16, device_map="auto" + ) + tokenizer = AutoTokenizer.from_pretrained( + "meta-llama/Llama-2-7b-chat-hf", token=os.getenv("HF_TOKEN") + ) + pipeline = Pipeline( + labeller=TransformersLLM( + model=model, # type: ignore + tokenizer=tokenizer, + task=PrometheusTask( + scoring_criteria="Is the provided completion accurate based on the given instruction?", + score_descriptions={ + 0: "Totaly off-topic and inaccurate", + 1: "Incorrect and inaccurate", + 2: "Almost correct, but partially inaccurate", + 3: "Correct but badly phrased", + 4: "Correct and accurate", + }, + ), + temperature=1.0, + top_p=1.0, + max_new_tokens=512, + ), + ) + + dataset = Dataset.from_dict( + { + "instruction": ["What's the capital of Spain?"], + "completion": ["Paris"], + "ref_completion": ["Madrid"], + } + ) + + dataset = pipeline.generate( + dataset, # type: ignore + display_progress_bar=True, + skip_dry_run=True, + ) diff --git a/src/distilabel/tasks/_templates/prometheus.jinja2 b/src/distilabel/tasks/_templates/prometheus.jinja2 new file mode 100644 index 0000000000..f3f1ef80da --- /dev/null +++ b/src/distilabel/tasks/_templates/prometheus.jinja2 @@ -0,0 +1,25 @@ +###Task Description: +An instruction (might include an Input inside it), a response to evaluate, a reference answer that gets a score of 5, and a score rubric representing a evaluation criteria are given. +1. Write a detailed feedback that assess the quality of the response strictly based on the given score rubric, not evaluating in general. +2. After writing a feedback, write a score that is an integer between 1 and 5. You should refer to the score rubric. +3. The output format should look as follows: \"Feedback: (write a feedback for criteria) [RESULT] (an integer number between 1 and 5)\" +4. Please do not generate any other opening, closing, and explanations. + +###The instruction to evaluate: +{{ instruction }} + +###Response to evaluate: +{{ completion }} + +###Reference Answer (Score 5): +{{ ref_completion }} + +###Score Rubrics: +[{{ scoring_criteria }}] +Score 1: {{ score_descriptions[0] }} +Score 2: {{ score_descriptions[1] }} +Score 3: {{ score_descriptions[2] }} +Score 4: {{ score_descriptions[3] }} +Score 5: {{ score_descriptions[4] }} + +###Feedback: diff --git a/src/distilabel/tasks/critique/base.py b/src/distilabel/tasks/critique/base.py index 4231482dbd..7caabc73ba 100644 --- a/src/distilabel/tasks/critique/base.py +++ b/src/distilabel/tasks/critique/base.py @@ -15,6 +15,8 @@ from dataclasses import dataclass from typing import List +from typing_extensions import TypedDict + from distilabel.tasks.base import Task @@ -36,3 +38,10 @@ def input_args_names(self) -> List[str]: def output_args_names(self) -> List[str]: """Returns the names of the output arguments of the task.""" return ["critique", "score"] + + +class CritiqueTaskOutput(TypedDict): + """A `TypedDict` matching the output format of any `CritiqueTask`.""" + + score: float + critique: str diff --git a/src/distilabel/tasks/critique/prometheus.py b/src/distilabel/tasks/critique/prometheus.py new file mode 100644 index 0000000000..988c01fb1c --- /dev/null +++ b/src/distilabel/tasks/critique/prometheus.py @@ -0,0 +1,67 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from dataclasses import dataclass +from typing import ClassVar, Dict, List + +from distilabel.tasks.base import get_template +from distilabel.tasks.critique.base import CritiqueTask, CritiqueTaskOutput +from distilabel.tasks.prompt import Prompt + +_PROMETHEUS_TEMPLATE = get_template("prometheus.jinja2") + + +@dataclass +class PrometheusTask(CritiqueTask): + scoring_criteria: str + score_descriptions: Dict[int, str] + + system_prompt: str = "You are a fair evaluator language model." + + __jinja2_template__: ClassVar[str] = _PROMETHEUS_TEMPLATE + + @property + def input_args_names(self) -> List[str]: + return super().input_args_names + ["ref_completion"] + + def generate_prompt( + self, + instruction: str, + completion: str, + ref_completion: str, + ) -> str: + render_kwargs = { + "instruction": instruction, + "completion": completion, + "ref_completion": ref_completion, + "scoring_criteria": self.scoring_criteria, + "score_descriptions": self.score_descriptions, + } + return Prompt( + system_prompt=self.system_prompt, + formatted_prompt=self.template.render(**render_kwargs), + ).format_as(format="llama2") # type: ignore + + def parse_output(self, output: str) -> CritiqueTaskOutput: # type: ignore + """Parses the output of the model into the desired format.""" + # We use a regex instead of splitting by the delimiter because the + # critique may contain the delimiter, and using the regex is safer. + pattern = r"(.+?)\. \[RESULT\] (\d+)" + match = re.match(pattern, output) + if match: + return CritiqueTaskOutput( + score=float(match.group(2)), + critique=match.group(1).strip(), + ) diff --git a/src/distilabel/tasks/critique/ultracm.py b/src/distilabel/tasks/critique/ultracm.py index e231f5a36d..331e6de68b 100644 --- a/src/distilabel/tasks/critique/ultracm.py +++ b/src/distilabel/tasks/critique/ultracm.py @@ -14,21 +14,14 @@ import re from dataclasses import dataclass -from typing import ClassVar, TypedDict +from typing import ClassVar from distilabel.tasks.base import get_template -from distilabel.tasks.critique.base import CritiqueTask +from distilabel.tasks.critique.base import CritiqueTask, CritiqueTaskOutput _ULTRACM_TEMPLATE = get_template("ultracm.jinja2") -class UltraCMOutput(TypedDict): - """A `TypedDict` matching the output format of UltraCM.""" - - score: float - critique: str - - @dataclass class UltraCMTask(CritiqueTask): __jinja2_template__: ClassVar[str] = _ULTRACM_TEMPLATE @@ -46,12 +39,12 @@ def generate_prompt(self, instruction: str, completion: str) -> str: } return f"{self.system_prompt}\nUser: {self.template.render(**render_kwargs)}\nAssistant: ### Feedback\nOverall Score: " - def parse_output(self, output: str) -> UltraCMOutput: # type: ignore + def parse_output(self, output: str) -> CritiqueTaskOutput: # type: ignore """Parses the output of the model into the desired format.""" pattern = r"(\d+(?:\.\d+)?)\s*(.*)" match = re.match(pattern, output) if match: - return UltraCMOutput( + return CritiqueTaskOutput( score=float(match.group(1)), critique=match.group(2).strip(), ) From e06453d860765d5299fd9819a270bafa012b6a0b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Tue, 19 Dec 2023 12:55:32 +0100 Subject: [PATCH 3/6] Randomise generations order (#167) * Randomise generations order * Add `shuffle_before_labelling` parameter --- src/distilabel/pipeline.py | 45 ++++++++++++++++++++++++++++++-------- 1 file changed, 36 insertions(+), 9 deletions(-) diff --git a/src/distilabel/pipeline.py b/src/distilabel/pipeline.py index 5cccb8b5e7..93d31c1d33 100644 --- a/src/distilabel/pipeline.py +++ b/src/distilabel/pipeline.py @@ -16,6 +16,7 @@ import math import os +import random import warnings from concurrent.futures import Future from typing import ( @@ -180,6 +181,7 @@ def _get_batch_generations( inputs: List[Dict[str, Any]], num_generations: int, num_batches: int, + shuffle_before_labelling: bool = True, progress_callback_func: Union[Callable, None] = None, ) -> List[Dict[str, Any]]: """Gets the batch generations for the given inputs, capturing the futures if the @@ -189,6 +191,10 @@ def _get_batch_generations( inputs (List[Dict[str, Any]]): the inputs to be used for generation. num_generations (int): the number of generations to be performed for each input. + num_batches (int): the number of batches to be processed. + shuffle_before_labelling (bool, optional): whether to shuffle the generations + before labelling or not. This is useful to avoid the labelling LLM to be + biased by the order of the generations. Defaults to `True`. progress_callback_func (Union[Callable, None], optional): the callback function to be called when the progress of the generation process changes. Defaults to None. @@ -227,7 +233,10 @@ def _get_batch_generations( ) else: batch_generations = outputs - return self._process_batch_generations(batch_generations=batch_generations) + return self._process_batch_generations( + batch_generations=batch_generations, + shuffle_before_labelling=shuffle_before_labelling, + ) def _get_batch_labels( self, @@ -259,12 +268,16 @@ def _get_batch_labels( def _process_batch_generations( self, batch_generations: List[List["LLMOutput"]], + shuffle_before_labelling: bool = True, ) -> List[Dict[str, Any]]: """Processes the batch generations, combining the outputs of the LLMs into a single dictionary. Args: batch_generations (List[List["LLMOutput"]]): the batch generations to be processed. + shuffle_before_labelling (bool, optional): whether to shuffle the generations + before labelling or not. This is useful to avoid the labelling LLM to be + biased by the order of the generations. Defaults to `True`. Returns: List[Dict[str, Any]]: the processed batch generations. @@ -276,6 +289,8 @@ def _process_batch_generations( "generation_prompt": [], "raw_generation_responses": [], } + if shuffle_before_labelling: + random.shuffle(generations) for generation in generations: processed_generation["generation_model"].append( generation["model_name"] @@ -531,26 +546,34 @@ def _generate( # noqa: C901 dataset: Dataset, num_generations: int = 1, batch_size: int = 1, + shuffle_before_labelling: bool = True, enable_checkpoints: bool = True, display_progress_bar: bool = False, ) -> CustomDataset: - """Generates the outputs for the given dataset using the LLMs provided to the `Pipeline`. + """Generates the outputs for the given dataset using the LLMs provided to the + `Pipeline`. Args: dataset (Dataset): the dataset to be used for generation. - num_generations (int, optional): the number of generations to be performed for each - input. Defaults to `1`. - batch_size (int, optional): the batch size to be used for generation. Defaults to `1`. - enable_checkpoints (bool, optional): whether to enable checkpoints or not. Defaults to `True`. - display_progress_bar (bool, optional): whether to display the progress bar or not. Defaults to `False`. + num_generations (int, optional): the number of generations to be performed + for each input. Defaults to `1`. + batch_size (int, optional): the batch size to be used for generation. Defaults + to `1`. + shuffle_before_labelling (bool, optional): whether to shuffle the generations + before labelling or not. This is useful to avoid the labelling LLM to be + biased by the order of the generations. Defaults to `True`. + enable_checkpoints (bool, optional): whether to enable checkpoints or not. + Defaults to `True`. + display_progress_bar (bool, optional): whether to display the progress bar + or not. Defaults to `False`. Returns: CustomDataset: the final dataset. Raises: RuntimeError: if the `Pipeline` fails during the generation or labelling steps. - UserWarning: if the `Pipeline` fails during the generation or labelling steps and - `enable_checkpoints` is set to `False`. + UserWarning: if the `Pipeline` fails during the generation or labelling steps + and `enable_checkpoints` is set to `False`. Examples: >>> from distilabel.llm.huggingface import TransformersLLM @@ -704,6 +727,7 @@ def generate( dataset: Dataset, num_generations: int = 1, batch_size: int = 1, + shuffle_before_labelling: bool = True, enable_checkpoints: bool = True, display_progress_bar: bool = False, skip_dry_run: bool = False, @@ -715,6 +739,9 @@ def generate( num_generations (int, optional): the number of generations to be performed for each input. Defaults to `1`. batch_size (int, optional): the batch size to be used for generation. Defaults to `1`. + shuffle_before_labelling: whether to shuffle the generations before labelling + or not. This is useful to avoid the labelling LLM to be biased by the order + of the generations. Defaults to `True`. enable_checkpoints (bool, optional): whether to enable checkpoints or not. Defaults to `True`. display_progress_bar (bool, optional): whether to display the progress bar or not. Defaults to `False`. skip_dry_run (bool, optional): whether to skip the dry run or not. Defaults to `False`. From 62993ac187a8c7c47ceacff04b7d83ee06f14664 Mon Sep 17 00:00:00 2001 From: Alvaro Bartolome Date: Tue, 19 Dec 2023 16:00:57 +0100 Subject: [PATCH 4/6] Add custom `to_argilla_{dataset,record}` to `SelfInstructTask` (#169) * Update `to_argilla_record` return type-hint * Add TODO for future self * Add custom `to_argilla_{dataset,record}` in `SelfInstructTask` * Fix loop over instructions in `SelfInstructTask` * Remove `print` from `parse_output` --- src/distilabel/tasks/base.py | 2 +- src/distilabel/tasks/text_generation/base.py | 3 + .../tasks/text_generation/self_instruct.py | 105 +++++++++++++++++- 3 files changed, 107 insertions(+), 3 deletions(-) diff --git a/src/distilabel/tasks/base.py b/src/distilabel/tasks/base.py index 1c9ec3f5a8..251c25d214 100644 --- a/src/distilabel/tasks/base.py +++ b/src/distilabel/tasks/base.py @@ -113,7 +113,7 @@ def to_argilla_dataset( def to_argilla_record( self, dataset_row: Dict[str, Any], *args: Any, **kwargs: Any - ) -> "FeedbackRecord": + ) -> Union["FeedbackRecord", List["FeedbackRecord"]]: raise NotImplementedError( "`to_argilla_record` is not implemented, if you want to export your dataset as an Argilla" " `FeedbackDataset` you will need to implement this method first." diff --git a/src/distilabel/tasks/text_generation/base.py b/src/distilabel/tasks/text_generation/base.py index 6a069773b7..2d4c8464c6 100644 --- a/src/distilabel/tasks/text_generation/base.py +++ b/src/distilabel/tasks/text_generation/base.py @@ -212,6 +212,9 @@ def to_argilla_record(self, dataset_row: Dict[str, Any]) -> "FeedbackRecord": arg_value = dataset_row[arg_name] if isinstance(arg_value, list): for idx, value in enumerate(arg_value, start=1): + # TODO: value formatting was included here due to some issues + # with `SelfInstructTask` but these list-parsing may not be needed + # anymore. value = ( value.strip() if isinstance(value, str) diff --git a/src/distilabel/tasks/text_generation/self_instruct.py b/src/distilabel/tasks/text_generation/self_instruct.py index 349682d59f..f787d46726 100644 --- a/src/distilabel/tasks/text_generation/self_instruct.py +++ b/src/distilabel/tasks/text_generation/self_instruct.py @@ -12,12 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. +import re +import warnings from dataclasses import dataclass -from typing import Dict, List +from typing import TYPE_CHECKING, Any, Dict, List, Optional from distilabel.tasks.base import get_template from distilabel.tasks.prompt import Prompt from distilabel.tasks.text_generation.base import TextGenerationTask +from distilabel.utils.argilla import infer_fields_from_dataset_row +from distilabel.utils.imports import _ARGILLA_AVAILABLE + +if _ARGILLA_AVAILABLE: + import argilla as rg + +if TYPE_CHECKING: + from argilla.client.feedback.dataset.local.dataset import FeedbackDataset + from argilla.client.feedback.schemas.records import FeedbackRecord _SELF_INSTRUCT_TEMPLATE = get_template("self-instruct.jinja2") @@ -79,6 +90,96 @@ def generate_prompt(self, input: str) -> Prompt: formatted_prompt=self.template.render(**render_kwargs), ) + @property + def output_args_names(self) -> List[str]: + return ["instructions"] + def parse_output(self, output: str) -> Dict[str, List[str]]: """Parses the output of the model into the desired format.""" - return {"generations": output.split("\n")} + pattern = re.compile(r"\d+\.\s+(.*?)\n") + return {"instructions": pattern.findall(output)} + + def to_argilla_dataset(self, dataset_row: Dict[str, Any]) -> "FeedbackDataset": + # First we infer the fields from the input_args_names, but we could also + # create those manually instead using `rg.TextField(...)` + fields = infer_fields_from_dataset_row( + field_names=self.input_args_names, + dataset_row=dataset_row, + ) + # Once the input fields have been defined, then we also include the instruction + # field which will be fulfilled with each of the instructions generated. + fields.append(rg.TextField(name="instruction", title="instruction")) # type: ignore + # Then we add a default `RatingQuestion` which asks the users to provide a + # rating for each of the generations, differing from the scenario where the inputs + # are the fields and the outputs the ones used to formulate the quesstions. So on, + # in this scenario we won't have suggestions, as the questions will be related to the + # combination of inputs and outputs. + questions = [ + rg.RatingQuestion( # type: ignore + name="instruction-rating", + title="How would you rate the generated instruction?", + values=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + ) + ] + # Finally, we define some metadata properties that can be potentially used + # while exploring the dataset within Argilla to get more insights on the data. + metadata_properties = [] + for arg_name in self.input_args_names: + if isinstance(dataset_row[arg_name], list): + for idx in range(1, len(dataset_row[arg_name]) + 1): + metadata_properties.append( + rg.IntegerMetadataProperty(name=f"length-{arg_name}-{idx}") # type: ignore + ) + elif isinstance(dataset_row[arg_name], str): + metadata_properties.append( + rg.IntegerMetadataProperty(name=f"length-{arg_name}") # type: ignore + ) + else: + warnings.warn( + f"Unsupported input type ({type(dataset_row[arg_name])}), skipping...", + UserWarning, + stacklevel=2, + ) + metadata_properties.append( + rg.IntegerMetadataProperty(name="length-instruction") # type: ignore + ) # type: ignore + # Then we just return the `FeedbackDataset` with the fields, questions, and metadata properties + # defined above. + return rg.FeedbackDataset( + fields=fields, + questions=questions, # type: ignore + metadata_properties=metadata_properties, # Note that these are always optional + ) + + def to_argilla_record( + self, + dataset_row: Dict[str, Any], + instructions_column: Optional[str] = "instructions", + ) -> List["FeedbackRecord"]: + """Converts a dataset row to a list of Argilla `FeedbackRecord`s.""" + records = [] + for instructions in dataset_row[instructions_column]: # type: ignore + for instruction in instructions: + fields, metadata = {}, {} + for arg_name in self.input_args_names: + arg_value = dataset_row[arg_name] + if isinstance(arg_value, list): + for idx, value in enumerate(arg_value, start=1): + value = value.strip() if isinstance(value, str) else "" + fields[f"{arg_name}-{idx}"] = value + if value is not None: + metadata[f"length-{arg_name}-{idx}"] = len(value) + elif isinstance(arg_value, str): + fields[arg_name] = arg_value.strip() if arg_value else "" + if arg_value is not None: + metadata[f"length-{arg_name}"] = len(arg_value.strip()) + else: + warnings.warn( + f"Unsupported input type ({type(arg_value)}), skipping...", + UserWarning, + stacklevel=2, + ) + fields["instruction"] = instruction + metadata["length-instruction"] = len(instruction) + records.append(rg.FeedbackRecord(fields=fields, metadata=metadata)) + return records From f91706c3eafaf41dc48aca26e43dba6c8500a9c0 Mon Sep 17 00:00:00 2001 From: Alvaro Bartolome Date: Tue, 19 Dec 2023 17:05:23 +0100 Subject: [PATCH 5/6] Add missing `shuffle_before_labelling` (#170) --- src/distilabel/pipeline.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/distilabel/pipeline.py b/src/distilabel/pipeline.py index 93d31c1d33..4309067807 100644 --- a/src/distilabel/pipeline.py +++ b/src/distilabel/pipeline.py @@ -636,7 +636,11 @@ def _generate( # noqa: C901 logger.info(f"Calling generator for batch {batch_i}...") try: batch_generations = self._get_batch_generations( - inputs, num_generations, num_batches, generation_progress_func + inputs=inputs, + num_generations=num_generations, + num_batches=num_batches, + shuffle_before_labelling=shuffle_before_labelling, + progress_callback_func=generation_progress_func, ) generations.extend(batch_generations) except Exception as e: @@ -655,7 +659,7 @@ def _generate( # noqa: C901 ) inputs = self._include_generator_outputs_as_inputs( - inputs, batch_generations + inputs=inputs, outputs=batch_generations ) if self.labeller is not None: @@ -785,6 +789,7 @@ def generate( num_generations=num_generations, batch_size=batch_size, enable_checkpoints=enable_checkpoints, + shuffle_before_labelling=shuffle_before_labelling, display_progress_bar=display_progress_bar, ) From 9a318e427c438c8b56c39c9c41adc2532ffb7bd5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Tue, 19 Dec 2023 17:07:27 +0100 Subject: [PATCH 6/6] Replace `multiprocessing` with `multiprocess` (#171) * Use `multiprocess` instead of `multiprocessing` * Add warning about executing the example in Mac OS * Apply suggestions from code review Co-authored-by: Alvaro Bartolome --------- Co-authored-by: Alvaro Bartolome --- examples/pipeline-llamacpp-and-openai-process.py | 4 ++++ pyproject.toml | 1 + src/distilabel/llm/base.py | 3 ++- 3 files changed, 7 insertions(+), 1 deletion(-) diff --git a/examples/pipeline-llamacpp-and-openai-process.py b/examples/pipeline-llamacpp-and-openai-process.py index bf3af9b660..df5f8cf11f 100644 --- a/examples/pipeline-llamacpp-and-openai-process.py +++ b/examples/pipeline-llamacpp-and-openai-process.py @@ -12,6 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +# WARNING: to run this example in Mac OS use: +# no_proxy=* OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES python examples/pipeline-llamacpp-and-openai-process.py +# Otherwise you will get an error when loading the llama.cpp model + import os from typing import TYPE_CHECKING diff --git a/pyproject.toml b/pyproject.toml index b87a8e781b..7cad05ae18 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ dependencies = [ "rich >= 13.5.0", "tenacity >= 8", "importlib-resources >= 6.1.1; python_version < '3.9'", + "multiprocess", ] dynamic = ["version"] diff --git a/src/distilabel/llm/base.py b/src/distilabel/llm/base.py index 3be4f95db8..1306984e2b 100644 --- a/src/distilabel/llm/base.py +++ b/src/distilabel/llm/base.py @@ -14,7 +14,6 @@ from __future__ import annotations -import multiprocessing as mp import queue import random import warnings @@ -34,6 +33,8 @@ Union, ) +import multiprocess as mp + from distilabel.logger import get_logger from distilabel.tasks.prompt import Prompt from distilabel.utils.futures import when_all_complete