Skip to content

Commit

Permalink
refactor: change modality to video
Browse files Browse the repository at this point in the history
Signed-off-by: Snehil Shah <[email protected]>
  • Loading branch information
Snehil-Shah committed Aug 27, 2024
1 parent 5756a59 commit a1506ba
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
6 changes: 3 additions & 3 deletions src/core/operators/cluster_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def run(input_data, n_clusters=None, modality='audio'):
dict: A dictionary mapping cluster labels to corresponding array of payloads
Raises:
ValueError: Modality should be either `audio` or `visual`
ValueError: Modality should be either `audio` or `video`
KeyError: Each data point in input must have `embedding` and `payload` properties
"""
# Parse data:
Expand All @@ -103,10 +103,10 @@ def run(input_data, n_clusters=None, modality='audio'):
n_clusters = int(n_clusters) # cast it to int
if modality == 'audio':
labels = KMeans_clustering(matrix=matrix, n_clusters=n_clusters)
elif modality == 'visual':
elif modality == 'video':
labels = Agglomerative_clustering(matrix=matrix, n_clusters=n_clusters)
else:
raise ValueError("Invalid modality. Modality should be either `audio` or `visual`.")
raise ValueError("Invalid modality. Modality should be either `audio` or `video`.")
else:
labels = AffinityPropagation_clustering(matrix=matrix)
return gen_data(payloads=payloads, labels=labels) # format output
2 changes: 1 addition & 1 deletion src/core/operators/test_cluster_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def test_kmeans_clustering(self):
self.assertCountEqual([result["cluster_0"], result["cluster_1"]], EXPECTED_CLUSTERS)

def test_agglomerative_clustering(self):
result = cluster_embeddings.run(input_data=MOCK_DATA, n_clusters=2, modality="visual")
result = cluster_embeddings.run(input_data=MOCK_DATA, n_clusters=2, modality="video")
self.assertIn("cluster_0", result)
self.assertIn("cluster_1", result)
self.assertEqual(len(result), 2)
Expand Down

0 comments on commit a1506ba

Please sign in to comment.