Skip to content

Commit

Permalink
Remove scikit-learn dep
Browse files Browse the repository at this point in the history
  • Loading branch information
gabrielmbmb committed Mar 27, 2024
1 parent a00c616 commit a84199b
Showing 1 changed file with 63 additions and 42 deletions.
105 changes: 63 additions & 42 deletions src/distilabel/steps/deita.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING, List, Type
from typing import TYPE_CHECKING, List, Literal

import numpy as np
from pydantic import Field, PrivateAttr
from pydantic import Field

from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.steps.base import GlobalStep, StepInput

if TYPE_CHECKING:
from sklearn.neighbors import NearestNeighbors

from distilabel.steps.typing import StepOutput


Expand Down Expand Up @@ -73,18 +71,10 @@ class DeitaFiltering(GlobalStep):
default=True,
description="Whether to normalize the embeddings before computing the cosine distance.",
)

_NearestNeighbors: Type["NearestNeighbors"] = PrivateAttr(...)

def load(self) -> None:
try:
from sklearn.neighbors import NearestNeighbors
except ImportError as ie:
raise ImportError(
"`scikit-learn` is not installed. Please install it using `pip install huggingface-hub`."
) from ie

self._NearestNeighbors = NearestNeighbors
distance_metric: RuntimeParameter[Literal["cosine", "manhattan"]] = Field(
default="cosine",
description="The distance metric to use. Currently only 'cosine' is supported.",
)

@property
def inputs(self) -> List[str]:
Expand All @@ -94,6 +84,28 @@ def inputs(self) -> List[str]:
def outputs(self) -> List[str]:
return ["deita_score", "nearest_neighbor_distance"]

def process(self, inputs: StepInput) -> "StepOutput": # type: ignore
"""Filter the dataset based on the DEITA score and the cosine distance between the
embeddings.
Args:
inputs: The input data.
Returns:
The filtered dataset.
"""
inputs = self._compute_deita_score(inputs)
inputs = self._compute_nearest_neighbor(inputs)
inputs.sort(key=lambda x: x["deita_score"], reverse=True)

selected_rows = []
for input in inputs:
if len(selected_rows) >= self.data_budget: # type: ignore
break
if input["nearest_neighbor_distance"] >= self.diversity_threshold:
selected_rows.append(input)
yield selected_rows

def _compute_deita_score(self, inputs: StepInput) -> StepInput:
"""Computes the DEITA score for each instruction-response pair. The DEITA score is
the product of the instruction score and the response score.
Expand Down Expand Up @@ -142,19 +154,23 @@ def _compute_nearest_neighbor(self, inputs: StepInput) -> StepInput:
Returns:
The input data with the cosine distance computed.
"""
embeddings = [input["embedding"] for input in inputs]
embeddings = np.array([input["embedding"] for input in inputs])
if self.normalize_embeddings:
embeddings = self._normalize_embeddings(embeddings)
self._logger.info("📏 Computing nearest neighbor distance...")
nn = self._NearestNeighbors(
n_neighbors=2, metric="cosine", algorithm="brute"
).fit(embeddings)
distances, _ = nn.kneighbors(embeddings, return_distance=True)

if self.distance_metric == "cosine":
self._logger.info("📏 Using cosine distance.")
distances = self._cosine_distance(embeddings)
else:
self._logger.info("📏 Using manhattan distance.")
distances = self._manhattan_distance(embeddings)

for distance, input in zip(distances, inputs):
input["nearest_neighbor_distance"] = distance[-1]
input["nearest_neighbor_distance"] = distance
return inputs

def _normalize_embeddings(self, embeddings: List[np.ndarray]) -> List[np.ndarray]:
def _normalize_embeddings(self, embeddings: np.ndarray) -> np.ndarray:
"""Normalize the embeddings.
Args:
Expand All @@ -164,29 +180,34 @@ def _normalize_embeddings(self, embeddings: List[np.ndarray]) -> List[np.ndarray
The normalized embeddings.
"""
self._logger.info("⚖️ Normalizing embeddings...")
for i, embedding in enumerate(embeddings):
norm = np.linalg.norm(embedding)
embeddings[i] = embedding / norm
return embeddings
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
return embeddings / norms

def process(self, inputs: StepInput) -> "StepOutput": # type: ignore
"""Filter the dataset based on the DEITA score and the cosine distance between the
embeddings.
def _cosine_distance(self, embeddings: np.array) -> np.array:
"""Computes the cosine distance between the embeddings.
Args:
inputs: The input data.
embeddings: The embeddings.
Returns:
The filtered dataset.
The cosine distance between the embeddings.
"""
inputs = self._compute_deita_score(inputs)
inputs = self._compute_nearest_neighbor(inputs)
inputs.sort(key=lambda x: x["deita_score"], reverse=True)
cosine_similarity = np.dot(embeddings, embeddings.T)
cosine_distance = 1 - cosine_similarity
# Ignore self-distance
np.fill_diagonal(cosine_distance, np.inf)
return np.min(cosine_distance, axis=1)

selected_rows = []
for input in inputs:
if len(selected_rows) >= self.data_budget: # type: ignore
break
if input["nearest_neighbor_distance"] >= self.diversity_threshold:
selected_rows.append(input)
yield selected_rows
def _manhattan_distance(self, embeddings: np.array) -> np.array:
"""Computes the manhattan distance between the embeddings.
Args:
embeddings: The embeddings.
Returns:
The manhattan distance between the embeddings.
"""
manhattan_distance = np.abs(embeddings[:, None] - embeddings).sum(-1)
# Ignore self-distance
np.fill_diagonal(manhattan_distance, np.inf)
return np.min(manhattan_distance, axis=1)

0 comments on commit a84199b

Please sign in to comment.