diff --git a/src/core/operators/cluster_embeddings.py b/src/core/operators/cluster_embeddings.py index 115256f..9f0d7de 100644 --- a/src/core/operators/cluster_embeddings.py +++ b/src/core/operators/cluster_embeddings.py @@ -19,23 +19,23 @@ def initialize(param): # Constants RANDOM_STATE = 50 - def gen_data(labels, input_data): + def gen_data(payloads, labels): """ Generates formatted output data. Args: + payloads (list): List of payloads labels (np.ndarray): An array of cluster labels - input_data (list[dict]): Operator input Returns: dict: A dictionary mapping cluster labels to corresponding array of payloads """ out = {} - for label, item in zip(labels, input_data): + for label, payload in zip(labels, payloads): key = f'cluster_{label}' if key not in out: out[key] = [] - out[key].append(item['payload']) + out[key].append(payload) return out def KMeans_clustering(matrix, n_clusters): @@ -76,12 +76,12 @@ def AffinityPropagation_clustering(matrix): """ return AffinityPropagation(random_state=RANDOM_STATE).fit_predict(np.array(matrix)) -def run(embeddings, n_clusters=None, modality='audio'): +def run(input_data, n_clusters=None, modality='audio'): """ Runs the operator. Args: - embeddings (list[dict]): List of data with each dictionary containing `embedding` and `payload` properties + input_data (list[dict]): List of data with each dictionary containing `embedding` and `payload` properties n_clusters (int, optional): Number of clusters. Defaults to None modality (str, optional): Source modality of embeddings. Defaults to 'audio' @@ -89,16 +89,24 @@ def run(embeddings, n_clusters=None, modality='audio'): dict: A dictionary mapping cluster labels to corresponding array of payloads Raises: - ValueError: If invalid modality is provided + ValueError: Modality should be either `audio` or `visual` + KeyError: Each data point in input must have `embedding` and `payload` properties """ - matrix = [data['embedding'] for data in embeddings] # isolating list of embeddings + # Parse data: + try: + matrix, payloads = zip(*[(data['embedding'], data['payload']) for data in input_data]) + except KeyError as e: + raise KeyError(f"Invalid data. Each data point in input must have `embedding` and `payload` properties. Missing key: {e}.") + + # Delegate appropriate clustering algorithm for the given params: if n_clusters: + n_clusters = int(n_clusters) # cast it to int if modality == 'audio': labels = KMeans_clustering(matrix=matrix, n_clusters=n_clusters) elif modality == 'visual': labels = Agglomerative_clustering(matrix=matrix, n_clusters=n_clusters) else: - raise ValueError("Invalid modality. Modality can only be `audio` or `visual`.") + raise ValueError("Invalid modality. Modality should be either `audio` or `visual`.") else: labels = AffinityPropagation_clustering(matrix=matrix) - return gen_data(labels=labels, input_data=embeddings) + return gen_data(payloads=payloads, labels=labels) # format output diff --git a/src/core/operators/test_cluster_embeddings.py b/src/core/operators/test_cluster_embeddings.py index 6be9c6a..c371eb7 100644 --- a/src/core/operators/test_cluster_embeddings.py +++ b/src/core/operators/test_cluster_embeddings.py @@ -1,8 +1,8 @@ import unittest -from src.core.operators import cluster_embeddings +from core.operators import cluster_embeddings # Test constants -EMBEDDINGS = [ +MOCK_DATA = [ {"payload": "A", "embedding": [0, 1]}, {"payload": "B", "embedding": [1, 0]}, {"payload": "C", "embedding": [100, 101]}, @@ -23,21 +23,21 @@ def tearDownClass(cls): pass def test_kmeans_clustering(self): - result = cluster_embeddings.run(embeddings=EMBEDDINGS, n_clusters=2, modality="audio") + result = cluster_embeddings.run(input_data=MOCK_DATA, n_clusters=2, modality="audio") self.assertIn("cluster_0", result) self.assertIn("cluster_1", result) self.assertEqual(len(result), 2) self.assertCountEqual([result["cluster_0"], result["cluster_1"]], EXPECTED_CLUSTERS) def test_agglomerative_clustering(self): - result = cluster_embeddings.run(embeddings=EMBEDDINGS, n_clusters=2, modality="visual") + result = cluster_embeddings.run(input_data=MOCK_DATA, n_clusters=2, modality="visual") self.assertIn("cluster_0", result) self.assertIn("cluster_1", result) self.assertEqual(len(result), 2) self.assertCountEqual([result["cluster_0"], result["cluster_1"]], EXPECTED_CLUSTERS) def test_affinity_propagation(self): - result = cluster_embeddings.run(embeddings=EMBEDDINGS, n_clusters=None, modality="audio") + result = cluster_embeddings.run(input_data=MOCK_DATA, n_clusters=None, modality="audio") self.assertIn("cluster_0", result) self.assertIn("cluster_1", result) self.assertEqual(len(result), 2)