diff --git a/src/core/operators/cluster_embeddings.py b/src/core/operators/cluster_embeddings.py new file mode 100644 index 0000000..953193f --- /dev/null +++ b/src/core/operators/cluster_embeddings.py @@ -0,0 +1,98 @@ +""" +Operator to cluster embeddings using KMeans, Affinity Propagation, and Agglomerative clustering algorithms +""" + +def initialize(param): + """ + Initializes the operator. + + Args: + param (dict): Parameters for initialization + """ + global KMeans_clustering, Agglomerative_clustering, AffinityPropagation_clustering + global gen_data + + # Imports + from sklearn.cluster import KMeans, AffinityPropagation, AgglomerativeClustering + import numpy as np + + # Constants + RANDOM_STATE = 50 + + def gen_data(labels, input_data): + """ + Generates formatted output data. + + Args: + labels (np.ndarray): An array of cluster labels + input_data (list[dict]): Operator input + + Returns: + dict: A dictionary mapping cluster labels to corresponding array of payloads + """ + out = {} + for label, item in zip(labels, input_data): + key = f'cluster_{label}' + if key not in out: + out[key] = [] + out[key].append(item['payload']) + return out + + def KMeans_clustering(matrix, n_clusters): + """ + Clusters given embeddings using KMeans clustering algorithm. + + Args: + matrix (list[list]): list of embeddings + n_clusters (int): number of clusters + + Returns: + numpy.ndarray: An array of cluster labels for each embedding + """ + return KMeans(n_clusters=n_clusters, random_state=RANDOM_STATE).fit_predict(np.array(matrix)) + + def Agglomerative_clustering(matrix, n_clusters): + """ + Clusters given embeddings using Agglomerative clustering algorithm. + + Args: + matrix (list[list]): list of embeddings + n_clusters (int): number of clusters + + Returns: + numpy.ndarray: An array of cluster labels for each embedding + """ + return AgglomerativeClustering(n_clusters=n_clusters).fit_predict(np.array(matrix)) + + def AffinityPropagation_clustering(matrix): + """ + Clusters given embeddings using Affinity Propagation algorithm (used if the number of clusters is unknown). + + Args: + matrix (list[list]): list of embeddings + + Returns: + numpy.ndarray: An array of cluster labels for each embedding + """ + return AffinityPropagation(random_state=RANDOM_STATE).fit_predict(np.array(matrix)) + +def run(embeddings, n_clusters=None, modality='audio'): + """ + Runs the operator. + + Args: + embeddings (list[dict]): List of data with each dictionary containing `embedding` and `payload` properties + n_clusters (int, optional): Number of clusters. Defaults to None. + modality (str, optional): Source modality of embeddings. Defaults to 'audio'. + """ + matrix = [data['embedding'] for data in embeddings] # isolating list of embeddings + if n_clusters: + if modality == 'audio': + labels = KMeans_clustering(matrix=matrix, n_clusters=n_clusters) + elif modality == 'visual': + labels = Agglomerative_clustering(matrix=matrix, n_clusters=n_clusters) + else: + raise ValueError("Invalid modality. Modality can only be `audio` or `visual`.") + else: + labels = AffinityPropagation_clustering(matrix=matrix) + return gen_data(labels=labels, input_data=embeddings) diff --git a/src/core/operators/cluster_embeddings_requirements.in b/src/core/operators/cluster_embeddings_requirements.in new file mode 100644 index 0000000..8dbaebb --- /dev/null +++ b/src/core/operators/cluster_embeddings_requirements.in @@ -0,0 +1,2 @@ +scikit-learn==1.5.1 +numpy==2.1.0 \ No newline at end of file diff --git a/src/core/operators/cluster_embeddings_requirements.txt b/src/core/operators/cluster_embeddings_requirements.txt new file mode 100644 index 0000000..c25b78f --- /dev/null +++ b/src/core/operators/cluster_embeddings_requirements.txt @@ -0,0 +1,129 @@ +# +# This file is autogenerated by pip-compile with Python 3.11 +# by the following command: +# +# pip-compile --allow-unsafe --generate-hashes cluster_embeddings_requirements.in +# +joblib==1.4.2 \ + --hash=sha256:06d478d5674cbc267e7496a410ee875abd68e4340feff4490bcb7afb88060ae6 \ + --hash=sha256:2382c5816b2636fbd20a09e0f4e9dad4736765fdfb7dca582943b9c1366b3f0e + # via scikit-learn +numpy==2.1.0 \ + --hash=sha256:08801848a40aea24ce16c2ecde3b756f9ad756586fb2d13210939eb69b023f5b \ + --hash=sha256:0937e54c09f7a9a68da6889362ddd2ff584c02d015ec92672c099b61555f8911 \ + --hash=sha256:0ab32eb9170bf8ffcbb14f11613f4a0b108d3ffee0832457c5d4808233ba8977 \ + --hash=sha256:0abb3916a35d9090088a748636b2c06dc9a6542f99cd476979fb156a18192b84 \ + --hash=sha256:0af3a5987f59d9c529c022c8c2a64805b339b7ef506509fba7d0556649b9714b \ + --hash=sha256:10e2350aea18d04832319aac0f887d5fcec1b36abd485d14f173e3e900b83e33 \ + --hash=sha256:15ef8b2177eeb7e37dd5ef4016f30b7659c57c2c0b57a779f1d537ff33a72c7b \ + --hash=sha256:1f817c71683fd1bb5cff1529a1d085a57f02ccd2ebc5cd2c566f9a01118e3b7d \ + --hash=sha256:24003ba8ff22ea29a8c306e61d316ac74111cebf942afbf692df65509a05f111 \ + --hash=sha256:30014b234f07b5fec20f4146f69e13cfb1e33ee9a18a1879a0142fbb00d47673 \ + --hash=sha256:343e3e152bf5a087511cd325e3b7ecfd5b92d369e80e74c12cd87826e263ec06 \ + --hash=sha256:378cb4f24c7d93066ee4103204f73ed046eb88f9ad5bb2275bb9fa0f6a02bd36 \ + --hash=sha256:398049e237d1aae53d82a416dade04defed1a47f87d18d5bd615b6e7d7e41d1f \ + --hash=sha256:3a3336fbfa0d38d3deacd3fe7f3d07e13597f29c13abf4d15c3b6dc2291cbbdd \ + --hash=sha256:442596f01913656d579309edcd179a2a2f9977d9a14ff41d042475280fc7f34e \ + --hash=sha256:44e44973262dc3ae79e9063a1284a73e09d01b894b534a769732ccd46c28cc62 \ + --hash=sha256:54139e0eb219f52f60656d163cbe67c31ede51d13236c950145473504fa208cb \ + --hash=sha256:5474dad8c86ee9ba9bb776f4b99ef2d41b3b8f4e0d199d4f7304728ed34d0300 \ + --hash=sha256:54c6a63e9d81efe64bfb7bcb0ec64332a87d0b87575f6009c8ba67ea6374770b \ + --hash=sha256:624884b572dff8ca8f60fab591413f077471de64e376b17d291b19f56504b2bb \ + --hash=sha256:6326ab99b52fafdcdeccf602d6286191a79fe2fda0ae90573c5814cd2b0bc1b8 \ + --hash=sha256:652e92fc409e278abdd61e9505649e3938f6d04ce7ef1953f2ec598a50e7c195 \ + --hash=sha256:6c1de77ded79fef664d5098a66810d4d27ca0224e9051906e634b3f7ead134c2 \ + --hash=sha256:76368c788ccb4f4782cf9c842b316140142b4cbf22ff8db82724e82fe1205dce \ + --hash=sha256:7a894c51fd8c4e834f00ac742abad73fc485df1062f1b875661a3c1e1fb1c2f6 \ + --hash=sha256:7dc90da0081f7e1da49ec4e398ede6a8e9cc4f5ebe5f9e06b443ed889ee9aaa2 \ + --hash=sha256:848c6b5cad9898e4b9ef251b6f934fa34630371f2e916261070a4eb9092ffd33 \ + --hash=sha256:899da829b362ade41e1e7eccad2cf274035e1cb36ba73034946fccd4afd8606b \ + --hash=sha256:8ab81ccd753859ab89e67199b9da62c543850f819993761c1e94a75a814ed667 \ + --hash=sha256:8fb49a0ba4d8f41198ae2d52118b050fd34dace4b8f3fb0ee34e23eb4ae775b1 \ + --hash=sha256:9156ca1f79fc4acc226696e95bfcc2b486f165a6a59ebe22b2c1f82ab190384a \ + --hash=sha256:9523f8b46485db6939bd069b28b642fec86c30909cea90ef550373787f79530e \ + --hash=sha256:a0756a179afa766ad7cb6f036de622e8a8f16ffdd55aa31f296c870b5679d745 \ + --hash=sha256:a0cdef204199278f5c461a0bed6ed2e052998276e6d8ab2963d5b5c39a0500bc \ + --hash=sha256:ab83adc099ec62e044b1fbb3a05499fa1e99f6d53a1dde102b2d85eff66ed324 \ + --hash=sha256:b34fa5e3b5d6dc7e0a4243fa0f81367027cb6f4a7215a17852979634b5544ee0 \ + --hash=sha256:b47c551c6724960479cefd7353656498b86e7232429e3a41ab83be4da1b109e8 \ + --hash=sha256:c4cd94dfefbefec3f8b544f61286584292d740e6e9d4677769bc76b8f41deb02 \ + --hash=sha256:c4f982715e65036c34897eb598d64aef15150c447be2cfc6643ec7a11af06574 \ + --hash=sha256:d8f699a709120b220dfe173f79c73cb2a2cab2c0b88dd59d7b49407d032b8ebd \ + --hash=sha256:dd94ce596bda40a9618324547cfaaf6650b1a24f5390350142499aa4e34e53d1 \ + --hash=sha256:de844aaa4815b78f6023832590d77da0e3b6805c644c33ce94a1e449f16d6ab5 \ + --hash=sha256:e5f0642cdf4636198a4990de7a71b693d824c56a757862230454629cf62e323d \ + --hash=sha256:f07fa2f15dabe91259828ce7d71b5ca9e2eb7c8c26baa822c825ce43552f4883 \ + --hash=sha256:f15976718c004466406342789f31b6673776360f3b1e3c575f25302d7e789575 \ + --hash=sha256:f358ea9e47eb3c2d6eba121ab512dfff38a88db719c38d1e67349af210bc7529 \ + --hash=sha256:f505264735ee074250a9c78247ee8618292091d9d1fcc023290e9ac67e8f1afa \ + --hash=sha256:f5ebbf9fbdabed208d4ecd2e1dfd2c0741af2f876e7ae522c2537d404ca895c3 \ + --hash=sha256:f6b26e6c3b98adb648243670fddc8cab6ae17473f9dc58c51574af3e64d61211 \ + --hash=sha256:f8e93a01a35be08d31ae33021e5268f157a2d60ebd643cfc15de6ab8e4722eb1 \ + --hash=sha256:fe76d75b345dc045acdbc006adcb197cc680754afd6c259de60d358d60c93736 \ + --hash=sha256:ffbd6faeb190aaf2b5e9024bac9622d2ee549b7ec89ef3a9373fa35313d44e0e + # via + # -r cluster_embeddings_requirements.in + # scikit-learn + # scipy +scikit-learn==1.5.1 \ + --hash=sha256:0828673c5b520e879f2af6a9e99eee0eefea69a2188be1ca68a6121b809055c1 \ + --hash=sha256:0ea5d40c0e3951df445721927448755d3fe1d80833b0b7308ebff5d2a45e6414 \ + --hash=sha256:10e49170691514a94bb2e03787aa921b82dbc507a4ea1f20fd95557862c98dc1 \ + --hash=sha256:154297ee43c0b83af12464adeab378dee2d0a700ccd03979e2b821e7dd7cc1c2 \ + --hash=sha256:161808750c267b77b4a9603cf9c93579c7a74ba8486b1336034c2f1579546d21 \ + --hash=sha256:1bd8d3a19d4bd6dc5a7d4f358c8c3a60934dc058f363c34c0ac1e9e12a31421d \ + --hash=sha256:1ff4ba34c2abff5ec59c803ed1d97d61b036f659a17f55be102679e88f926fac \ + --hash=sha256:508907e5f81390e16d754e8815f7497e52139162fd69c4fdbd2dfa5d6cc88915 \ + --hash=sha256:5944ce1faada31c55fb2ba20a5346b88e36811aab504ccafb9f0339e9f780395 \ + --hash=sha256:5f57428de0c900a98389c4a433d4a3cf89de979b3aa24d1c1d251802aa15e44d \ + --hash=sha256:689b6f74b2c880276e365fe84fe4f1befd6a774f016339c65655eaff12e10cbf \ + --hash=sha256:781586c414f8cc58e71da4f3d7af311e0505a683e112f2f62919e3019abd3745 \ + --hash=sha256:7b073a27797a283187a4ef4ee149959defc350b46cbf63a84d8514fe16b69855 \ + --hash=sha256:88e0672c7ac21eb149d409c74cc29f1d611d5158175846e7a9c2427bd12b3956 \ + --hash=sha256:909144d50f367a513cee6090873ae582dba019cb3fca063b38054fa42704c3a4 \ + --hash=sha256:97625f217c5c0c5d0505fa2af28ae424bd37949bb2f16ace3ff5f2f81fb4498b \ + --hash=sha256:9a07f90846313a7639af6a019d849ff72baadfa4c74c778821ae0fad07b7275b \ + --hash=sha256:b59e3e62d2be870e5c74af4e793293753565c7383ae82943b83383fdcf5cc5c1 \ + --hash=sha256:b5e865e9bd59396220de49cb4a57b17016256637c61b4c5cc81aaf16bc123bbe \ + --hash=sha256:da3f404e9e284d2b0a157e1b56b6566a34eb2798205cba35a211df3296ab7a74 \ + --hash=sha256:f5b213bc29cc30a89a3130393b0e39c847a15d769d6e59539cd86b75d276b1a7 + # via -r cluster_embeddings_requirements.in +scipy==1.14.1 \ + --hash=sha256:0c2f95de3b04e26f5f3ad5bb05e74ba7f68b837133a4492414b3afd79dfe540e \ + --hash=sha256:1729560c906963fc8389f6aac023739ff3983e727b1a4d87696b7bf108316a79 \ + --hash=sha256:278266012eb69f4a720827bdd2dc54b2271c97d84255b2faaa8f161a158c3b37 \ + --hash=sha256:2843f2d527d9eebec9a43e6b406fb7266f3af25a751aa91d62ff416f54170bc5 \ + --hash=sha256:2da0469a4ef0ecd3693761acbdc20f2fdeafb69e6819cc081308cc978153c675 \ + --hash=sha256:2ff0a7e01e422c15739ecd64432743cf7aae2b03f3084288f399affcefe5222d \ + --hash=sha256:2ff38e22128e6c03ff73b6bb0f85f897d2362f8c052e3b8ad00532198fbdae3f \ + --hash=sha256:30ac8812c1d2aab7131a79ba62933a2a76f582d5dbbc695192453dae67ad6310 \ + --hash=sha256:3a1b111fac6baec1c1d92f27e76511c9e7218f1695d61b59e05e0fe04dc59617 \ + --hash=sha256:4079b90df244709e675cdc8b93bfd8a395d59af40b72e339c2287c91860deb8e \ + --hash=sha256:5149e3fd2d686e42144a093b206aef01932a0059c2a33ddfa67f5f035bdfe13e \ + --hash=sha256:5a275584e726026a5699459aa72f828a610821006228e841b94275c4a7c08417 \ + --hash=sha256:631f07b3734d34aced009aaf6fedfd0eb3498a97e581c3b1e5f14a04164a456d \ + --hash=sha256:716e389b694c4bb564b4fc0c51bc84d381735e0d39d3f26ec1af2556ec6aad94 \ + --hash=sha256:8426251ad1e4ad903a4514712d2fa8fdd5382c978010d1c6f5f37ef286a713ad \ + --hash=sha256:8475230e55549ab3f207bff11ebfc91c805dc3463ef62eda3ccf593254524ce8 \ + --hash=sha256:8bddf15838ba768bb5f5083c1ea012d64c9a444e16192762bd858f1e126196d0 \ + --hash=sha256:8e32dced201274bf96899e6491d9ba3e9a5f6b336708656466ad0522d8528f69 \ + --hash=sha256:8f9ea80f2e65bdaa0b7627fb00cbeb2daf163caa015e59b7516395fe3bd1e066 \ + --hash=sha256:97c5dddd5932bd2a1a31c927ba5e1463a53b87ca96b5c9bdf5dfd6096e27efc3 \ + --hash=sha256:a49f6ed96f83966f576b33a44257d869756df6cf1ef4934f59dd58b25e0327e5 \ + --hash=sha256:af29a935803cc707ab2ed7791c44288a682f9c8107bc00f0eccc4f92c08d6e07 \ + --hash=sha256:b05d43735bb2f07d689f56f7b474788a13ed8adc484a85aa65c0fd931cf9ccd2 \ + --hash=sha256:b28d2ca4add7ac16ae8bb6632a3c86e4b9e4d52d3e34267f6e1b0c1f8d87e389 \ + --hash=sha256:b99722ea48b7ea25e8e015e8341ae74624f72e5f21fc2abd45f3a93266de4c5d \ + --hash=sha256:baff393942b550823bfce952bb62270ee17504d02a1801d7fd0719534dfb9c84 \ + --hash=sha256:c0ee987efa6737242745f347835da2cc5bb9f1b42996a4d97d5c7ff7928cb6f2 \ + --hash=sha256:d0d2821003174de06b69e58cef2316a6622b60ee613121199cb2852a873f8cf3 \ + --hash=sha256:e0cf28db0f24a38b2a0ca33a85a54852586e43cf6fd876365c86e0657cfe7d73 \ + --hash=sha256:e4f5a7c49323533f9103d4dacf4e4f07078f360743dec7f7596949149efeec06 \ + --hash=sha256:eb58ca0abd96911932f688528977858681a59d61a7ce908ffd355957f7025cfc \ + --hash=sha256:edaf02b82cd7639db00dbff629995ef185c8df4c3ffa71a5562a595765a06ce1 \ + --hash=sha256:fef8c87f8abfb884dac04e97824b61299880c43f4ce675dd2cbeadd3c9b466d2 + # via scikit-learn +threadpoolctl==3.5.0 \ + --hash=sha256:082433502dd922bf738de0d8bcc4fdcbf0979ff44c42bd40f5af8a282f6fa107 \ + --hash=sha256:56c1e26c150397e58c4926da8eeee87533b1e32bef131bd4bf6a2f45f3185467 + # via scikit-learn diff --git a/src/core/operators/test_cluster_embeddings.py b/src/core/operators/test_cluster_embeddings.py new file mode 100644 index 0000000..6be9c6a --- /dev/null +++ b/src/core/operators/test_cluster_embeddings.py @@ -0,0 +1,44 @@ +import unittest +from src.core.operators import cluster_embeddings + +# Test constants +EMBEDDINGS = [ + {"payload": "A", "embedding": [0, 1]}, + {"payload": "B", "embedding": [1, 0]}, + {"payload": "C", "embedding": [100, 101]}, + {"payload": "D", "embedding": [101, 100]} +] +EXPECTED_CLUSTERS = [["A", "B"], ["C", "D"]] + +class Test(unittest.TestCase): + @classmethod + def setUpClass(cls): + # initialize operator + param = {} + cluster_embeddings.initialize(param) + + @classmethod + def tearDownClass(cls): + # delete config files + pass + + def test_kmeans_clustering(self): + result = cluster_embeddings.run(embeddings=EMBEDDINGS, n_clusters=2, modality="audio") + self.assertIn("cluster_0", result) + self.assertIn("cluster_1", result) + self.assertEqual(len(result), 2) + self.assertCountEqual([result["cluster_0"], result["cluster_1"]], EXPECTED_CLUSTERS) + + def test_agglomerative_clustering(self): + result = cluster_embeddings.run(embeddings=EMBEDDINGS, n_clusters=2, modality="visual") + self.assertIn("cluster_0", result) + self.assertIn("cluster_1", result) + self.assertEqual(len(result), 2) + self.assertCountEqual([result["cluster_0"], result["cluster_1"]], EXPECTED_CLUSTERS) + + def test_affinity_propagation(self): + result = cluster_embeddings.run(embeddings=EMBEDDINGS, n_clusters=None, modality="audio") + self.assertIn("cluster_0", result) + self.assertIn("cluster_1", result) + self.assertEqual(len(result), 2) + self.assertCountEqual([result["cluster_0"], result["cluster_1"]], EXPECTED_CLUSTERS)