Skip to content

Commit

Permalink
feat: improve exception handling and logic
Browse files Browse the repository at this point in the history
Co-authored-by: Chaithanya512 <[email protected]>
Signed-off-by: Snehil Shah <[email protected]>
  • Loading branch information
Snehil-Shah and Chaithanya512 committed Aug 26, 2024
1 parent 4e08e35 commit 5756a59
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 15 deletions.
28 changes: 18 additions & 10 deletions src/core/operators/cluster_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,23 @@ def initialize(param):
# Constants
RANDOM_STATE = 50

def gen_data(labels, input_data):
def gen_data(payloads, labels):
"""
Generates formatted output data.
Args:
payloads (list): List of payloads
labels (np.ndarray): An array of cluster labels
input_data (list[dict]): Operator input
Returns:
dict: A dictionary mapping cluster labels to corresponding array of payloads
"""
out = {}
for label, item in zip(labels, input_data):
for label, payload in zip(labels, payloads):
key = f'cluster_{label}'
if key not in out:
out[key] = []
out[key].append(item['payload'])
out[key].append(payload)
return out

def KMeans_clustering(matrix, n_clusters):
Expand Down Expand Up @@ -76,29 +76,37 @@ def AffinityPropagation_clustering(matrix):
"""
return AffinityPropagation(random_state=RANDOM_STATE).fit_predict(np.array(matrix))

def run(embeddings, n_clusters=None, modality='audio'):
def run(input_data, n_clusters=None, modality='audio'):
"""
Runs the operator.
Args:
embeddings (list[dict]): List of data with each dictionary containing `embedding` and `payload` properties
input_data (list[dict]): List of data with each dictionary containing `embedding` and `payload` properties
n_clusters (int, optional): Number of clusters. Defaults to None
modality (str, optional): Source modality of embeddings. Defaults to 'audio'
Returns:
dict: A dictionary mapping cluster labels to corresponding array of payloads
Raises:
ValueError: If invalid modality is provided
ValueError: Modality should be either `audio` or `visual`
KeyError: Each data point in input must have `embedding` and `payload` properties
"""
matrix = [data['embedding'] for data in embeddings] # isolating list of embeddings
# Parse data:
try:
matrix, payloads = zip(*[(data['embedding'], data['payload']) for data in input_data])
except KeyError as e:
raise KeyError(f"Invalid data. Each data point in input must have `embedding` and `payload` properties. Missing key: {e}.")

# Delegate appropriate clustering algorithm for the given params:
if n_clusters:
n_clusters = int(n_clusters) # cast it to int
if modality == 'audio':
labels = KMeans_clustering(matrix=matrix, n_clusters=n_clusters)
elif modality == 'visual':
labels = Agglomerative_clustering(matrix=matrix, n_clusters=n_clusters)
else:
raise ValueError("Invalid modality. Modality can only be `audio` or `visual`.")
raise ValueError("Invalid modality. Modality should be either `audio` or `visual`.")
else:
labels = AffinityPropagation_clustering(matrix=matrix)
return gen_data(labels=labels, input_data=embeddings)
return gen_data(payloads=payloads, labels=labels) # format output
10 changes: 5 additions & 5 deletions src/core/operators/test_cluster_embeddings.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import unittest
from src.core.operators import cluster_embeddings
from core.operators import cluster_embeddings

# Test constants
EMBEDDINGS = [
MOCK_DATA = [
{"payload": "A", "embedding": [0, 1]},
{"payload": "B", "embedding": [1, 0]},
{"payload": "C", "embedding": [100, 101]},
Expand All @@ -23,21 +23,21 @@ def tearDownClass(cls):
pass

def test_kmeans_clustering(self):
result = cluster_embeddings.run(embeddings=EMBEDDINGS, n_clusters=2, modality="audio")
result = cluster_embeddings.run(input_data=MOCK_DATA, n_clusters=2, modality="audio")
self.assertIn("cluster_0", result)
self.assertIn("cluster_1", result)
self.assertEqual(len(result), 2)
self.assertCountEqual([result["cluster_0"], result["cluster_1"]], EXPECTED_CLUSTERS)

def test_agglomerative_clustering(self):
result = cluster_embeddings.run(embeddings=EMBEDDINGS, n_clusters=2, modality="visual")
result = cluster_embeddings.run(input_data=MOCK_DATA, n_clusters=2, modality="visual")
self.assertIn("cluster_0", result)
self.assertIn("cluster_1", result)
self.assertEqual(len(result), 2)
self.assertCountEqual([result["cluster_0"], result["cluster_1"]], EXPECTED_CLUSTERS)

def test_affinity_propagation(self):
result = cluster_embeddings.run(embeddings=EMBEDDINGS, n_clusters=None, modality="audio")
result = cluster_embeddings.run(input_data=MOCK_DATA, n_clusters=None, modality="audio")
self.assertIn("cluster_0", result)
self.assertIn("cluster_1", result)
self.assertEqual(len(result), 2)
Expand Down

0 comments on commit 5756a59

Please sign in to comment.