Skip to content

Commit

Permalink
feat: add cluster_embeddings operator
Browse files Browse the repository at this point in the history
Signed-off-by: Snehil Shah <[email protected]>
  • Loading branch information
Snehil-Shah committed Aug 25, 2024
1 parent faa9727 commit 3f04b83
Show file tree
Hide file tree
Showing 4 changed files with 273 additions and 0 deletions.
98 changes: 98 additions & 0 deletions src/core/operators/cluster_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
"""
Operator to cluster embeddings using KMeans, Affinity Propagation, and Agglomerative clustering algorithms
"""

def initialize(param):
"""
Initializes the operator.
Args:
param (dict): Parameters for initialization
"""
global KMeans_clustering, Agglomerative_clustering, AffinityPropagation_clustering
global gen_data

# Imports
from sklearn.cluster import KMeans, AffinityPropagation, AgglomerativeClustering
import numpy as np

# Constants
RANDOM_STATE = 50

def gen_data(labels, input_data):
"""
Generates formatted output data.
Args:
labels (np.ndarray): An array of cluster labels
input_data (list[dict]): Operator input
Returns:
dict: A dictionary mapping cluster labels to corresponding array of payloads
"""
out = {}
for label, item in zip(labels, input_data):
key = f'cluster_{label}'
if key not in out:
out[key] = []
out[key].append(item['payload'])
return out

def KMeans_clustering(matrix, n_clusters):
"""
Clusters given embeddings using KMeans clustering algorithm.
Args:
matrix (list[list]): list of embeddings
n_clusters (int): number of clusters
Returns:
numpy.ndarray: An array of cluster labels for each embedding
"""
return KMeans(n_clusters=n_clusters, random_state=RANDOM_STATE).fit_predict(np.array(matrix))

def Agglomerative_clustering(matrix, n_clusters):
"""
Clusters given embeddings using Agglomerative clustering algorithm.
Args:
matrix (list[list]): list of embeddings
n_clusters (int): number of clusters
Returns:
numpy.ndarray: An array of cluster labels for each embedding
"""
return AgglomerativeClustering(n_clusters=n_clusters).fit_predict(np.array(matrix))

def AffinityPropagation_clustering(matrix):
"""
Clusters given embeddings using Affinity Propagation algorithm (used if the number of clusters is unknown).
Args:
matrix (list[list]): list of embeddings
Returns:
numpy.ndarray: An array of cluster labels for each embedding
"""
return AffinityPropagation(random_state=RANDOM_STATE).fit_predict(np.array(matrix))

