Skip to content

Commit

Permalink
Add Embeddings base class, SentenceTransformerEmbeddings class, `…
Browse files Browse the repository at this point in the history
…EmbeddingGeneration` and `FaissNearestNeighbour` steps (#830)

* Add `Embeddings` base class and `SentenceTransformers` class

* Add `EmbeddingGeneration` step

* Add `precision` attribute

* Add docstrings

* Add example to docstring

* Update component gallery to include `Embeddings` models

* Add `sentence-transformers` extra

* Add `FaissNearestNeighbour` step

* Add category and example

* Merge category to icons dictionaries

* Add missing unit tests

* Add `faiss-cpu` and `faiss-gpu` extras

* Update unit tests
  • Loading branch information
gabrielmbmb authored Jul 26, 2024
1 parent 6ac0267 commit 04b86f5
Show file tree
Hide file tree
Showing 20 changed files with 865 additions and 33 deletions.
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ vllm = [
# `setuptools` is needed to be installed if installed with `uv pip install distilabel[vllm]`
"setuptools",
]
sentence-transformers = ["sentence-transformers >= 3.0.0"]
faiss-cpu = ["faiss-cpu >= 1.8.0"]
faiss-gpu = ["faiss-cpu >= 1.7.2"]

[project.urls]
Documentation = "https://distilabel.argilla.io/"
Expand Down
2 changes: 1 addition & 1 deletion scripts/install_dependencies.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ python_version=$(python -c "import sys; print(sys.version_info[:2])")

python -m pip install uv

uv pip install --system -e ".[dev,tests,anthropic,argilla,cohere,groq,hf-inference-endpoints,hf-transformers,litellm,llama-cpp,ollama,openai,outlines,vertexai,mistralai,instructor]"
uv pip install --system -e ".[dev,tests,anthropic,argilla,cohere,groq,hf-inference-endpoints,hf-transformers,litellm,llama-cpp,ollama,openai,outlines,vertexai,mistralai,instructor,sentence-transformers,faiss-cpu]"

if [ "${python_version}" != "(3, 12)" ]; then
uv pip install --system -e .[ray]
Expand Down
21 changes: 21 additions & 0 deletions src/distilabel/embeddings/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from distilabel.embeddings.base import Embeddings
from distilabel.embeddings.sentence_transformers import SentenceTransformerEmbeddings

__all__ = [
"Embeddings",
"SentenceTransformerEmbeddings",
]
72 changes: 72 additions & 0 deletions src/distilabel/embeddings/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from abc import ABC, abstractmethod
from typing import List, Union

from pydantic import BaseModel, ConfigDict, PrivateAttr

from distilabel.mixins.runtime_parameters import RuntimeParametersMixin
from distilabel.utils.serialization import _Serializable


class Embeddings(RuntimeParametersMixin, BaseModel, _Serializable, ABC):
"""Base class for `Embeddings` models.
To implement an `Embeddings` subclass, you need to subclass this class and implement:
- `load` method to load the `Embeddings` model. Don't forget to call `super().load()`,
so the `_logger` attribute is initialized.
- `model_name` property to return the model name used for the `Embeddings`.
- `encode` method to generate the sentence embeddings.
Attributes:
_logger: the logger to be used for the `Embeddings` model. It will be initialized
when the `load` method is called.
"""

model_config = ConfigDict(
arbitrary_types_allowed=True,
protected_namespaces=(),
validate_default=True,
validate_assignment=True,
extra="forbid",
)
_logger: Union[logging.Logger, None] = PrivateAttr(...)

def load(self) -> None:
"""Method to be called to initialize the `Embeddings`"""
self._logger = logging.getLogger(f"distilabel.llm.{self.model_name}")

def unload(self) -> None:
"""Method to be called to unload the `Embeddings` and release any resources."""
pass

@property
@abstractmethod
def model_name(self) -> str:
"""Returns the model name used for the `Embeddings`."""
pass

@abstractmethod
def encode(self, inputs: List[str]) -> List[List[Union[int, float]]]:
"""Generates embeddings for the provided inputs.
Args:
inputs: a list of texts for which an embedding has to be generated.
Returns:
The generated embeddings.
"""
pass
157 changes: 157 additions & 0 deletions src/distilabel/embeddings/sentence_transformers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union

from pydantic import Field, PrivateAttr

from distilabel.embeddings.base import Embeddings
from distilabel.llms.mixins.cuda_device_placement import CudaDevicePlacementMixin
from distilabel.mixins.runtime_parameters import RuntimeParameter

if TYPE_CHECKING:
from sentence_transformers import SentenceTransformer


class SentenceTransformerEmbeddings(Embeddings, CudaDevicePlacementMixin):
"""`sentence-transformers` library implementation for embedding generation.
Attributes:
model: the model Hugging Face Hub repo id or a path to a directory containing the
model weights and configuration files.
device: the name of the device used to load the model e.g. "cuda", "mps", etc.
Defaults to `None`.
prompts: a dictionary containing prompts to be used with the model. Defaults to
`None`.
default_prompt_name: the default prompt (in `prompts`) that will be applied to the
inputs. If not provided, then no prompt will be used. Defaults to `None`.
trust_remote_code: whether to allow fetching and executing remote code fetched
from the repository in the Hub. Defaults to `False`.
revision: if `model` refers to a Hugging Face Hub repository, then the revision
(e.g. a branch name or a commit id) to use. Defaults to `"main"`.
token: the Hugging Face Hub token that will be used to authenticate to the Hugging
Face Hub. If not provided, the `HF_TOKEN` environment or `huggingface_hub` package
local configuration will be used. Defaults to `None`.
truncate_dim: the dimension to truncate the sentence embeddings. Defaults to `None`.
model_kwargs: extra kwargs that will be passed to the Hugging Face `transformers`
model class. Defaults to `None`.
tokenizer_kwargs: extra kwargs that will be passed to the Hugging Face `transformers`
tokenizer class. Defaults to `None`.
config_kwargs: extra kwargs that will be passed to the Hugging Face `transformers`
configuration class. Defaults to `None`.
precision: the dtype that will have the resulting embeddings. Defaults to `"float32"`.
normalize_embeddings: whether to normalize the embeddings so they have a length
of 1. Defaults to `None`.
Examples:
Generating sentence embeddings:
```python
from distilabel.embeddings import SentenceTransformerEmbeddings
embeddings = SentenceTransformerEmbeddings(model="mixedbread-ai/mxbai-embed-large-v1")
embeddings.load()
results = embeddings.encode(inputs=["distilabel is awesome!", "and Argilla!"])
# [
# [-0.05447685346007347, -0.01623094454407692, ...],
# [4.4889533455716446e-05, 0.044016145169734955, ...],
# ]
```
"""

model: str
device: Optional[RuntimeParameter[str]] = Field(
default=None,
description="The device to be used to load the model. If `None`, then it"
" will check if a GPU can be used.",
)
prompts: Optional[Dict[str, str]] = None
default_prompt_name: Optional[str] = None
trust_remote_code: bool = False
revision: Optional[str] = None
token: Optional[str] = None
truncate_dim: Optional[int] = None
model_kwargs: Optional[Dict[str, Any]] = None
tokenizer_kwargs: Optional[Dict[str, Any]] = None
config_kwargs: Optional[Dict[str, Any]] = None
precision: Optional[Literal["float32", "int8", "uint8", "binary", "ubinary"]] = (
"float32"
)
normalize_embeddings: RuntimeParameter[bool] = Field(
default=True,
description="Whether to normalize the embeddings so the generated vectors"
" have a length of 1 or not.",
)

_model: Union["SentenceTransformer", None] = PrivateAttr(None)

def load(self) -> None:
"""Loads the Sentence Transformer model"""
super().load()

if self.device == "cuda":
CudaDevicePlacementMixin.load(self)

try:
from sentence_transformers import SentenceTransformer
except ImportError as e:
raise ImportError(
"`sentence-transformers` package is not installed. Please install it using"
" `pip install sentence-transformers`."
) from e

self._model = SentenceTransformer(
model_name_or_path=self.model,
device=self.device,
prompts=self.prompts,
default_prompt_name=self.default_prompt_name,
trust_remote_code=self.trust_remote_code,
revision=self.revision,
token=self.token,
truncate_dim=self.truncate_dim,
model_kwargs=self.model_kwargs,
tokenizer_kwargs=self.tokenizer_kwargs,
config_kwargs=self.config_kwargs,
)

@property
def model_name(self) -> str:
"""Returns the name of the model."""
return self.model

def encode(self, inputs: List[str]) -> List[List[Union[int, float]]]:
"""Generates embeddings for the provided inputs.
Args:
inputs: a list of texts for which an embedding has to be generated.
Returns:
The generated embeddings.
"""
return self._model.encode( # type: ignore
sentences=inputs,
batch_size=len(inputs),
convert_to_numpy=True,
precision=self.precision, # type: ignore
normalize_embeddings=self.normalize_embeddings, # type: ignore
).tolist() # type: ignore

def unload(self) -> None:
del self._model
if self.device == "cuda":
CudaDevicePlacementMixin.unload(self)
super().unload()
3 changes: 2 additions & 1 deletion src/distilabel/llms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ class LLM(RuntimeParametersMixin, BaseModel, _Serializable, ABC):
_logger: Union[logging.Logger, None] = PrivateAttr(...)

def load(self) -> None:
"""Method to be called to initialize the `LLM`, its logger and optionally the structured output generator."""
"""Method to be called to initialize the `LLM`, its logger and optionally the
structured output generator."""
self._logger = logging.getLogger(f"distilabel.llm.{self.model_name}")

def unload(self) -> None:
Expand Down
2 changes: 2 additions & 0 deletions src/distilabel/llms/huggingface/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ class TransformersLLM(LLM, MagpieChatTemplateMixin, CudaDevicePlacementMixin):
(e.g. a branch name or a commit id) to use. Defaults to `"main"`.
torch_dtype: the torch dtype to use for the model e.g. "float16", "float32", etc.
Defaults to `"auto"`.
trust_remote_code: whether to allow fetching and executing remote code fetched
from the repository in the Hub. Defaults to `False`.
trust_remote_code: whether to trust or not remote (code in the Hugging Face Hub
repository) code to load the model. Defaults to `False`.
model_kwargs: additional dictionary of keyword arguments that will be passed to
Expand Down
4 changes: 4 additions & 0 deletions src/distilabel/steps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from distilabel.steps.columns.merge import MergeColumns
from distilabel.steps.decorator import step
from distilabel.steps.deita import DeitaFiltering
from distilabel.steps.embeddings.embedding_generation import EmbeddingGeneration
from distilabel.steps.embeddings.nearest_neighbour import FaissNearestNeighbour
from distilabel.steps.formatting.conversation import ConversationTemplate
from distilabel.steps.formatting.dpo import (
FormatChatGenerationDPO,
Expand Down Expand Up @@ -54,6 +56,8 @@
"CombineColumns",
"ConversationTemplate",
"DeitaFiltering",
"EmbeddingGeneration",
"FaissNearestNeighbour",
"ExpandColumns",
"FormatChatGenerationDPO",
"FormatChatGenerationSFT",
Expand Down
14 changes: 14 additions & 0 deletions src/distilabel/steps/embeddings/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Loading

0 comments on commit 04b86f5

Please sign in to comment.