From 53ca00c13d8c737a8fb80f0071e5ccd3f1aa66ee Mon Sep 17 00:00:00 2001 From: Alvaro Bartolome Date: Fri, 5 Jan 2024 14:12:12 +0100 Subject: [PATCH] Add `TogetherInferenceLLM` (#215) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add `TogetherInferenceLLM` and `_TOGETHER_AVAILABLE_FLAG` * Update docstrings of `TogetherInferenceLLM` * Add `TogetherInferenceLLM` in `distilabel.llm` init * Access `TogetherInferenceLLM` output via dict * Fix bug affecting `TextGenerationTask` in `_to_argilla_record` * Add `TogetherInferenceLLM` documentation * Add `examples/pipeline-together-inference.py` * Update `model` argument docstring Co-authored-by: Agus --------- Co-authored-by: Gabriel Martín Blázquez Co-authored-by: Agus --- README.md | 1 + docs/index.md | 1 + .../llm/together_inference_generate.py | 21 ++ docs/technical-reference/llms.md | 16 +- examples/pipeline-together-inference.py | 74 ++++++ src/distilabel/llm/__init__.py | 2 + src/distilabel/llm/together.py | 215 ++++++++++++++++++ src/distilabel/tasks/base.py | 76 ++++--- src/distilabel/utils/imports.py | 1 + 9 files changed, 378 insertions(+), 29 deletions(-) create mode 100644 docs/snippets/technical-reference/llm/together_inference_generate.py create mode 100644 examples/pipeline-together-inference.py create mode 100644 src/distilabel/llm/together.py diff --git a/README.md b/README.md index 54f9a2356e..cf10a3abfe 100644 --- a/README.md +++ b/README.md @@ -32,6 +32,7 @@ In addition, the following extras are available: - `openai`: for using OpenAI API models via the `OpenAILLM` integration. - `vllm`: for using [vllm](https://github.com/vllm-project/vllm) serving engine via the `vLLM` integration. - `llama-cpp`: for using [llama-cpp-python](https://github.com/abetlen/llama-cpp-python) as Python bindings for `llama.cpp`. +- `together`: for using [Together Inference](https://www.together.ai/products) via their Python client. - `argilla`: for exporting the generated datasets to [Argilla](https://argilla.io/). ## Example diff --git a/docs/index.md b/docs/index.md index f4831d9058..d8302c83c9 100644 --- a/docs/index.md +++ b/docs/index.md @@ -23,6 +23,7 @@ In addition, the following extras are available: - `openai`: for using OpenAI API models via the `OpenAILLM` integration. - `vllm`: for using [vllm](https://github.com/vllm-project/vllm) serving engine via the `vLLM` integration. - `llama-cpp`: for using [llama-cpp-python](https://github.com/abetlen/llama-cpp-python) as Python bindings for `llama.cpp`. +- `together`: for using [Together Inference](https://www.together.ai/products) via their Python client. - `argilla`: for exporting the generated datasets to [Argilla](https://argilla.io/). ## Quick example diff --git a/docs/snippets/technical-reference/llm/together_inference_generate.py b/docs/snippets/technical-reference/llm/together_inference_generate.py new file mode 100644 index 0000000000..9dab6c1c47 --- /dev/null +++ b/docs/snippets/technical-reference/llm/together_inference_generate.py @@ -0,0 +1,21 @@ +from distilabel.tasks import TextGenerationTask +from distilabel.llm import TogetherInferenceLLM + +llm = TogetherInferenceLLM( + model="togethercomputer/llama-2-70b-chat", + task=TextGenerationTask(), + max_new_tokens=512, + temperature=0.3, + prompt_format="llama2", +) +output = llm.generate( + [{"input": "Explain me the theory of relativity as if you were a pirate."}] +) +# >>> print(result[0][0]["parsed_output"]["generations"]) +# Ahoy matey! Yer lookin' fer a tale of the theory of relativity, eh? Well, +# settle yerself down with a pint o' grog and listen close, for this be a story +# of the sea of time and space! +# Ye see, matey, the theory of relativity be tellin' us that time and space ain't +# fixed things, like the deck o' a ship or the stars in the sky. Nay, they be like +# the ocean itself, always changin' and flowin' like the tides. +# Now, imagine ... diff --git a/docs/technical-reference/llms.md b/docs/technical-reference/llms.md index fafca6e6cd..71803674dc 100644 --- a/docs/technical-reference/llms.md +++ b/docs/technical-reference/llms.md @@ -9,7 +9,7 @@ In this section we will see what's an `LLM` and the different `LLM`s implementat The [`LLM`][distilabel.llm.base.LLM] class encapsulates the functionality for interacting with a large language model. -It distinguishes between *task* specifications and configurable parameters that influence the LLM's behavior. +It distinguishes between *task* specifications and configurable parameters that influence the LLM behavior. For illustration purposes, we employ the [`TextGenerationTask`][distilabel.tasks.text_generation.base.TextGenerationTask] in this section and guide you to the dedicated [`Tasks`](../technical-reference/tasks.md) section for comprehensive details. @@ -28,7 +28,7 @@ Let's briefly introduce the general parameters we may find[^1]: - `top_k` and `top_p`: `top_k` limits the number of tokens the model is allowed to use to generate the following token sorted by probability, while `top_p` limits the number of tokens the model can use for the next token, but in terms of the sum of their probabilities. -- `frequency_penalty` and `presence_penalty`: the frequency penalty penalizes tokens that have already appeard in the generated text, limiting the possibility of those appearing again, and the `presence_penalty` penalizes regardless of hte frequency. +- `frequency_penalty` and `presence_penalty`: the frequency penalty penalizes tokens that have already appeared in the generated text, limiting the possibility of those appearing again, and the `presence_penalty` penalizes regardless of the frequency. - `prompt_format` and `prompt_formatting_fn`: these two parameters allow to tweak the prompt of our models, for example we can direct the `LLM` to format the prompt according to one of the defined formats, while `prompt_formatting_fn` allows to pass a function that will be applied to the prompt before the generation, for extra control of what we ingest to the model. @@ -160,6 +160,17 @@ Let's see how to interact with these LLMs: --8<-- "docs/snippets/technical-reference/llm/inference_endpoint_generate.py" ``` +### Together Inference + +Together offers a product named Together Inference, which exposes some models for diverse tasks such as chat, text generation, code, or image; exposing those via an endpoint within their API either as serverless endpoints or as dedicated instances. + +See their release post with more details at [Announcing Together Inference Engine – the fastest inference available](https://www.together.ai/blog/together-inference-engine-v1). + + +```python +--8<-- "docs/snippets/technical-reference/llm/together_inference_generate.py" +``` + ## `ProcessLLM` and `LLMPool` By default, `distilabel` uses a single process, so the generation loop is usually bottlenecked by the model inference time and Python GIL. To overcome this limitation, we provide the `ProcessLLM` class that allows to load an `LLM` in a different process, avoiding the GIL and allowing to parallelize the generation loop. Creating a `ProcessLLM` is easy as: @@ -176,4 +187,3 @@ You can directly use a `ProcessLLM` as the `generator` or `labeller` in a `Pipel ```python --8<-- "docs/snippets/technical-reference/llm/llmpool.py" ``` - diff --git a/examples/pipeline-together-inference.py b/examples/pipeline-together-inference.py new file mode 100644 index 0000000000..5874d2b1a1 --- /dev/null +++ b/examples/pipeline-together-inference.py @@ -0,0 +1,74 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import time + +from datasets import Dataset +from distilabel.llm import TogetherInferenceLLM +from distilabel.pipeline import Pipeline +from distilabel.tasks import TextGenerationTask + +if __name__ == "__main__": + dataset = Dataset.from_dict( + { + "input": ["Explain me the theory of relativity as if you were a pirate."], + } + ) + + llm = TogetherInferenceLLM( + model="togethercomputer/llama-2-70b-chat", + api_key=os.getenv("TOGETHER_API_KEY", None), + task=TextGenerationTask(), + prompt_format="llama2", + ) + pipeline = Pipeline(generator=llm) + + start = time.time() + dataset = pipeline.generate( + dataset=dataset, + shuffle_before_labelling=False, + num_generations=2, + skip_dry_run=True, + display_progress_bar=False, + ) # type: ignore + end = time.time() + print("Elapsed", end - start) + + # Push to the HuggingFace Hub + dataset.push_to_hub( + os.getenv("HF_REPO_ID"), # type: ignore + split="train", + private=True, + token=os.getenv("HF_TOKEN", None), + ) + + try: + from uuid import uuid4 + + import argilla as rg + + rg.init( + api_url=os.getenv("ARGILLA_API_URL"), + api_key=os.getenv("ARGILLA_API_KEY"), + ) + + # Convert into an Argilla dataset and push it to Argilla + rg_dataset = dataset.to_argilla() + rg_dataset.push_to_argilla( + name=f"my-dataset-{uuid4()}", + workspace="admin", + ) + except ImportError: + pass diff --git a/src/distilabel/llm/__init__.py b/src/distilabel/llm/__init__.py index fc09d7a33d..a7e02e0e4f 100644 --- a/src/distilabel/llm/__init__.py +++ b/src/distilabel/llm/__init__.py @@ -18,6 +18,7 @@ from distilabel.llm.huggingface.transformers import TransformersLLM from distilabel.llm.llama_cpp import LlamaCppLLM from distilabel.llm.openai import OpenAILLM +from distilabel.llm.together import TogetherInferenceLLM from distilabel.llm.vllm import vLLM __all__ = [ @@ -29,6 +30,7 @@ "InferenceEndpointsLLM", "TransformersLLM", "LlamaCppLLM", + "TogetherInferenceLLM", "OpenAILLM", "vLLM", ] diff --git a/src/distilabel/llm/together.py b/src/distilabel/llm/together.py new file mode 100644 index 0000000000..43b3f26196 --- /dev/null +++ b/src/distilabel/llm/together.py @@ -0,0 +1,215 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from functools import cached_property +from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Union + +from distilabel.llm.base import LLM +from distilabel.llm.utils import LLMOutput +from distilabel.logger import get_logger +from distilabel.utils.imports import _TOGETHER_AVAILABLE + +if _TOGETHER_AVAILABLE: + import together + +if TYPE_CHECKING: + from distilabel.tasks.base import Task + from distilabel.tasks.prompt import SupportedFormats + + +logger = get_logger() + + +class TogetherInferenceLLM(LLM): + def __init__( + self, + task: "Task", + model: str, + api_key: Union[str, None] = None, + max_new_tokens: int = 128, + repetition_penalty: float = 1.0, + temperature: float = 1.0, + top_p: float = 1.0, + top_k: int = 1, + stop: Union[List[str], None] = None, + logprobs: int = 0, + num_threads: Union[int, None] = None, + prompt_format: Union["SupportedFormats", None] = None, + prompt_formatting_fn: Union[Callable[..., str], None] = None, + ) -> None: + """Initializes the OpenAILLM class. + + Args: + task (Task): the task to be performed by the LLM. + model (str): the model to be used for generation. + max_new_tokens (int, optional): the maximum number of tokens to be generated. + Defaults to 128. + temperature (float, optional): the temperature to be used for generation. From the Together + Inference docs: "A decimal number that determines the degree of randomness in the response. + A value of 0 will always yield the same output. A temperature much less than 1 favors more + correctness and is appropriate for question answering or summarization. A value approaching + 1 introduces more randomness in the output.". Defaults to 1.0. + repetition_penalty (float, optional): the repetition penalty to be used for generation. From the + Together Inference docs: "Controls the diversity of generated text by reducing the likelihood + of repeated sequences. Higher values decrease repetition.". Defaults to 1.0. + top_p (float, optional): the top-p value to be used for generation. From the Together + Inference docs: "used to dynamically adjust the number of choices for each predicted + token based on the cumulative probabilities. It specifies a probability threshold, + below which all less likely tokens are filtered out. This technique helps to maintain + diversity and generate more fluent and natural-sounding text.". Defaults to 1.0. + top_k (int, optional): the top-k value to be used for generation. From the Together Inference + docs: "used to limit the number of choices for the next predicted word or token. It specifies + the maximum number of tokens to consider at each step, based on their probability of occurrence. + This technique helps to speed up the generation process and can improve the quality of the + generated text by focusing on the most likely options.". Defaults to 1. + stop (List[str], optional): strings to delimitate the generation process, so that when the + model generates any of the provided characters, the generation process is considered completed. + Defaults to None. + logprobs (int, optional): the number of logprobs to be returned for each token. From the + Together Inference docs: "An integer that specifies how many top token log probabilities + are included in the response for each token generation step.". Defaults to None. + num_threads (Union[int, None], optional): the number of threads to be used + for parallel generation. If `None`, no parallel generation will be performed. + Defaults to `None`. + prompt_format (Union[SupportedFormats, None], optional): the format to be used + for the prompt. If `None`, the default format of the task will be used, available + formats are `openai`, `chatml`, `llama2`, `zephyr`, and `default`. Defaults to `None`, + but `default` (concatenation of `system_prompt` and `formatted_prompt` with a line-break) + will be used if no `prompt_formatting_fn` is provided. + prompt_formatting_fn (Union[Callable[..., str], None], optional): a function to be + applied to the prompt before generation. If `None`, no formatting will be applied. + Defaults to `None`. + + Raises: + AssertionError: if the provided `model` is not available in Together Inference. + + Examples: + >>> from distilabel.tasks.text_generation import TextGenerationTask as Task + >>> from distilabel.llm import TogetherInferenceLLM + >>> task = Task() + >>> llm = TogetherInferenceLLM(model="togethercomputer/llama-2-7b", task=task, prompt_format="llama2") + """ + if not _TOGETHER_AVAILABLE: + raise ImportError( + "`TogetherInferenceLLM` cannot be used as `together` is not installed, please " + " install it with `pip install together`." + ) + + together.api_key = api_key or os.getenv("TOGETHER_API_KEY", None) + if together.api_key is None: + raise ValueError( + "No `api_key` provided, please provide one or set the `TOGETHER_API_KEY` " + "environment variable." + ) + + super().__init__( + task=task, + num_threads=num_threads, + prompt_format=prompt_format, + prompt_formatting_fn=prompt_formatting_fn, + ) + + assert ( + model in self.available_models + ), f"Provided `model` is not available in Together Inference, available models are {self.available_models}" + self.model = model + + self.max_new_tokens = max_new_tokens + self.temperature = temperature + self.top_p = top_p + self.top_k = top_k + self.repetition_penalty = repetition_penalty + self.stop = stop + self.logprobs = logprobs + + def __rich_repr__(self) -> Generator[Any, None, None]: + yield from super().__rich_repr__() + yield ( + "parameters", + { + "max_new_tokens": self.max_new_tokens, + "temperature": self.temperature, + "repetition_penalty": self.repetition_penalty, + "top_p": self.top_p, + "top_k": self.top_k, + "stop": self.stop, + "logprobs": self.logprobs, + }, + ) + + @cached_property + def available_models(self) -> List[str]: + """Returns the list of available models in Together Inference.""" + # TODO: exclude the image models + return [model["name"] for model in together.Models.list()] + + @property + def model_name(self) -> str: + """Returns the name of the Together Inference model.""" + return self.model + + def _generate( + self, + inputs: List[Dict[str, Any]], + num_generations: int = 1, + ) -> List[List[LLMOutput]]: + """Generates `num_generations` for each input in `inputs`. + + Args: + inputs (List[Dict[str, Any]]): the inputs to be used for generation. + num_generations (int, optional): the number of generations to be performed for each + input. Defaults to 1. + + Returns: + List[List[LLMOutput]]: the generated outputs. + """ + prompts = self._generate_prompts(inputs, default_format=None) + outputs = [] + for prompt in prompts: + batch = [] + for _ in range(num_generations): + output = together.Complete.create( + prompt=prompt, + model=self.model, + max_tokens=self.max_new_tokens, + stop=self.stop, + temperature=self.temperature, + top_k=self.top_k, + top_p=self.top_p, + repetition_penalty=self.repetition_penalty, + logprobs=self.logprobs, + ) + if output["output"]["choices"] is not None: + for choice in output["output"]["choices"]: + try: + parsed_response = self.task.parse_output( + choice["text"].strip() + ) + except Exception as e: + logger.error( + f"Error parsing Together Inference response: {e}" + ) + parsed_response = None + batch.append( + LLMOutput( + model_name=self.model_name, + prompt_used=prompt, + raw_output=choice["text"], + parsed_output=parsed_response, + ) + ) + if len(batch) > 0: + outputs.append(batch) + return outputs diff --git a/src/distilabel/tasks/base.py b/src/distilabel/tasks/base.py index c334ab91dc..80bc5c279c 100644 --- a/src/distilabel/tasks/base.py +++ b/src/distilabel/tasks/base.py @@ -144,40 +144,64 @@ def _to_argilla_record( # noqa: C901 else None ] if any( - generation_column in required_column_names - for generation_column in generation_columns + isinstance(nested, list) + for column_name in list( + set(generation_columns) + - { + "generation_model", + "generation_prompt", + "raw_generation_response", + } + ) + for nested in dataset_row[column_name] ): - unwrapped_dataset_rows = [] - for row in dataset_rows: - for idx in range(len(dataset_row["generation_model"])): - unwrapped_dataset_row = {} - for key, value in row.items(): - if key in generation_columns: - unwrapped_dataset_row[key] = value[idx] - else: - unwrapped_dataset_row[key] = value - unwrapped_dataset_rows.append(unwrapped_dataset_row) - dataset_rows = unwrapped_dataset_rows + if any( + generation_column in required_column_names + for generation_column in generation_columns + ): + unwrapped_dataset_rows = [] + for row in dataset_rows: + for idx in range(len(dataset_row["generation_model"])): + unwrapped_dataset_row = {} + for key, value in row.items(): + if key in generation_columns: + unwrapped_dataset_row[key] = value[idx] + else: + unwrapped_dataset_row[key] = value + unwrapped_dataset_rows.append(unwrapped_dataset_row) + dataset_rows = unwrapped_dataset_rows if "labelling_model" in dataset_row and isinstance( dataset_row["labelling_model"], list ): labelling_columns = column_names[column_names.index("labelling_model") :] if any( - labelling_column in required_column_names - for labelling_column in labelling_columns + isinstance(nested, list) + for column_name in list( + set(labelling_columns) + - { + "labelling_model", + "labelling_prompt", + "raw_labelling_response", + } + ) + for nested in dataset_row[column_name] ): - unwrapped_dataset_rows = [] - for row in dataset_rows: - for idx in range(len(dataset_row["labelling_model"])): - unwrapped_dataset_row = {} - for key, value in row.items(): - if key in labelling_columns: - unwrapped_dataset_row[key] = value[idx] - else: - unwrapped_dataset_row[key] = value - unwrapped_dataset_rows.append(unwrapped_dataset_row) - dataset_rows = unwrapped_dataset_rows + if any( + labelling_column in required_column_names + for labelling_column in labelling_columns + ): + unwrapped_dataset_rows = [] + for row in dataset_rows: + for idx in range(len(dataset_row["labelling_model"])): + unwrapped_dataset_row = {} + for key, value in row.items(): + if key in labelling_columns: + unwrapped_dataset_row[key] = value[idx] + else: + unwrapped_dataset_row[key] = value + unwrapped_dataset_rows.append(unwrapped_dataset_row) + dataset_rows = unwrapped_dataset_rows if len(dataset_rows) == 1: return self.to_argilla_record(dataset_rows[0], *args, **kwargs) diff --git a/src/distilabel/utils/imports.py b/src/distilabel/utils/imports.py index af50986bda..c044bf7e95 100644 --- a/src/distilabel/utils/imports.py +++ b/src/distilabel/utils/imports.py @@ -114,3 +114,4 @@ def _check_package_is_available( _AISTUDIO_AVAILABLE = _check_package_is_available( "google-generativeai", min_version="0.3.2", greater_or_equal=True ) +_TOGETHER_AVAILABLE = _check_package_is_available("together")