diff --git a/src/distilabel/steps/deita.py b/src/distilabel/steps/deita.py index 0278ff8418..67ac7e4252 100644 --- a/src/distilabel/steps/deita.py +++ b/src/distilabel/steps/deita.py @@ -12,17 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, List, Type +from typing import TYPE_CHECKING, List, Literal import numpy as np -from pydantic import Field, PrivateAttr +from pydantic import Field from distilabel.mixins.runtime_parameters import RuntimeParameter from distilabel.steps.base import GlobalStep, StepInput if TYPE_CHECKING: - from sklearn.neighbors import NearestNeighbors - from distilabel.steps.typing import StepOutput @@ -73,18 +71,10 @@ class DeitaFiltering(GlobalStep): default=True, description="Whether to normalize the embeddings before computing the cosine distance.", ) - - _NearestNeighbors: Type["NearestNeighbors"] = PrivateAttr(...) - - def load(self) -> None: - try: - from sklearn.neighbors import NearestNeighbors - except ImportError as ie: - raise ImportError( - "`scikit-learn` is not installed. Please install it using `pip install huggingface-hub`." - ) from ie - - self._NearestNeighbors = NearestNeighbors + distance_metric: RuntimeParameter[Literal["cosine", "manhattan"]] = Field( + default="cosine", + description="The distance metric to use. Currently only 'cosine' is supported.", + ) @property def inputs(self) -> List[str]: @@ -94,6 +84,28 @@ def inputs(self) -> List[str]: def outputs(self) -> List[str]: return ["deita_score", "nearest_neighbor_distance"] + def process(self, inputs: StepInput) -> "StepOutput": # type: ignore + """Filter the dataset based on the DEITA score and the cosine distance between the + embeddings. + + Args: + inputs: The input data. + + Returns: + The filtered dataset. + """ + inputs = self._compute_deita_score(inputs) + inputs = self._compute_nearest_neighbor(inputs) + inputs.sort(key=lambda x: x["deita_score"], reverse=True) + + selected_rows = [] + for input in inputs: + if len(selected_rows) >= self.data_budget: # type: ignore + break + if input["nearest_neighbor_distance"] >= self.diversity_threshold: + selected_rows.append(input) + yield selected_rows + def _compute_deita_score(self, inputs: StepInput) -> StepInput: """Computes the DEITA score for each instruction-response pair. The DEITA score is the product of the instruction score and the response score. @@ -142,19 +154,23 @@ def _compute_nearest_neighbor(self, inputs: StepInput) -> StepInput: Returns: The input data with the cosine distance computed. """ - embeddings = [input["embedding"] for input in inputs] + embeddings = np.array([input["embedding"] for input in inputs]) if self.normalize_embeddings: embeddings = self._normalize_embeddings(embeddings) self._logger.info("📏 Computing nearest neighbor distance...") - nn = self._NearestNeighbors( - n_neighbors=2, metric="cosine", algorithm="brute" - ).fit(embeddings) - distances, _ = nn.kneighbors(embeddings, return_distance=True) + + if self.distance_metric == "cosine": + self._logger.info("📏 Using cosine distance.") + distances = self._cosine_distance(embeddings) + else: + self._logger.info("📏 Using manhattan distance.") + distances = self._manhattan_distance(embeddings) + for distance, input in zip(distances, inputs): - input["nearest_neighbor_distance"] = distance[-1] + input["nearest_neighbor_distance"] = distance return inputs - def _normalize_embeddings(self, embeddings: List[np.ndarray]) -> List[np.ndarray]: + def _normalize_embeddings(self, embeddings: np.ndarray) -> np.ndarray: """Normalize the embeddings. Args: @@ -164,29 +180,34 @@ def _normalize_embeddings(self, embeddings: List[np.ndarray]) -> List[np.ndarray The normalized embeddings. """ self._logger.info("⚖️ Normalizing embeddings...") - for i, embedding in enumerate(embeddings): - norm = np.linalg.norm(embedding) - embeddings[i] = embedding / norm - return embeddings + norms = np.linalg.norm(embeddings, axis=1, keepdims=True) + return embeddings / norms - def process(self, inputs: StepInput) -> "StepOutput": # type: ignore - """Filter the dataset based on the DEITA score and the cosine distance between the - embeddings. + def _cosine_distance(self, embeddings: np.array) -> np.array: + """Computes the cosine distance between the embeddings. Args: - inputs: The input data. + embeddings: The embeddings. Returns: - The filtered dataset. + The cosine distance between the embeddings. """ - inputs = self._compute_deita_score(inputs) - inputs = self._compute_nearest_neighbor(inputs) - inputs.sort(key=lambda x: x["deita_score"], reverse=True) + cosine_similarity = np.dot(embeddings, embeddings.T) + cosine_distance = 1 - cosine_similarity + # Ignore self-distance + np.fill_diagonal(cosine_distance, np.inf) + return np.min(cosine_distance, axis=1) - selected_rows = [] - for input in inputs: - if len(selected_rows) >= self.data_budget: # type: ignore - break - if input["nearest_neighbor_distance"] >= self.diversity_threshold: - selected_rows.append(input) - yield selected_rows + def _manhattan_distance(self, embeddings: np.array) -> np.array: + """Computes the manhattan distance between the embeddings. + + Args: + embeddings: The embeddings. + + Returns: + The manhattan distance between the embeddings. + """ + manhattan_distance = np.abs(embeddings[:, None] - embeddings).sum(-1) + # Ignore self-distance + np.fill_diagonal(manhattan_distance, np.inf) + return np.min(manhattan_distance, axis=1)