diff --git a/src/core/operators/dimension_reduction.py b/src/core/operators/dimension_reduction.py new file mode 100644 index 0000000..45ca3f0 --- /dev/null +++ b/src/core/operators/dimension_reduction.py @@ -0,0 +1,179 @@ +"""Operator to perform dimensionality reduction given the embedddings.""" + +from abc import ABC, abstractmethod +from sklearn.manifold import TSNE +import numpy as np + + +class DimensionReduction(ABC): + """Abstract base class for dimension reduction techniques.""" + + @abstractmethod + def initialize(self, params): + pass + + @abstractmethod + def run(self, embeddings): + pass + + +class TSNEReduction(DimensionReduction): + """t-SNE implementation of the DimensionReduction abstract class.""" + + def initialize(self, params): + """ + Initialize the t-SNE model with parameters. + + Args: + params (dict): A dictionary containing t-SNE parameters such as: + - n_components (int): Number of dimensions to reduce to. Default is 2. + - perplexity (float): Perplexity parameter for t-SNE. Default is 30. + - learning_rate (float): Learning rate for t-SNE. Default is 150. + - n_iter (int): Number of iterations for optimization. Default is 1000. + - random_state (int): Seed for random number generation. Default is 42. + + Raises: + ValueError: If the t-SNE model fails to initialize. + """ + try: + self.model = TSNE( + n_components=params.get('n_components', 2), + perplexity=params.get('perplexity', 30), + learning_rate=params.get('learning_rate', 150), + max_iter=params.get('max_iter', 1000), + random_state=params.get('random_state', 42), + method=params.get('method', 'barnes_hut') + ) + print("t-SNE model successfully initialized") + except Exception as e: + raise ValueError(f"Failed to initialize t-SNE model: {e}") + + def run(self, embeddings_array): + """ + Apply the t-SNE model to reduce the dimensionality of embeddings. + + Args: + embeddings (list or numpy.ndarray): A list or array of embeddings to be reduced. + + Returns: + numpy.ndarray: The reduced embeddings as a 2D array. + + Raises: + ValueError: If the embeddings input is not a 2D array. + RuntimeError: If the t-SNE reduction fails. + """ + try: + if embeddings_array.ndim != 2: + raise ValueError("Embeddings should be a 2D array.") + return self.model.fit_transform(embeddings_array) + except Exception as e: + raise RuntimeError(f"t-SNE reduction failed: {e}") + + +class DimensionReductionFactory: + """Factory class for creating dimension reduction models.""" + + @staticmethod + def get_reduction_model(model_type): + """ + Factory method to create a dimension reduction model based on type. + + Args: + model_type (str): String indicating the type of model (e.g., 'tsne'). + + Returns: + DimensionReduction: An instance of the corresponding dimension reduction model. + + Raises: + ValueError: If the specified model type is unsupported. + """ + if model_type.lower() == 'tsne': + return TSNEReduction() + else: + raise ValueError(f"Unsupported model type: {model_type}") + + +def gen_data(payloads, reduced_embeddings): + """ + Generates the formatted output. + + Args: + payloads (list): List of paylods. + reduced_embeddings (nd.array): An array of reduced embeddings. + + Returns: + list: A list of dictionaries containing the payload and corresponding embedding. + """ + out = [] + + for payload, reduced_embedding in zip(payloads, reduced_embeddings): + tmp_dict = {} + tmp_dict['payload'] = payload + tmp_dict['reduced_embedding'] = reduced_embedding + out.append(tmp_dict) + return out + + +def setup_reduction(model_type, params): + """ + Initialize the dimension reduction model with provided type and parameters. + + Args: + model_type (str): String indicating the type of model (e.g., 'tsne'). + params (dict): Dictionary of parameters for the model initialization. + + """ + global reduction_model + reduction_model = DimensionReductionFactory.get_reduction_model(model_type) + reduction_model.initialize(params) + + +def perform_reduction(input_data): + """ + Reduce the dimensionality of the provided embeddings using the initialized model. + + Args: + input_data (list): A list of dictionaries containing payload and embeddings to be reduced. + Example: + + [ + { + "payload": "123", + "embedding": [1, 2, 3] + }, + { + "payload": "124", + "embedding": [1, 0, 1] + } + ] + + Returns: + list: The reduced embeddings and the corresponding payload as a list of dictionaries. + Example: + + [ + { + "payload":"123", + "reduced_embedding": [1, 2] + }, + { + "payload": "124", + "reduced_embedding": [1, 0] + } + ] + + Raises: + ValueError: If the embeddings input is not a non-empty list. + KeyError: If the input data is invalid. + """ + if not isinstance(input_data, list) or len(input_data) == 0: + raise ValueError("Input should be a non-empty list.") + + try: + embeddings, payloads = zip(*[(data['embedding'], data['payload']) for data in input_data]) + except KeyError as e: + raise KeyError(f"Invalid data. Each data point in input must have `embedding` and `payload` properties. Missing key: {e}.") + + reduced_embeddings = reduction_model.run(np.array(embeddings)) + + return gen_data(payloads, reduced_embeddings) \ No newline at end of file diff --git a/src/core/operators/dimension_reduction_requirements.in b/src/core/operators/dimension_reduction_requirements.in new file mode 100644 index 0000000..8dbaebb --- /dev/null +++ b/src/core/operators/dimension_reduction_requirements.in @@ -0,0 +1,2 @@ +scikit-learn==1.5.1 +numpy==2.1.0 \ No newline at end of file diff --git a/src/core/operators/dimension_reduction_requirements.txt b/src/core/operators/dimension_reduction_requirements.txt new file mode 100644 index 0000000..a819217 --- /dev/null +++ b/src/core/operators/dimension_reduction_requirements.txt @@ -0,0 +1,129 @@ +# +# This file is autogenerated by pip-compile with Python 3.10 +# by the following command: +# +# pip-compile --allow-unsafe --generate-hashes core/operators/dimension_reduction_requirements.in +# +joblib==1.4.2 \ + --hash=sha256:06d478d5674cbc267e7496a410ee875abd68e4340feff4490bcb7afb88060ae6 \ + --hash=sha256:2382c5816b2636fbd20a09e0f4e9dad4736765fdfb7dca582943b9c1366b3f0e + # via scikit-learn +numpy==2.1.0 \ + --hash=sha256:08801848a40aea24ce16c2ecde3b756f9ad756586fb2d13210939eb69b023f5b \ + --hash=sha256:0937e54c09f7a9a68da6889362ddd2ff584c02d015ec92672c099b61555f8911 \ + --hash=sha256:0ab32eb9170bf8ffcbb14f11613f4a0b108d3ffee0832457c5d4808233ba8977 \ + --hash=sha256:0abb3916a35d9090088a748636b2c06dc9a6542f99cd476979fb156a18192b84 \ + --hash=sha256:0af3a5987f59d9c529c022c8c2a64805b339b7ef506509fba7d0556649b9714b \ + --hash=sha256:10e2350aea18d04832319aac0f887d5fcec1b36abd485d14f173e3e900b83e33 \ + --hash=sha256:15ef8b2177eeb7e37dd5ef4016f30b7659c57c2c0b57a779f1d537ff33a72c7b \ + --hash=sha256:1f817c71683fd1bb5cff1529a1d085a57f02ccd2ebc5cd2c566f9a01118e3b7d \ + --hash=sha256:24003ba8ff22ea29a8c306e61d316ac74111cebf942afbf692df65509a05f111 \ + --hash=sha256:30014b234f07b5fec20f4146f69e13cfb1e33ee9a18a1879a0142fbb00d47673 \ + --hash=sha256:343e3e152bf5a087511cd325e3b7ecfd5b92d369e80e74c12cd87826e263ec06 \ + --hash=sha256:378cb4f24c7d93066ee4103204f73ed046eb88f9ad5bb2275bb9fa0f6a02bd36 \ + --hash=sha256:398049e237d1aae53d82a416dade04defed1a47f87d18d5bd615b6e7d7e41d1f \ + --hash=sha256:3a3336fbfa0d38d3deacd3fe7f3d07e13597f29c13abf4d15c3b6dc2291cbbdd \ + --hash=sha256:442596f01913656d579309edcd179a2a2f9977d9a14ff41d042475280fc7f34e \ + --hash=sha256:44e44973262dc3ae79e9063a1284a73e09d01b894b534a769732ccd46c28cc62 \ + --hash=sha256:54139e0eb219f52f60656d163cbe67c31ede51d13236c950145473504fa208cb \ + --hash=sha256:5474dad8c86ee9ba9bb776f4b99ef2d41b3b8f4e0d199d4f7304728ed34d0300 \ + --hash=sha256:54c6a63e9d81efe64bfb7bcb0ec64332a87d0b87575f6009c8ba67ea6374770b \ + --hash=sha256:624884b572dff8ca8f60fab591413f077471de64e376b17d291b19f56504b2bb \ + --hash=sha256:6326ab99b52fafdcdeccf602d6286191a79fe2fda0ae90573c5814cd2b0bc1b8 \ + --hash=sha256:652e92fc409e278abdd61e9505649e3938f6d04ce7ef1953f2ec598a50e7c195 \ + --hash=sha256:6c1de77ded79fef664d5098a66810d4d27ca0224e9051906e634b3f7ead134c2 \ + --hash=sha256:76368c788ccb4f4782cf9c842b316140142b4cbf22ff8db82724e82fe1205dce \ + --hash=sha256:7a894c51fd8c4e834f00ac742abad73fc485df1062f1b875661a3c1e1fb1c2f6 \ + --hash=sha256:7dc90da0081f7e1da49ec4e398ede6a8e9cc4f5ebe5f9e06b443ed889ee9aaa2 \ + --hash=sha256:848c6b5cad9898e4b9ef251b6f934fa34630371f2e916261070a4eb9092ffd33 \ + --hash=sha256:899da829b362ade41e1e7eccad2cf274035e1cb36ba73034946fccd4afd8606b \ + --hash=sha256:8ab81ccd753859ab89e67199b9da62c543850f819993761c1e94a75a814ed667 \ + --hash=sha256:8fb49a0ba4d8f41198ae2d52118b050fd34dace4b8f3fb0ee34e23eb4ae775b1 \ + --hash=sha256:9156ca1f79fc4acc226696e95bfcc2b486f165a6a59ebe22b2c1f82ab190384a \ + --hash=sha256:9523f8b46485db6939bd069b28b642fec86c30909cea90ef550373787f79530e \ + --hash=sha256:a0756a179afa766ad7cb6f036de622e8a8f16ffdd55aa31f296c870b5679d745 \ + --hash=sha256:a0cdef204199278f5c461a0bed6ed2e052998276e6d8ab2963d5b5c39a0500bc \ + --hash=sha256:ab83adc099ec62e044b1fbb3a05499fa1e99f6d53a1dde102b2d85eff66ed324 \ + --hash=sha256:b34fa5e3b5d6dc7e0a4243fa0f81367027cb6f4a7215a17852979634b5544ee0 \ + --hash=sha256:b47c551c6724960479cefd7353656498b86e7232429e3a41ab83be4da1b109e8 \ + --hash=sha256:c4cd94dfefbefec3f8b544f61286584292d740e6e9d4677769bc76b8f41deb02 \ + --hash=sha256:c4f982715e65036c34897eb598d64aef15150c447be2cfc6643ec7a11af06574 \ + --hash=sha256:d8f699a709120b220dfe173f79c73cb2a2cab2c0b88dd59d7b49407d032b8ebd \ + --hash=sha256:dd94ce596bda40a9618324547cfaaf6650b1a24f5390350142499aa4e34e53d1 \ + --hash=sha256:de844aaa4815b78f6023832590d77da0e3b6805c644c33ce94a1e449f16d6ab5 \ + --hash=sha256:e5f0642cdf4636198a4990de7a71b693d824c56a757862230454629cf62e323d \ + --hash=sha256:f07fa2f15dabe91259828ce7d71b5ca9e2eb7c8c26baa822c825ce43552f4883 \ + --hash=sha256:f15976718c004466406342789f31b6673776360f3b1e3c575f25302d7e789575 \ + --hash=sha256:f358ea9e47eb3c2d6eba121ab512dfff38a88db719c38d1e67349af210bc7529 \ + --hash=sha256:f505264735ee074250a9c78247ee8618292091d9d1fcc023290e9ac67e8f1afa \ + --hash=sha256:f5ebbf9fbdabed208d4ecd2e1dfd2c0741af2f876e7ae522c2537d404ca895c3 \ + --hash=sha256:f6b26e6c3b98adb648243670fddc8cab6ae17473f9dc58c51574af3e64d61211 \ + --hash=sha256:f8e93a01a35be08d31ae33021e5268f157a2d60ebd643cfc15de6ab8e4722eb1 \ + --hash=sha256:fe76d75b345dc045acdbc006adcb197cc680754afd6c259de60d358d60c93736 \ + --hash=sha256:ffbd6faeb190aaf2b5e9024bac9622d2ee549b7ec89ef3a9373fa35313d44e0e + # via + # -r core/operators/dimension_reduction_requirements.in + # scikit-learn + # scipy +scikit-learn==1.5.1 \ + --hash=sha256:0828673c5b520e879f2af6a9e99eee0eefea69a2188be1ca68a6121b809055c1 \ + --hash=sha256:0ea5d40c0e3951df445721927448755d3fe1d80833b0b7308ebff5d2a45e6414 \ + --hash=sha256:10e49170691514a94bb2e03787aa921b82dbc507a4ea1f20fd95557862c98dc1 \ + --hash=sha256:154297ee43c0b83af12464adeab378dee2d0a700ccd03979e2b821e7dd7cc1c2 \ + --hash=sha256:161808750c267b77b4a9603cf9c93579c7a74ba8486b1336034c2f1579546d21 \ + --hash=sha256:1bd8d3a19d4bd6dc5a7d4f358c8c3a60934dc058f363c34c0ac1e9e12a31421d \ + --hash=sha256:1ff4ba34c2abff5ec59c803ed1d97d61b036f659a17f55be102679e88f926fac \ + --hash=sha256:508907e5f81390e16d754e8815f7497e52139162fd69c4fdbd2dfa5d6cc88915 \ + --hash=sha256:5944ce1faada31c55fb2ba20a5346b88e36811aab504ccafb9f0339e9f780395 \ + --hash=sha256:5f57428de0c900a98389c4a433d4a3cf89de979b3aa24d1c1d251802aa15e44d \ + --hash=sha256:689b6f74b2c880276e365fe84fe4f1befd6a774f016339c65655eaff12e10cbf \ + --hash=sha256:781586c414f8cc58e71da4f3d7af311e0505a683e112f2f62919e3019abd3745 \ + --hash=sha256:7b073a27797a283187a4ef4ee149959defc350b46cbf63a84d8514fe16b69855 \ + --hash=sha256:88e0672c7ac21eb149d409c74cc29f1d611d5158175846e7a9c2427bd12b3956 \ + --hash=sha256:909144d50f367a513cee6090873ae582dba019cb3fca063b38054fa42704c3a4 \ + --hash=sha256:97625f217c5c0c5d0505fa2af28ae424bd37949bb2f16ace3ff5f2f81fb4498b \ + --hash=sha256:9a07f90846313a7639af6a019d849ff72baadfa4c74c778821ae0fad07b7275b \ + --hash=sha256:b59e3e62d2be870e5c74af4e793293753565c7383ae82943b83383fdcf5cc5c1 \ + --hash=sha256:b5e865e9bd59396220de49cb4a57b17016256637c61b4c5cc81aaf16bc123bbe \ + --hash=sha256:da3f404e9e284d2b0a157e1b56b6566a34eb2798205cba35a211df3296ab7a74 \ + --hash=sha256:f5b213bc29cc30a89a3130393b0e39c847a15d769d6e59539cd86b75d276b1a7 + # via -r core/operators/dimension_reduction_requirements.in +scipy==1.14.1 \ + --hash=sha256:0c2f95de3b04e26f5f3ad5bb05e74ba7f68b837133a4492414b3afd79dfe540e \ + --hash=sha256:1729560c906963fc8389f6aac023739ff3983e727b1a4d87696b7bf108316a79 \ + --hash=sha256:278266012eb69f4a720827bdd2dc54b2271c97d84255b2faaa8f161a158c3b37 \ + --hash=sha256:2843f2d527d9eebec9a43e6b406fb7266f3af25a751aa91d62ff416f54170bc5 \ + --hash=sha256:2da0469a4ef0ecd3693761acbdc20f2fdeafb69e6819cc081308cc978153c675 \ + --hash=sha256:2ff0a7e01e422c15739ecd64432743cf7aae2b03f3084288f399affcefe5222d \ + --hash=sha256:2ff38e22128e6c03ff73b6bb0f85f897d2362f8c052e3b8ad00532198fbdae3f \ + --hash=sha256:30ac8812c1d2aab7131a79ba62933a2a76f582d5dbbc695192453dae67ad6310 \ + --hash=sha256:3a1b111fac6baec1c1d92f27e76511c9e7218f1695d61b59e05e0fe04dc59617 \ + --hash=sha256:4079b90df244709e675cdc8b93bfd8a395d59af40b72e339c2287c91860deb8e \ + --hash=sha256:5149e3fd2d686e42144a093b206aef01932a0059c2a33ddfa67f5f035bdfe13e \ + --hash=sha256:5a275584e726026a5699459aa72f828a610821006228e841b94275c4a7c08417 \ + --hash=sha256:631f07b3734d34aced009aaf6fedfd0eb3498a97e581c3b1e5f14a04164a456d \ + --hash=sha256:716e389b694c4bb564b4fc0c51bc84d381735e0d39d3f26ec1af2556ec6aad94 \ + --hash=sha256:8426251ad1e4ad903a4514712d2fa8fdd5382c978010d1c6f5f37ef286a713ad \ + --hash=sha256:8475230e55549ab3f207bff11ebfc91c805dc3463ef62eda3ccf593254524ce8 \ + --hash=sha256:8bddf15838ba768bb5f5083c1ea012d64c9a444e16192762bd858f1e126196d0 \ + --hash=sha256:8e32dced201274bf96899e6491d9ba3e9a5f6b336708656466ad0522d8528f69 \ + --hash=sha256:8f9ea80f2e65bdaa0b7627fb00cbeb2daf163caa015e59b7516395fe3bd1e066 \ + --hash=sha256:97c5dddd5932bd2a1a31c927ba5e1463a53b87ca96b5c9bdf5dfd6096e27efc3 \ + --hash=sha256:a49f6ed96f83966f576b33a44257d869756df6cf1ef4934f59dd58b25e0327e5 \ + --hash=sha256:af29a935803cc707ab2ed7791c44288a682f9c8107bc00f0eccc4f92c08d6e07 \ + --hash=sha256:b05d43735bb2f07d689f56f7b474788a13ed8adc484a85aa65c0fd931cf9ccd2 \ + --hash=sha256:b28d2ca4add7ac16ae8bb6632a3c86e4b9e4d52d3e34267f6e1b0c1f8d87e389 \ + --hash=sha256:b99722ea48b7ea25e8e015e8341ae74624f72e5f21fc2abd45f3a93266de4c5d \ + --hash=sha256:baff393942b550823bfce952bb62270ee17504d02a1801d7fd0719534dfb9c84 \ + --hash=sha256:c0ee987efa6737242745f347835da2cc5bb9f1b42996a4d97d5c7ff7928cb6f2 \ + --hash=sha256:d0d2821003174de06b69e58cef2316a6622b60ee613121199cb2852a873f8cf3 \ + --hash=sha256:e0cf28db0f24a38b2a0ca33a85a54852586e43cf6fd876365c86e0657cfe7d73 \ + --hash=sha256:e4f5a7c49323533f9103d4dacf4e4f07078f360743dec7f7596949149efeec06 \ + --hash=sha256:eb58ca0abd96911932f688528977858681a59d61a7ce908ffd355957f7025cfc \ + --hash=sha256:edaf02b82cd7639db00dbff629995ef185c8df4c3ffa71a5562a595765a06ce1 \ + --hash=sha256:fef8c87f8abfb884dac04e97824b61299880c43f4ce675dd2cbeadd3c9b466d2 + # via scikit-learn +threadpoolctl==3.5.0 \ + --hash=sha256:082433502dd922bf738de0d8bcc4fdcbf0979ff44c42bd40f5af8a282f6fa107 \ + --hash=sha256:56c1e26c150397e58c4926da8eeee87533b1e32bef131bd4bf6a2f45f3185467 + # via scikit-learn diff --git a/src/core/operators/test_dimension_reduction.py b/src/core/operators/test_dimension_reduction.py new file mode 100644 index 0000000..b348062 --- /dev/null +++ b/src/core/operators/test_dimension_reduction.py @@ -0,0 +1,49 @@ +import unittest +import numpy as np +from core.operators.dimension_reduction import setup_reduction, perform_reduction + + +class TestDimensionReductionOperator(unittest.TestCase): + @classmethod + def setUpClass(cls): + # Initialize operator + cls.initial_params = { + 'n_components': 2, + 'perplexity': 30, + 'learning_rate': 200, + 'max_iter': 250, + 'random_state': 42, + 'method': 'barnes_hut' + } + setup_reduction('tsne', cls.initial_params) + + @classmethod + def tearDownClass(cls): + # Clean up if necessary + pass + + def test_tsne_reduction(self): + # Create sample embeddings + sample_embeddings = np.random.rand(100, 50) # 100 samples, 50 dimensions + + input_data = [{'payload': str(i), 'embedding': embedding} for i, embedding in enumerate(sample_embeddings)] + + # Perform reduction + reduced_data = perform_reduction(input_data) + reduced_embeddings = np.array([d['reduced_embedding'] for d in reduced_data]) + + # Check output shape + self.assertEqual(reduced_embeddings.shape, (100, 2)) # Should reduce to 2D + + def test_invalid_input(self): + # Test with empty list + with self.assertRaises(ValueError): + perform_reduction([]) + + # Test with non-list input + with self.assertRaises(ValueError): + perform_reduction("not a list") + + # Test with missing keys in input data + with self.assertRaises(KeyError): + perform_reduction([{'payload': '123'}]) \ No newline at end of file