def run(embeddings, n_clusters=None, modality='audio'):
"""
Runs the operator.
Args:
embeddings (list[dict]): List of data with each dictionary containing `embedding` and `payload` properties
n_clusters (int, optional): Number of clusters. Defaults to None.
modality (str, optional): Source modality of embeddings. Defaults to 'audio'.
"""
matrix = [data['embedding'] for data in embeddings] # isolating list of embeddings
if n_clusters:
if modality == 'audio':
labels = KMeans_clustering(matrix=matrix, n_clusters=n_clusters)
elif modality == 'visual':
labels = Agglomerative_clustering(matrix=matrix, n_clusters=n_clusters)
else:
raise ValueError("Invalid modality. Modality can only be `audio` or `visual`.")
else:
labels = AffinityPropagation_clustering(matrix=matrix)
return gen_data(labels=labels, input_data=embeddings)
2 changes: 2 additions & 0 deletions src/core/operators/cluster_embeddings_requirements.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
scikit-learn==1.5.1
numpy==2.1.0
129 changes: 129 additions & 0 deletions src/core/operators/cluster_embeddings_requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
#
# This file is autogenerated by pip-compile with Python 3.11
# by the following command:
#
# pip-compile --allow-unsafe --generate-hashes cluster_embeddings_requirements.in
#
joblib==1.4.2 \
--hash=sha256:06d478d5674cbc267e7496a410ee875abd68e4340feff4490bcb7afb88060ae6 \
--hash=sha256:2382c5816b2636fbd20a09e0f4e9dad4736765fdfb7dca582943b9c1366b3f0e
# via scikit-learn
numpy==2.1.0 \
--hash=sha256:08801848a40aea24ce16c2ecde3b756f9ad756586fb2d13210939eb69b023f5b \
--hash=sha256:0937e54c09f7a9a68da6889362ddd2ff584c02d015ec92672c099b61555f8911 \
--hash=sha256:0ab32eb9170bf8ffcbb14f11613f4a0b108d3ffee0832457c5d4808233ba8977 \
--hash=sha256:0abb3916a35d9090088a748636b2c06dc9a6542f99cd476979fb156a18192b84 \
--hash=sha256:0af3a5987f59d9c529c022c8c2a64805b339b7ef506509fba7d0556649b9714b \
--hash=sha256:10e2350aea18d04832319aac0f887d5fcec1b36abd485d14f173e3e900b83e33 \
--hash=sha256:15ef8b2177eeb7e37dd5ef4016f30b7659c57c2c0b57a779f1d537ff33a72c7b \
--hash=sha256:1f817c71683fd1bb5cff1529a1d085a57f02ccd2ebc5cd2c566f9a01118e3b7d \
--hash=sha256:24003ba8ff22ea29a8c306e61d316ac74111cebf942afbf692df65509a05f111 \
--hash=sha256:30014b234f07b5fec20f4146f69e13cfb1e33ee9a18a1879a0142fbb00d47673 \
--hash=sha256:343e3e152bf5a087511cd325e3b7ecfd5b92d369e80e74c12cd87826e263ec06 \
--hash=sha256:378cb4f24c7d93066ee4103204f73ed046eb88f9ad5bb2275bb9fa0f6a02bd36 \
--hash=sha256:398049e237d1aae53d82a416dade04defed1a47f87d18d5bd615b6e7d7e41d1f \
--hash=sha256:3a3336fbfa0d38d3deacd3fe7f3d07e13597f29c13abf4d15c3b6dc2291cbbdd \
--hash=sha256:442596f01913656d579309edcd179a2a2f9977d9a14ff41d042475280fc7f34e \
--hash=sha256:44e44973262dc3ae79e9063a1284a73e09d01b894b534a769732ccd46c28cc62 \
--hash=sha256:54139e0eb219f52f60656d163cbe67c31ede51d13236c950145473504fa208cb \
--hash=sha256:5474dad8c86ee9ba9bb776f4b99ef2d41b3b8f4e0d199d4f7304728ed34d0300 \
--hash=sha256:54c6a63e9d81efe64bfb7bcb0ec64332a87d0b87575f6009c8ba67ea6374770b \
--hash=sha256:624884b572dff8ca8f60fab591413f077471de64e376b17d291b19f56504b2bb \
--hash=sha256:6326ab99b52fafdcdeccf602d6286191a79fe2fda0ae90573c5814cd2b0bc1b8 \
--hash=sha256:652e92fc409e278abdd61e9505649e3938f6d04ce7ef1953f2ec598a50e7c195 \
--hash=sha256:6c1de77ded79fef664d5098a66810d4d27ca0224e9051906e634b3f7ead134c2 \
--hash=sha256:76368c788ccb4f4782cf9c842b316140142b4cbf22ff8db82724e82fe1205dce \
--hash=sha256:7a894c51fd8c4e834f00ac742abad73fc485df1062f1b875661a3c1e1fb1c2f6 \
--hash=sha256:7dc90da0081f7e1da49ec4e398ede6a8e9cc4f5ebe5f9e06b443ed889ee9aaa2 \
--hash=sha256:848c6b5cad9898e4b9ef251b6f934fa34630371f2e916261070a4eb9092ffd33 \
--hash=sha256:899da829b362ade41e1e7eccad2cf274035e1cb36ba73034946fccd4afd8606b \
--hash=sha256:8ab81ccd753859ab89e67199b9da62c543850f819993761c1e94a75a814ed667 \
--hash=sha256:8fb49a0ba4d8f41198ae2d52118b050fd34dace4b8f3fb0ee34e23eb4ae775b1 \
--hash=sha256:9156ca1f79fc4acc226696e95bfcc2b486f165a6a59ebe22b2c1f82ab190384a \
--hash=sha256:9523f8b46485db6939bd069b28b642fec86c30909cea90ef550373787f79530e \
--hash=sha256:a0756a179afa766ad7cb6f036de622e8a8f16ffdd55aa31f296c870b5679d745 \
--hash=sha256:a0cdef204199278f5c461a0bed6ed2e052998276e6d8ab2963d5b5c39a0500bc \
--hash=sha256:ab83adc099ec62e044b1fbb3a05499fa1e99f6d53a1dde102b2d85eff66ed324 \
--hash=sha256:b34fa5e3b5d6dc7e0a4243fa0f81367027cb6f4a7215a17852979634b5544ee0 \
--hash=sha256:b47c551c6724960479cefd7353656498b86e7232429e3a41ab83be4da1b109e8 \
--hash=sha256:c4cd94dfefbefec3f8b544f61286584292d740e6e9d4677769bc76b8f41deb02 \
--hash=sha256:c4f982715e65036c34897eb598d64aef15150c447be2cfc6643ec7a11af06574 \
--hash=sha256:d8f699a709120b220dfe173f79c73cb2a2cab2c0b88dd59d7b49407d032b8ebd \
--hash=sha256:dd94ce596bda40a9618324547cfaaf6650b1a24f5390350142499aa4e34e53d1 \
--hash=sha256:de844aaa4815b78f6023832590d77da0e3b6805c644c33ce94a1e449f16d6ab5 \
--hash=sha256:e5f0642cdf4636198a4990de7a71b693d824c56a757862230454629cf62e323d \
--hash=sha256:f07fa2f15dabe91259828ce7d71b5ca9e2eb7c8c26baa822c825ce43552f4883 \
--hash=sha256:f15976718c004466406342789f31b6673776360f3b1e3c575f25302d7e789575 \
--hash=sha256:f358ea9e47eb3c2d6eba121ab512dfff38a88db719c38d1e67349af210bc7529 \
--hash=sha256:f505264735ee074250a9c78247ee8618292091d9d1fcc023290e9ac67e8f1afa \
--hash=sha256:f5ebbf9fbdabed208d4ecd2e1dfd2c0741af2f876e7ae522c2537d404ca895c3 \
--hash=sha256:f6b26e6c3b98adb648243670fddc8cab6ae17473f9dc58c51574af3e64d61211 \
--hash=sha256:f8e93a01a35be08d31ae33021e5268f157a2d60ebd643cfc15de6ab8e4722eb1 \
--hash=sha256:fe76d75b345dc045acdbc006adcb197cc680754afd6c259de60d358d60c93736 \
--hash=sha256:ffbd6faeb190aaf2b5e9024bac9622d2ee549b7ec89ef3a9373fa35313d44e0e
# via
# -r cluster_embeddings_requirements.in
# scikit-learn
# scipy
scikit-learn==1.5.1 \
--hash=sha256:0828673c5b520e879f2af6a9e99eee0eefea69a2188be1ca68a6121b809055c1 \
--hash=sha256:0ea5d40c0e3951df445721927448755d3fe1d80833b0b7308ebff5d2a45e6414 \
--hash=sha256:10e49170691514a94bb2e03787aa921b82dbc507a4ea1f20fd95557862c98dc1 \
--hash=sha256:154297ee43c0b83af12464adeab378dee2d0a700ccd03979e2b821e7dd7cc1c2 \
--hash=sha256:161808750c267b77b4a9603cf9c93579c7a74ba8486b1336034c2f1579546d21 \
--hash=sha256:1bd8d3a19d4bd6dc5a7d4f358c8c3a60934dc058f363c34c0ac1e9e12a31421d \
--hash=sha256:1ff4ba34c2abff5ec59c803ed1d97d61b036f659a17f55be102679e88f926fac \
--hash=sha256:508907e5f81390e16d754e8815f7497e52139162fd69c4fdbd2dfa5d6cc88915 \
--hash=sha256:5944ce1faada31c55fb2ba20a5346b88e36811aab504ccafb9f0339e9f780395 \
--hash=sha256:5f57428de0c900a98389c4a433d4a3cf89de979b3aa24d1c1d251802aa15e44d \
--hash=sha256:689b6f74b2c880276e365fe84fe4f1befd6a774f016339c65655eaff12e10cbf \
--hash=sha256:781586c414f8cc58e71da4f3d7af311e0505a683e112f2f62919e3019abd3745 \
--hash=sha256:7b073a27797a283187a4ef4ee149959defc350b46cbf63a84d8514fe16b69855 \
--hash=sha256:88e0672c7ac21eb149d409c74cc29f1d611d5158175846e7a9c2427bd12b3956 \
--hash=sha256:909144d50f367a513cee6090873ae582dba019cb3fca063b38054fa42704c3a4 \
--hash=sha256:97625f217c5c0c5d0505fa2af28ae424bd37949bb2f16ace3ff5f2f81fb4498b \
--hash=sha256:9a07f90846313a7639af6a019d849ff72baadfa4c74c778821ae0fad07b7275b \
--hash=sha256:b59e3e62d2be870e5c74af4e793293753565c7383ae82943b83383fdcf5cc5c1 \
--hash=sha256:b5e865e9bd59396220de49cb4a57b17016256637c61b4c5cc81aaf16bc123bbe \
--hash=sha256:da3f404e9e284d2b0a157e1b56b6566a34eb2798205cba35a211df3296ab7a74 \
--hash=sha256:f5b213bc29cc30a89a3130393b0e39c847a15d769d6e59539cd86b75d276b1a7
# via -r cluster_embeddings_requirements.in
scipy==1.14.1 \
--hash=sha256:0c2f95de3b04e26f5f3ad5bb05e74ba7f68b837133a4492414b3afd79dfe540e \
--hash=sha256:1729560c906963fc8389f6aac023739ff3983e727b1a4d87696b7bf108316a79 \
--hash=sha256:278266012eb69f4a720827bdd2dc54b2271c97d84255b2faaa8f161a158c3b37 \
--hash=sha256:2843f2d527d9eebec9a43e6b406fb7266f3af25a751aa91d62ff416f54170bc5 \
--hash=sha256:2da0469a4ef0ecd3693761acbdc20f2fdeafb69e6819cc081308cc978153c675 \
--hash=sha256:2ff0a7e01e422c15739ecd64432743cf7aae2b03f3084288f399affcefe5222d \
--hash=sha256:2ff38e22128e6c03ff73b6bb0f85f897d2362f8c052e3b8ad00532198fbdae3f \
--hash=sha256:30ac8812c1d2aab7131a79ba62933a2a76f582d5dbbc695192453dae67ad6310 \
--hash=sha256:3a1b111fac6baec1c1d92f27e76511c9e7218f1695d61b59e05e0fe04dc59617 \
--hash=sha256:4079b90df244709e675cdc8b93bfd8a395d59af40b72e339c2287c91860deb8e \
--hash=sha256:5149e3fd2d686e42144a093b206aef01932a0059c2a33ddfa67f5f035bdfe13e \
--hash=sha256:5a275584e726026a5699459aa72f828a610821006228e841b94275c4a7c08417 \
--hash=sha256:631f07b3734d34aced009aaf6fedfd0eb3498a97e581c3b1e5f14a04164a456d \
--hash=sha256:716e389b694c4bb564b4fc0c51bc84d381735e0d39d3f26ec1af2556ec6aad94 \
--hash=sha256:8426251ad1e4ad903a4514712d2fa8fdd5382c978010d1c6f5f37ef286a713ad \
--hash=sha256:8475230e55549ab3f207bff11ebfc91c805dc3463ef62eda3ccf593254524ce8 \
--hash=sha256:8bddf15838ba768bb5f5083c1ea012d64c9a444e16192762bd858f1e126196d0 \
--hash=sha256:8e32dced201274bf96899e6491d9ba3e9a5f6b336708656466ad0522d8528f69 \
--hash=sha256:8f9ea80f2e65bdaa0b7627fb00cbeb2daf163caa015e59b7516395fe3bd1e066 \
--hash=sha256:97c5dddd5932bd2a1a31c927ba5e1463a53b87ca96b5c9bdf5dfd6096e27efc3 \
--hash=sha256:a49f6ed96f83966f576b33a44257d869756df6cf1ef4934f59dd58b25e0327e5 \
--hash=sha256:af29a935803cc707ab2ed7791c44288a682f9c8107bc00f0eccc4f92c08d6e07 \
--hash=sha256:b05d43735bb2f07d689f56f7b474788a13ed8adc484a85aa65c0fd931cf9ccd2 \
--hash=sha256:b28d2ca4add7ac16ae8bb6632a3c86e4b9e4d52d3e34267f6e1b0c1f8d87e389 \
--hash=sha256:b99722ea48b7ea25e8e015e8341ae74624f72e5f21fc2abd45f3a93266de4c5d \
--hash=sha256:baff393942b550823bfce952bb62270ee17504d02a1801d7fd0719534dfb9c84 \
--hash=sha256:c0ee987efa6737242745f347835da2cc5bb9f1b42996a4d97d5c7ff7928cb6f2 \
--hash=sha256:d0d2821003174de06b69e58cef2316a6622b60ee613121199cb2852a873f8cf3 \
--hash=sha256:e0cf28db0f24a38b2a0ca33a85a54852586e43cf6fd876365c86e0657cfe7d73 \
--hash=sha256:e4f5a7c49323533f9103d4dacf4e4f07078f360743dec7f7596949149efeec06 \
--hash=sha256:eb58ca0abd96911932f688528977858681a59d61a7ce908ffd355957f7025cfc \
--hash=sha256:edaf02b82cd7639db00dbff629995ef185c8df4c3ffa71a5562a595765a06ce1 \
--hash=sha256:fef8c87f8abfb884dac04e97824b61299880c43f4ce675dd2cbeadd3c9b466d2
# via scikit-learn
threadpoolctl==3.5.0 \
--hash=sha256:082433502dd922bf738de0d8bcc4fdcbf0979ff44c42bd40f5af8a282f6fa107 \
--hash=sha256:56c1e26c150397e58c4926da8eeee87533b1e32bef131bd4bf6a2f45f3185467
# via scikit-learn
44 changes: 44 additions & 0 deletions src/core/operators/test_cluster_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import unittest
from src.core.operators import cluster_embeddings

