From 04b86f51c06f91f8d0054576855cf3864cc6fad8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Fri, 26 Jul 2024 19:03:05 +0200 Subject: [PATCH] Add `Embeddings` base class, `SentenceTransformerEmbeddings` class, `EmbeddingGeneration` and `FaissNearestNeighbour` steps (#830) * Add `Embeddings` base class and `SentenceTransformers` class * Add `EmbeddingGeneration` step * Add `precision` attribute * Add docstrings * Add example to docstring * Update component gallery to include `Embeddings` models * Add `sentence-transformers` extra * Add `FaissNearestNeighbour` step * Add category and example * Merge category to icons dictionaries * Add missing unit tests * Add `faiss-cpu` and `faiss-gpu` extras * Update unit tests --- pyproject.toml | 3 + scripts/install_dependencies.sh | 2 +- src/distilabel/embeddings/__init__.py | 21 ++ src/distilabel/embeddings/base.py | 72 +++++++ .../embeddings/sentence_transformers.py | 157 ++++++++++++++ src/distilabel/llms/base.py | 3 +- .../llms/huggingface/transformers.py | 2 + src/distilabel/steps/__init__.py | 4 + src/distilabel/steps/embeddings/__init__.py | 14 ++ .../steps/embeddings/embedding_generation.py | 86 ++++++++ .../steps/embeddings/nearest_neighbour.py | 204 ++++++++++++++++++ .../utils/export_components_info.py | 60 +++--- src/distilabel/utils/logging.py | 2 + .../utils/mkdocs/components_gallery.py | 53 ++++- .../templates/components-gallery/index.md | 8 + tests/unit/embeddings/__init__.py | 14 ++ .../embeddings/test_sentence_transformers.py | 42 ++++ tests/unit/steps/embeddings/__init__.py | 14 ++ .../embeddings/test_embedding_generation.py | 51 +++++ .../embeddings/test_nearest_neighbour.py | 86 ++++++++ 20 files changed, 865 insertions(+), 33 deletions(-) create mode 100644 src/distilabel/embeddings/__init__.py create mode 100644 src/distilabel/embeddings/base.py create mode 100644 src/distilabel/embeddings/sentence_transformers.py create mode 100644 src/distilabel/steps/embeddings/__init__.py create mode 100644 src/distilabel/steps/embeddings/embedding_generation.py create mode 100644 src/distilabel/steps/embeddings/nearest_neighbour.py create mode 100644 tests/unit/embeddings/__init__.py create mode 100644 tests/unit/embeddings/test_sentence_transformers.py create mode 100644 tests/unit/steps/embeddings/__init__.py create mode 100644 tests/unit/steps/embeddings/test_embedding_generation.py create mode 100644 tests/unit/steps/embeddings/test_nearest_neighbour.py diff --git a/pyproject.toml b/pyproject.toml index 7c15678ce1..96450c6d30 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -91,6 +91,9 @@ vllm = [ # `setuptools` is needed to be installed if installed with `uv pip install distilabel[vllm]` "setuptools", ] +sentence-transformers = ["sentence-transformers >= 3.0.0"] +faiss-cpu = ["faiss-cpu >= 1.8.0"] +faiss-gpu = ["faiss-cpu >= 1.7.2"] [project.urls] Documentation = "https://distilabel.argilla.io/" diff --git a/scripts/install_dependencies.sh b/scripts/install_dependencies.sh index 06f52c402e..7deb35b778 100755 --- a/scripts/install_dependencies.sh +++ b/scripts/install_dependencies.sh @@ -6,7 +6,7 @@ python_version=$(python -c "import sys; print(sys.version_info[:2])") python -m pip install uv -uv pip install --system -e ".[dev,tests,anthropic,argilla,cohere,groq,hf-inference-endpoints,hf-transformers,litellm,llama-cpp,ollama,openai,outlines,vertexai,mistralai,instructor]" +uv pip install --system -e ".[dev,tests,anthropic,argilla,cohere,groq,hf-inference-endpoints,hf-transformers,litellm,llama-cpp,ollama,openai,outlines,vertexai,mistralai,instructor,sentence-transformers,faiss-cpu]" if [ "${python_version}" != "(3, 12)" ]; then uv pip install --system -e .[ray] diff --git a/src/distilabel/embeddings/__init__.py b/src/distilabel/embeddings/__init__.py new file mode 100644 index 0000000000..a7e5e63e2c --- /dev/null +++ b/src/distilabel/embeddings/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from distilabel.embeddings.base import Embeddings +from distilabel.embeddings.sentence_transformers import SentenceTransformerEmbeddings + +__all__ = [ + "Embeddings", + "SentenceTransformerEmbeddings", +] diff --git a/src/distilabel/embeddings/base.py b/src/distilabel/embeddings/base.py new file mode 100644 index 0000000000..22a255540f --- /dev/null +++ b/src/distilabel/embeddings/base.py @@ -0,0 +1,72 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from abc import ABC, abstractmethod +from typing import List, Union + +from pydantic import BaseModel, ConfigDict, PrivateAttr + +from distilabel.mixins.runtime_parameters import RuntimeParametersMixin +from distilabel.utils.serialization import _Serializable + + +class Embeddings(RuntimeParametersMixin, BaseModel, _Serializable, ABC): + """Base class for `Embeddings` models. + + To implement an `Embeddings` subclass, you need to subclass this class and implement: + - `load` method to load the `Embeddings` model. Don't forget to call `super().load()`, + so the `_logger` attribute is initialized. + - `model_name` property to return the model name used for the `Embeddings`. + - `encode` method to generate the sentence embeddings. + + Attributes: + _logger: the logger to be used for the `Embeddings` model. It will be initialized + when the `load` method is called. + """ + + model_config = ConfigDict( + arbitrary_types_allowed=True, + protected_namespaces=(), + validate_default=True, + validate_assignment=True, + extra="forbid", + ) + _logger: Union[logging.Logger, None] = PrivateAttr(...) + + def load(self) -> None: + """Method to be called to initialize the `Embeddings`""" + self._logger = logging.getLogger(f"distilabel.llm.{self.model_name}") + + def unload(self) -> None: + """Method to be called to unload the `Embeddings` and release any resources.""" + pass + + @property + @abstractmethod + def model_name(self) -> str: + """Returns the model name used for the `Embeddings`.""" + pass + + @abstractmethod + def encode(self, inputs: List[str]) -> List[List[Union[int, float]]]: + """Generates embeddings for the provided inputs. + + Args: + inputs: a list of texts for which an embedding has to be generated. + + Returns: + The generated embeddings. + """ + pass diff --git a/src/distilabel/embeddings/sentence_transformers.py b/src/distilabel/embeddings/sentence_transformers.py new file mode 100644 index 0000000000..08b3465ad1 --- /dev/null +++ b/src/distilabel/embeddings/sentence_transformers.py @@ -0,0 +1,157 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union + +from pydantic import Field, PrivateAttr + +from distilabel.embeddings.base import Embeddings +from distilabel.llms.mixins.cuda_device_placement import CudaDevicePlacementMixin +from distilabel.mixins.runtime_parameters import RuntimeParameter + +if TYPE_CHECKING: + from sentence_transformers import SentenceTransformer + + +class SentenceTransformerEmbeddings(Embeddings, CudaDevicePlacementMixin): + """`sentence-transformers` library implementation for embedding generation. + + Attributes: + model: the model Hugging Face Hub repo id or a path to a directory containing the + model weights and configuration files. + device: the name of the device used to load the model e.g. "cuda", "mps", etc. + Defaults to `None`. + prompts: a dictionary containing prompts to be used with the model. Defaults to + `None`. + default_prompt_name: the default prompt (in `prompts`) that will be applied to the + inputs. If not provided, then no prompt will be used. Defaults to `None`. + trust_remote_code: whether to allow fetching and executing remote code fetched + from the repository in the Hub. Defaults to `False`. + revision: if `model` refers to a Hugging Face Hub repository, then the revision + (e.g. a branch name or a commit id) to use. Defaults to `"main"`. + token: the Hugging Face Hub token that will be used to authenticate to the Hugging + Face Hub. If not provided, the `HF_TOKEN` environment or `huggingface_hub` package + local configuration will be used. Defaults to `None`. + truncate_dim: the dimension to truncate the sentence embeddings. Defaults to `None`. + model_kwargs: extra kwargs that will be passed to the Hugging Face `transformers` + model class. Defaults to `None`. + tokenizer_kwargs: extra kwargs that will be passed to the Hugging Face `transformers` + tokenizer class. Defaults to `None`. + config_kwargs: extra kwargs that will be passed to the Hugging Face `transformers` + configuration class. Defaults to `None`. + precision: the dtype that will have the resulting embeddings. Defaults to `"float32"`. + normalize_embeddings: whether to normalize the embeddings so they have a length + of 1. Defaults to `None`. + + Examples: + + Generating sentence embeddings: + + ```python + from distilabel.embeddings import SentenceTransformerEmbeddings + + embeddings = SentenceTransformerEmbeddings(model="mixedbread-ai/mxbai-embed-large-v1") + + embeddings.load() + + results = embeddings.encode(inputs=["distilabel is awesome!", "and Argilla!"]) + # [ + # [-0.05447685346007347, -0.01623094454407692, ...], + # [4.4889533455716446e-05, 0.044016145169734955, ...], + # ] + ``` + """ + + model: str + device: Optional[RuntimeParameter[str]] = Field( + default=None, + description="The device to be used to load the model. If `None`, then it" + " will check if a GPU can be used.", + ) + prompts: Optional[Dict[str, str]] = None + default_prompt_name: Optional[str] = None + trust_remote_code: bool = False + revision: Optional[str] = None + token: Optional[str] = None + truncate_dim: Optional[int] = None + model_kwargs: Optional[Dict[str, Any]] = None + tokenizer_kwargs: Optional[Dict[str, Any]] = None + config_kwargs: Optional[Dict[str, Any]] = None + precision: Optional[Literal["float32", "int8", "uint8", "binary", "ubinary"]] = ( + "float32" + ) + normalize_embeddings: RuntimeParameter[bool] = Field( + default=True, + description="Whether to normalize the embeddings so the generated vectors" + " have a length of 1 or not.", + ) + + _model: Union["SentenceTransformer", None] = PrivateAttr(None) + + def load(self) -> None: + """Loads the Sentence Transformer model""" + super().load() + + if self.device == "cuda": + CudaDevicePlacementMixin.load(self) + + try: + from sentence_transformers import SentenceTransformer + except ImportError as e: + raise ImportError( + "`sentence-transformers` package is not installed. Please install it using" + " `pip install sentence-transformers`." + ) from e + + self._model = SentenceTransformer( + model_name_or_path=self.model, + device=self.device, + prompts=self.prompts, + default_prompt_name=self.default_prompt_name, + trust_remote_code=self.trust_remote_code, + revision=self.revision, + token=self.token, + truncate_dim=self.truncate_dim, + model_kwargs=self.model_kwargs, + tokenizer_kwargs=self.tokenizer_kwargs, + config_kwargs=self.config_kwargs, + ) + + @property + def model_name(self) -> str: + """Returns the name of the model.""" + return self.model + + def encode(self, inputs: List[str]) -> List[List[Union[int, float]]]: + """Generates embeddings for the provided inputs. + + Args: + inputs: a list of texts for which an embedding has to be generated. + + Returns: + The generated embeddings. + """ + return self._model.encode( # type: ignore + sentences=inputs, + batch_size=len(inputs), + convert_to_numpy=True, + precision=self.precision, # type: ignore + normalize_embeddings=self.normalize_embeddings, # type: ignore + ).tolist() # type: ignore + + def unload(self) -> None: + del self._model + if self.device == "cuda": + CudaDevicePlacementMixin.unload(self) + super().unload() diff --git a/src/distilabel/llms/base.py b/src/distilabel/llms/base.py index 07fba6788d..daf1bd0813 100644 --- a/src/distilabel/llms/base.py +++ b/src/distilabel/llms/base.py @@ -85,7 +85,8 @@ class LLM(RuntimeParametersMixin, BaseModel, _Serializable, ABC): _logger: Union[logging.Logger, None] = PrivateAttr(...) def load(self) -> None: - """Method to be called to initialize the `LLM`, its logger and optionally the structured output generator.""" + """Method to be called to initialize the `LLM`, its logger and optionally the + structured output generator.""" self._logger = logging.getLogger(f"distilabel.llm.{self.model_name}") def unload(self) -> None: diff --git a/src/distilabel/llms/huggingface/transformers.py b/src/distilabel/llms/huggingface/transformers.py index 86754e8ef1..06d6b71422 100644 --- a/src/distilabel/llms/huggingface/transformers.py +++ b/src/distilabel/llms/huggingface/transformers.py @@ -44,6 +44,8 @@ class TransformersLLM(LLM, MagpieChatTemplateMixin, CudaDevicePlacementMixin): (e.g. a branch name or a commit id) to use. Defaults to `"main"`. torch_dtype: the torch dtype to use for the model e.g. "float16", "float32", etc. Defaults to `"auto"`. + trust_remote_code: whether to allow fetching and executing remote code fetched + from the repository in the Hub. Defaults to `False`. trust_remote_code: whether to trust or not remote (code in the Hugging Face Hub repository) code to load the model. Defaults to `False`. model_kwargs: additional dictionary of keyword arguments that will be passed to diff --git a/src/distilabel/steps/__init__.py b/src/distilabel/steps/__init__.py index b3d84c15ea..39951a4643 100644 --- a/src/distilabel/steps/__init__.py +++ b/src/distilabel/steps/__init__.py @@ -27,6 +27,8 @@ from distilabel.steps.columns.merge import MergeColumns from distilabel.steps.decorator import step from distilabel.steps.deita import DeitaFiltering +from distilabel.steps.embeddings.embedding_generation import EmbeddingGeneration +from distilabel.steps.embeddings.nearest_neighbour import FaissNearestNeighbour from distilabel.steps.formatting.conversation import ConversationTemplate from distilabel.steps.formatting.dpo import ( FormatChatGenerationDPO, @@ -54,6 +56,8 @@ "CombineColumns", "ConversationTemplate", "DeitaFiltering", + "EmbeddingGeneration", + "FaissNearestNeighbour", "ExpandColumns", "FormatChatGenerationDPO", "FormatChatGenerationSFT", diff --git a/src/distilabel/steps/embeddings/__init__.py b/src/distilabel/steps/embeddings/__init__.py new file mode 100644 index 0000000000..20ce00bda7 --- /dev/null +++ b/src/distilabel/steps/embeddings/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/src/distilabel/steps/embeddings/embedding_generation.py b/src/distilabel/steps/embeddings/embedding_generation.py new file mode 100644 index 0000000000..55e8838274 --- /dev/null +++ b/src/distilabel/steps/embeddings/embedding_generation.py @@ -0,0 +1,86 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, List + +from distilabel.embeddings.base import Embeddings +from distilabel.steps.base import Step, StepInput + +if TYPE_CHECKING: + from distilabel.steps.typing import StepOutput + + +class EmbeddingGeneration(Step): + """Generate embeddings using an `Embeddings` model. + + `EmbeddingGeneration` is a `Step` that using an `Embeddings` model generates sentence + embeddings for the provided input texts. + + Attributes: + embeddings: the `Embeddings` model used to generate the sentence embeddings. + + Input columns: + - text (`str`): The text for which the sentence embedding has to be generated. + + Output columns: + - embedding (`List[Union[float, int]]`): the generated sentence embedding. + + Examples: + + Generate sentence embeddings with Sentence Transformers: + + ```python + from distilabel.embeddings import SentenceTransformerEmbeddings + from distilabel.steps import EmbeddingGeneration + + embedding_generation = EmbeddingGeneration( + embeddings=SentenceTransformerEmbeddings( + model="mixedbread-ai/mxbai-embed-large-v1", + ) + ) + + embedding_generation.load() + + result = next(embedding_generation.process([{"text": "Hello, how are you?"}])) + # [{'text': 'Hello, how are you?', 'embedding': [0.06209656596183777, -0.015797119587659836, ...]}] + ``` + + """ + + embeddings: Embeddings + + @property + def inputs(self) -> List[str]: + return ["text"] + + @property + def outputs(self) -> List[str]: + return ["embedding", "model_name"] + + def load(self) -> None: + """Loads the `Embeddings` model.""" + super().load() + + self.embeddings.load() + + def process(self, inputs: StepInput) -> "StepOutput": # type: ignore + embeddings = self.embeddings.encode(inputs=[input["text"] for input in inputs]) + for input, embedding in zip(inputs, embeddings): + input["embedding"] = embedding + input["model_name"] = self.embeddings.model_name + yield inputs + + def unload(self) -> None: + super().unload() + self.embeddings.unload() diff --git a/src/distilabel/steps/embeddings/nearest_neighbour.py b/src/distilabel/steps/embeddings/nearest_neighbour.py new file mode 100644 index 0000000000..4aafbe664a --- /dev/null +++ b/src/distilabel/steps/embeddings/nearest_neighbour.py @@ -0,0 +1,204 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib.util +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union + +import numpy as np +from datasets import Dataset +from pydantic import Field + +from distilabel.mixins.runtime_parameters import RuntimeParameter +from distilabel.steps import GlobalStep, StepInput + +if TYPE_CHECKING: + from distilabel.steps.typing import StepOutput + + +class FaissNearestNeighbour(GlobalStep): + """Create a `faiss` index to get the nearest neighbours. + + `FaissNearestNeighbour` is a `GlobalStep` that creates a `faiss` index using the Hugging + Face `datasets` library integration, and then gets the nearest neighbours and the scores + or distance of the nearest neighbours for each input row. + + Attributes: + device: the CUDA device ID or a list of IDs to be used. If negative integer, it + will use all the available GPUs. Defaults to `None`. + string_factory: the name of the factory to be used to build the `faiss` index. + Available string factories can be checked here: https://github.com/facebookresearch/faiss/wiki/Faiss-indexes. + Defaults to `None`. + metric_type: the metric to be used to measure the distance between the points. It's + an integer and the recommend way to pass it is importing `faiss` and then passing + one of `faiss.METRIC_x` variables. Defaults to `None`. + k: the number of nearest neighbours to search for each input row. Defaults to `1`. + search_batch_size: the number of rows to include in a search batch. The value can + be adjusted to maximize the resources usage or to avoid OOM issues. Defaults + to `50`. + + Runtime parameters: + - `device`: the CUDA device ID or a list of IDs to be used. If negative integer, + it will use all the available GPUs. Defaults to `None`. + - `string_factory`: the name of the factory to be used to build the `faiss` index. + Available string factories can be checked here: https://github.com/facebookresearch/faiss/wiki/Faiss-indexes. + Defaults to `None`. + - `metric_type`: the metric to be used to measure the distance between the points. + It's an integer and the recommend way to pass it is importing `faiss` and then + passing one of `faiss.METRIC_x` variables. Defaults to `None`. + - `k`: the number of nearest neighbours to search for each input row. Defaults to `1`. + - `search_batch_size`: the number of rows to include in a search batch. The value + can be adjusted to maximize the resources usage or to avoid OOM issues. Defaults + to `50`. + + Input columns: + - embedding (`List[Union[float, int]]`): a sentence embedding. + + Output columns: + - nn_indices (`List[int]`): a list containing the indices of the `k` nearest neighbours + in the inputs for the row. + - nn_scores (`List[float]`): a list containing the score or distance to each `k` + nearest neighbour in the inputs. + + Categories: + - embedding + + References: + - [`The Faiss library`](https://arxiv.org/abs/2401.08281) + + Examples: + + Generating embeddings and getting the nearest neighbours: + + ```python + from distilabel.embeddings.sentence_transformers import SentenceTransformerEmbeddings + from distilabel.pipeline import Pipeline + from distilabel.steps import EmbeddingGeneration, FaissNearestNeighbour, LoadDataFromHub + + with Pipeline(name="hello") as pipeline: + load_data = LoadDataFromHub(output_mappings={"prompt": "text"}) + + embeddings = EmbeddingGeneration( + embeddings=SentenceTransformerEmbeddings( + model="mixedbread-ai/mxbai-embed-large-v1" + ) + ) + + nearest_neighbours = FaissNearestNeighbour() + + load_data >> embeddings >> nearest_neighbours + + if __name__ == "__main__": + distiset = pipeline.run( + parameters={ + load_data.name: { + "repo_id": "distilabel-internal-testing/instruction-dataset-mini", + "split": "test", + }, + }, + use_cache=False, + ) + ``` + """ + + device: Optional[RuntimeParameter[Union[int, List[int]]]] = Field( + default=None, + description="The CUDA device ID or a list of IDs to be used. If negative integer," + " it will use all the available GPUs.", + ) + string_factory: Optional[RuntimeParameter[str]] = Field( + default=None, + description="The name of the factory to be used to build the `faiss` index." + "Available string factories can be checked here: https://github.com/facebookresearch/faiss/wiki/Faiss-indexes.", + ) + metric_type: Optional[RuntimeParameter[int]] = Field( + default=None, + description="The metric to be used to measure the distance between the points. It's" + " an integer and the recommend way to pass it is importing `faiss` and thenpassing" + " one of `faiss.METRIC_x` variables.", + ) + k: Optional[RuntimeParameter[int]] = Field( + default=1, + description="The number of nearest neighbours to search for each input row.", + ) + search_batch_size: Optional[RuntimeParameter[int]] = Field( + default=50, + description="The number of rows to include in a search batch. The value can be adjusted" + " to maximize the resources usage or to avoid OOM issues.", + ) + + def load(self) -> None: + super().load() + + if importlib.util.find_spec("faiss") is None: + raise ImportError( + "`faiss` package is not installed. Please install it using `pip install" + " faiss-cpu` or `pip install faiss-gpu`." + ) + + @property + def inputs(self) -> List[str]: + return ["embedding"] + + @property + def outputs(self) -> List[str]: + return ["nn_indices", "nn_scores"] + + def _build_index(self, inputs: List[Dict[str, Any]]) -> Dataset: + """Builds a `faiss` index using `datasets` integration. + + Args: + inputs: a list of dictionaries. + + Returns: + The build `datasets.Dataset` with its `faiss` index. + """ + dataset = Dataset.from_list(inputs) + dataset.add_faiss_index( + column="embedding", + device=self.device, # type: ignore + string_factory=self.string_factory, + metric_type=self.metric_type, + ) + return dataset + + def _search(self, dataset: Dataset) -> Dataset: + """Search the top `k` nearest neighbours for each row in the dataset. + + Args: + dataset: the dataset with the `faiss` index built. + + Returns: + The updated dataset containing the top `k` nearest neighbours for each row, + as well as the score or distance. + """ + + def add_search_results(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + queries = np.array(examples["embedding"]) + results = dataset.search_batch( + index_name="embedding", + queries=queries, + k=self.k + 1, # type: ignore + ) + examples["nn_indices"] = [indices[1:] for indices in results.total_indices] + examples["nn_scores"] = [scores[1:] for scores in results.total_scores] + return examples + + return dataset.map( + add_search_results, batched=True, batch_size=self.search_batch_size + ) + + def process(self, inputs: StepInput) -> "StepOutput": # type: ignore + dataset = self._build_index(inputs) + dataset = self._search(dataset) + yield dataset.to_list() diff --git a/src/distilabel/utils/export_components_info.py b/src/distilabel/utils/export_components_info.py index 7376a475dd..fa1cd6556d 100644 --- a/src/distilabel/utils/export_components_info.py +++ b/src/distilabel/utils/export_components_info.py @@ -15,6 +15,7 @@ import inspect from typing import Generator, List, Type, TypedDict, TypeVar +from distilabel.embeddings.base import Embeddings from distilabel.llms.base import LLM from distilabel.steps.base import _Step from distilabel.steps.tasks.base import _Task @@ -29,6 +30,7 @@ class ComponentsInfo(TypedDict): llms: List steps: List tasks: List + embeddings: List def export_components_info() -> ComponentsInfo: @@ -40,34 +42,27 @@ def export_components_info() -> ComponentsInfo: A dictionary containing `distilabel` components information """ - steps = [] - for step_type in _get_steps(): - steps.append( + return { + "steps": [ + {"name": step_type.__name__, "docstring": parse_google_docstring(step_type)} + for step_type in _get_steps() + ], + "tasks": [ + {"name": task_type.__name__, "docstring": parse_google_docstring(task_type)} + for task_type in _get_tasks() + ], + "llms": [ + {"name": llm_type.__name__, "docstring": parse_google_docstring(llm_type)} + for llm_type in _get_llms() + ], + "embeddings": [ { - "name": step_type.__name__, - "docstring": parse_google_docstring(step_type), + "name": embeddings_type.__name__, + "docstring": parse_google_docstring(embeddings_type), } - ) - - tasks = [] - for task_type in _get_tasks(): - tasks.append( - { - "name": task_type.__name__, - "docstring": parse_google_docstring(task_type), - } - ) - - llms = [] - for llm_type in _get_llms(): - llms.append( - { - "name": llm_type.__name__, - "docstring": parse_google_docstring(llm_type), - } - ) - - return {"steps": steps, "tasks": tasks, "llms": llms} + for embeddings_type in _get_embeddings() + ], + } T = TypeVar("T", covariant=True) @@ -118,6 +113,19 @@ def _get_llms() -> List[Type["LLM"]]: ] +def _get_embeddings() -> List[Type["Embeddings"]]: + """Get all `Embeddings` subclasses, that are not abstract classes. + + Returns: + A list of `Embeddings` subclasses, except `AsyncLLM` subclass + """ + return [ + embeddings_type + for embeddings_type in _recursive_subclasses(Embeddings) + if not inspect.isabstract(embeddings_type) + ] + + # Reference: https://adamj.eu/tech/2024/05/10/python-all-subclasses/ def _recursive_subclasses(klass: Type[T]) -> Generator[Type[T], None, None]: """Recursively get all subclasses of a class. diff --git a/src/distilabel/utils/logging.py b/src/distilabel/utils/logging.py index c69ebcda18..3d50300624 100644 --- a/src/distilabel/utils/logging.py +++ b/src/distilabel/utils/logging.py @@ -37,6 +37,8 @@ "filelock", "fsspec", "asyncio", + "sentence_transformers.SentenceTransformer", + "faiss.loader", ] queue_listener: Union[QueueListener, None] = None diff --git a/src/distilabel/utils/mkdocs/components_gallery.py b/src/distilabel/utils/mkdocs/components_gallery.py index ae43c586af..2d28c55fb6 100644 --- a/src/distilabel/utils/mkdocs/components_gallery.py +++ b/src/distilabel/utils/mkdocs/components_gallery.py @@ -74,7 +74,7 @@ ).read() ) -_TASKS_CATEGORY_TO_ICON = { +_STEPS_CATEGORY_TO_ICON = { "text-generation": ":material-text-box-edit:", "evol": ":material-dna:", "preference": ":material-poll:", @@ -82,9 +82,6 @@ "scorer": ":octicons-number-16:", "embedding": ":material-vector-line:", "format": ":material-format-list-bulleted:", -} - -_STEPS_CATEGORY_TO_ICON = { "filtering": ":material-filter:", "save": ":material-content-save:", "load": ":material-file-download:", @@ -145,6 +142,9 @@ def on_files( self.file_paths["llms"] = self._generate_llms_pages( src_dir=src_dir, llms=components_info["llms"] ) + self.file_paths["embeddings"] = self._generate_embeddings_pages( + src_dir=src_dir, embeddings=components_info["embeddings"] + ) # Add the new files to the files collections for relative_file_path in [ @@ -152,6 +152,7 @@ def on_files( *self.file_paths["steps"], *self.file_paths["tasks"], *self.file_paths["llms"], + *self.file_paths["embeddings"], ]: file = File( path=relative_file_path, @@ -266,7 +267,7 @@ def _generate_tasks_pages(self, src_dir: Path, tasks: list) -> List[str]: docstring = task["docstring"] if docstring["icon"] == "" and docstring["categories"]: first_category = docstring["categories"][0] - docstring["icon"] = _TASKS_CATEGORY_TO_ICON.get(first_category, "") + docstring["icon"] = _STEPS_CATEGORY_TO_ICON.get(first_category, "") name = task["name"] @@ -339,6 +340,48 @@ def _generate_llms_pages(self, src_dir: Path, llms: list) -> List[str]: return paths + def _generate_embeddings_pages(self, src_dir: Path, embeddings: list) -> List[str]: + """Generates the files for the `Embeddings` subsection of the components gallery. + + Args: + src_dir: The path to the source directory. + embeddings: The list of `Embeddings` components. + + Returns: + The relative paths to the generated files. + """ + + paths = ["components-gallery/embeddings/index.md"] + steps_gallery_page_path = src_dir / paths[0] + steps_gallery_page_path.parent.mkdir(parents=True, exist_ok=True) + + # Create detail page for each `LLM` + for embeddings_model in embeddings: + content = _LLM_DETAIL_TEMPLATE.render(llm=embeddings_model) + + llm_path = ( + f"components-gallery/embeddings/{embeddings_model['name'].lower()}.md" + ) + path = src_dir / llm_path + with open(path, "w") as f: + f.write(content) + + paths.append(llm_path) + + # Create the `components-gallery/llms/index.md` file + content = _COMPONENTS_LIST_TEMPLATE.render( + title="Embeddings Gallery", + description="", + components=embeddings, + component_group="embeddings", + default_icon=":material-vector-line:", + ) + + with open(steps_gallery_page_path, "w") as f: + f.write(content) + + return paths + def on_nav( self, nav: "Navigation", *, config: "MkDocsConfig", files: "Files" ) -> Union["Navigation", None]: diff --git a/src/distilabel/utils/mkdocs/templates/components-gallery/index.md b/src/distilabel/utils/mkdocs/templates/components-gallery/index.md index eb2914b6a6..cc3e44aecf 100644 --- a/src/distilabel/utils/mkdocs/templates/components-gallery/index.md +++ b/src/distilabel/utils/mkdocs/templates/components-gallery/index.md @@ -31,4 +31,12 @@ hide: [:octicons-arrow-right-24: LLMs](llms/index.md){ .bottom } +- :material-vector-line:{ .lg .middle } __Embeddings__ + + --- + + Explore all the available `Embeddings` models integrated with `distilabel`. + + [:octicons-arrow-right-24: Embeddings](embeddings/index.md){ .bottom } + diff --git a/tests/unit/embeddings/__init__.py b/tests/unit/embeddings/__init__.py new file mode 100644 index 0000000000..20ce00bda7 --- /dev/null +++ b/tests/unit/embeddings/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/tests/unit/embeddings/test_sentence_transformers.py b/tests/unit/embeddings/test_sentence_transformers.py new file mode 100644 index 0000000000..2efeabb807 --- /dev/null +++ b/tests/unit/embeddings/test_sentence_transformers.py @@ -0,0 +1,42 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from distilabel.embeddings.sentence_transformers import SentenceTransformerEmbeddings + + +class TestSentenceTransformersEmbeddings: + def test_model_name(self) -> None: + embeddings = SentenceTransformerEmbeddings( + model="sentence-transformers/all-MiniLM-L6-v2" + ) + + assert embeddings.model_name == "sentence-transformers/all-MiniLM-L6-v2" + + def test_encode(self) -> None: + embeddings = SentenceTransformerEmbeddings( + model="sentence-transformers/all-MiniLM-L6-v2" + ) + + embeddings.load() + + results = embeddings.encode( + inputs=[ + "Hello, how are you?", + "What a nice day!", + "I hear that llamas are very popular now.", + ] + ) + + for result in results: + assert len(result) == 384 diff --git a/tests/unit/steps/embeddings/__init__.py b/tests/unit/steps/embeddings/__init__.py new file mode 100644 index 0000000000..20ce00bda7 --- /dev/null +++ b/tests/unit/steps/embeddings/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/tests/unit/steps/embeddings/test_embedding_generation.py b/tests/unit/steps/embeddings/test_embedding_generation.py new file mode 100644 index 0000000000..66284e0ed9 --- /dev/null +++ b/tests/unit/steps/embeddings/test_embedding_generation.py @@ -0,0 +1,51 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from distilabel.embeddings.sentence_transformers import SentenceTransformerEmbeddings +from distilabel.steps.embeddings.embedding_generation import EmbeddingGeneration + + +class TestEmbeddingGeneration: + def test_process(self) -> None: + step = EmbeddingGeneration( + embeddings=SentenceTransformerEmbeddings( + model="sentence-transformers/all-MiniLM-L6-v2" + ) + ) + + step.load() + + results = next( + step.process( + inputs=[ + {"text": "Hello, how are you?"}, + {"text": "What a nice day!"}, + {"text": "I hear that llamas are very popular now."}, + ] + ) + ) + + step.unload() + + for result, text in zip( + results, + [ + "Hello, how are you?", + "What a nice day!", + "I hear that llamas are very popular now.", + ], + ): + assert len(result["embedding"]) == 384 + assert result["text"] == text + assert result["model_name"] == "sentence-transformers/all-MiniLM-L6-v2" diff --git a/tests/unit/steps/embeddings/test_nearest_neighbour.py b/tests/unit/steps/embeddings/test_nearest_neighbour.py new file mode 100644 index 0000000000..e48cbd0897 --- /dev/null +++ b/tests/unit/steps/embeddings/test_nearest_neighbour.py @@ -0,0 +1,86 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from distilabel.steps.embeddings.nearest_neighbour import FaissNearestNeighbour + + +class TestFaissNearestNeighbour: + def test_process(self) -> None: + step = FaissNearestNeighbour() + + step.load() + + results = next( + step.process( + inputs=[ + {"embedding": [0.1, -0.4, 0.7, 0.2]}, + {"embedding": [-0.3, 0.9, 0.1, -0.5]}, + {"embedding": [0.6, 0.2, -0.1, 0.8]}, + {"embedding": [-0.2, -0.6, 0.4, 0.3]}, + {"embedding": [0.9, 0.1, -0.3, -0.2]}, + {"embedding": [0.4, -0.7, 0.6, 0.1]}, + {"embedding": [-0.5, 0.3, -0.2, 0.9]}, + {"embedding": [0.7, 0.5, -0.4, -0.1]}, + {"embedding": [-0.1, -0.9, 0.8, 0.6]}, + ] + ) + ) + + assert results == [ + { + "embedding": [0.1, -0.4, 0.7, 0.2], + "nn_indices": [5], + "nn_scores": [0.19999998807907104], + }, + { + "embedding": [-0.3, 0.9, 0.1, -0.5], + "nn_indices": [7], + "nn_scores": [1.5699999332427979], + }, + { + "embedding": [0.6, 0.2, -0.1, 0.8], + "nn_indices": [7], + "nn_scores": [1.0000001192092896], + }, + { + "embedding": [-0.2, -0.6, 0.4, 0.3], + "nn_indices": [0], + "nn_scores": [0.23000000417232513], + }, + { + "embedding": [0.9, 0.1, -0.3, -0.2], + "nn_indices": [7], + "nn_scores": [0.2200000137090683], + }, + { + "embedding": [0.4, -0.7, 0.6, 0.1], + "nn_indices": [0], + "nn_scores": [0.19999998807907104], + }, + { + "embedding": [-0.5, 0.3, -0.2, 0.9], + "nn_indices": [2], + "nn_scores": [1.2400000095367432], + }, + { + "embedding": [0.7, 0.5, -0.4, -0.1], + "nn_indices": [4], + "nn_scores": [0.2200000137090683], + }, + { + "embedding": [-0.1, -0.9, 0.8, 0.6], + "nn_indices": [3], + "nn_scores": [0.3499999940395355], + }, + ]