# Test constants
EMBEDDINGS = [
{"payload": "A", "embedding": [0, 1]},
{"payload": "B", "embedding": [1, 0]},
{"payload": "C", "embedding": [100, 101]},
{"payload": "D", "embedding": [101, 100]}
]
EXPECTED_CLUSTERS = [["A", "B"], ["C", "D"]]

class Test(unittest.TestCase):
@classmethod
def setUpClass(cls):
# initialize operator
param = {}
cluster_embeddings.initialize(param)

@classmethod
def tearDownClass(cls):
# delete config files
pass

def test_kmeans_clustering(self):
result = cluster_embeddings.run(embeddings=EMBEDDINGS, n_clusters=2, modality="audio")
self.assertIn("cluster_0", result)
self.assertIn("cluster_1", result)
self.assertEqual(len(result), 2)
self.assertCountEqual([result["cluster_0"], result["cluster_1"]], EXPECTED_CLUSTERS)

def test_agglomerative_clustering(self):
result = cluster_embeddings.run(embeddings=EMBEDDINGS, n_clusters=2, modality="visual")
self.assertIn("cluster_0", result)
self.assertIn("cluster_1", result)
self.assertEqual(len(result), 2)
self.assertCountEqual([result["cluster_0"], result["cluster_1"]], EXPECTED_CLUSTERS)

def test_affinity_propagation(self):
result = cluster_embeddings.run(embeddings=EMBEDDINGS, n_clusters=None, modality="audio")
self.assertIn("cluster_0", result)
self.assertIn("cluster_1", result)
self.assertEqual(len(result), 2)
self.assertCountEqual([result["cluster_0"], result["cluster_1"]], EXPECTED_CLUSTERS)

0 comments on commit 3f04b83

Please sign in to comment.