Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[82] - add dimensionality reduction operator #376

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 179 additions & 0 deletions src/core/operators/dimension_reduction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
"""Operator to perform dimensionality reduction given the embedddings."""

from abc import ABC, abstractmethod
from sklearn.manifold import TSNE
import numpy as np


class DimensionReduction(ABC):
"""Abstract base class for dimension reduction techniques."""

@abstractmethod
def initialize(self, params):
pass

@abstractmethod
def run(self, embeddings):
pass


class TSNEReduction(DimensionReduction):
"""t-SNE implementation of the DimensionReduction abstract class."""

def initialize(self, params):
"""
Initialize the t-SNE model with parameters.

Args:
params (dict): A dictionary containing t-SNE parameters such as:
- n_components (int): Number of dimensions to reduce to. Default is 2.
- perplexity (float): Perplexity parameter for t-SNE. Default is 30.
- learning_rate (float): Learning rate for t-SNE. Default is 150.
- n_iter (int): Number of iterations for optimization. Default is 1000.
- random_state (int): Seed for random number generation. Default is 42.

Raises:
ValueError: If the t-SNE model fails to initialize.
"""
try:
self.model = TSNE(
n_components=params.get('n_components', 2),
perplexity=params.get('perplexity', 30),
learning_rate=params.get('learning_rate', 150),
max_iter=params.get('max_iter', 1000),
random_state=params.get('random_state', 42),
method=params.get('method', 'barnes_hut')
)
print("t-SNE model successfully initialized")
except Exception as e:
raise ValueError(f"Failed to initialize t-SNE model: {e}")

def run(self, embeddings_array):
"""
Apply the t-SNE model to reduce the dimensionality of embeddings.

Args:
embeddings (list or numpy.ndarray): A list or array of embeddings to be reduced.

Returns:
numpy.ndarray: The reduced embeddings as a 2D array.

Raises:
ValueError: If the embeddings input is not a 2D array.
RuntimeError: If the t-SNE reduction fails.
"""
try:
if embeddings_array.ndim != 2:
raise ValueError("Embeddings should be a 2D array.")
return self.model.fit_transform(embeddings_array)
except Exception as e:
raise RuntimeError(f"t-SNE reduction failed: {e}")


class DimensionReductionFactory:
"""Factory class for creating dimension reduction models."""

@staticmethod
def get_reduction_model(model_type):
"""
Factory method to create a dimension reduction model based on type.

Args:
model_type (str): String indicating the type of model (e.g., 'tsne').

Returns:
DimensionReduction: An instance of the corresponding dimension reduction model.

Raises:
ValueError: If the specified model type is unsupported.
"""
if model_type.lower() == 'tsne':
return TSNEReduction()
else:
raise ValueError(f"Unsupported model type: {model_type}")


def gen_data(payloads, reduced_embeddings):
"""
Generates the formatted output.

Args:
payloads (list): List of paylods.
reduced_embeddings (nd.array): An array of reduced embeddings.

Returns:
list: A list of dictionaries containing the payload and corresponding embedding.
"""
out = []

for payload, reduced_embedding in zip(payloads, reduced_embeddings):
tmp_dict = {}
tmp_dict['payload'] = payload
tmp_dict['reduced_embedding'] = reduced_embedding
out.append(tmp_dict)
return out


def setup_reduction(model_type, params):
"""
Initialize the dimension reduction model with provided type and parameters.

Args:
model_type (str): String indicating the type of model (e.g., 'tsne').
params (dict): Dictionary of parameters for the model initialization.

"""
global reduction_model
reduction_model = DimensionReductionFactory.get_reduction_model(model_type)
reduction_model.initialize(params)


def perform_reduction(input_data):
"""
Reduce the dimensionality of the provided embeddings using the initialized model.

Args:
input_data (list): A list of dictionaries containing payload and embeddings to be reduced.
Example:

[
{
"payload": "123",
"embedding": [1, 2, 3]
},
{
"payload": "124",
"embedding": [1, 0, 1]
}
]

Returns:
list: The reduced embeddings and the corresponding payload as a list of dictionaries.
Example:

[
{
"payload":"123",
"reduced_embedding": [1, 2]
},
{
"payload": "124",
"reduced_embedding": [1, 0]
}
]

Raises:
ValueError: If the embeddings input is not a non-empty list.
KeyError: If the input data is invalid.
"""
if not isinstance(input_data, list) or len(input_data) == 0:
raise ValueError("Input should be a non-empty list.")

try:
embeddings, payloads = zip(*[(data['embedding'], data['payload']) for data in input_data])
except KeyError as e:
raise KeyError(f"Invalid data. Each data point in input must have `embedding` and `payload` properties. Missing key: {e}.")

reduced_embeddings = reduction_model.run(np.array(embeddings))

return gen_data(payloads, reduced_embeddings)
2 changes: 2 additions & 0 deletions src/core/operators/dimension_reduction_requirements.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
scikit-learn==1.5.1
numpy==2.1.0
129 changes: 129 additions & 0 deletions src/core/operators/dimension_reduction_requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
#
# This file is autogenerated by pip-compile with Python 3.10
# by the following command:
#
# pip-compile --allow-unsafe --generate-hashes core/operators/dimension_reduction_requirements.in
#
joblib==1.4.2 \
--hash=sha256:06d478d5674cbc267e7496a410ee875abd68e4340feff4490bcb7afb88060ae6 \
--hash=sha256:2382c5816b2636fbd20a09e0f4e9dad4736765fdfb7dca582943b9c1366b3f0e
# via scikit-learn
numpy==2.1.0 \
--hash=sha256:08801848a40aea24ce16c2ecde3b756f9ad756586fb2d13210939eb69b023f5b \
--hash=sha256:0937e54c09f7a9a68da6889362ddd2ff584c02d015ec92672c099b61555f8911 \
--hash=sha256:0ab32eb9170bf8ffcbb14f11613f4a0b108d3ffee0832457c5d4808233ba8977 \
--hash=sha256:0abb3916a35d9090088a748636b2c06dc9a6542f99cd476979fb156a18192b84 \
--hash=sha256:0af3a5987f59d9c529c022c8c2a64805b339b7ef506509fba7d0556649b9714b \
--hash=sha256:10e2350aea18d04832319aac0f887d5fcec1b36abd485d14f173e3e900b83e33 \
--hash=sha256:15ef8b2177eeb7e37dd5ef4016f30b7659c57c2c0b57a779f1d537ff33a72c7b \
--hash=sha256:1f817c71683fd1bb5cff1529a1d085a57f02ccd2ebc5cd2c566f9a01118e3b7d \
--hash=sha256:24003ba8ff22ea29a8c306e61d316ac74111cebf942afbf692df65509a05f111 \
--hash=sha256:30014b234f07b5fec20f4146f69e13cfb1e33ee9a18a1879a0142fbb00d47673 \
--hash=sha256:343e3e152bf5a087511cd325e3b7ecfd5b92d369e80e74c12cd87826e263ec06 \
--hash=sha256:378cb4f24c7d93066ee4103204f73ed046eb88f9ad5bb2275bb9fa0f6a02bd36 \
--hash=sha256:398049e237d1aae53d82a416dade04defed1a47f87d18d5bd615b6e7d7e41d1f \
--hash=sha256:3a3336fbfa0d38d3deacd3fe7f3d07e13597f29c13abf4d15c3b6dc2291cbbdd \
--hash=sha256:442596f01913656d579309edcd179a2a2f9977d9a14ff41d042475280fc7f34e \
--hash=sha256:44e44973262dc3ae79e9063a1284a73e09d01b894b534a769732ccd46c28cc62 \
--hash=sha256:54139e0eb219f52f60656d163cbe67c31ede51d13236c950145473504fa208cb \
--hash=sha256:5474dad8c86ee9ba9bb776f4b99ef2d41b3b8f4e0d199d4f7304728ed34d0300 \
--hash=sha256:54c6a63e9d81efe64bfb7bcb0ec64332a87d0b87575f6009c8ba67ea6374770b \
--hash=sha256:624884b572dff8ca8f60fab591413f077471de64e376b17d291b19f56504b2bb \
--hash=sha256:6326ab99b52fafdcdeccf602d6286191a79fe2fda0ae90573c5814cd2b0bc1b8 \
--hash=sha256:652e92fc409e278abdd61e9505649e3938f6d04ce7ef1953f2ec598a50e7c195 \
--hash=sha256:6c1de77ded79fef664d5098a66810d4d27ca0224e9051906e634b3f7ead134c2 \
--hash=sha256:76368c788ccb4f4782cf9c842b316140142b4cbf22ff8db82724e82fe1205dce \
--hash=sha256:7a894c51fd8c4e834f00ac742abad73fc485df1062f1b875661a3c1e1fb1c2f6 \
--hash=sha256:7dc90da0081f7e1da49ec4e398ede6a8e9cc4f5ebe5f9e06b443ed889ee9aaa2 \
--hash=sha256:848c6b5cad9898e4b9ef251b6f934fa34630371f2e916261070a4eb9092ffd33 \
--hash=sha256:899da829b362ade41e1e7eccad2cf274035e1cb36ba73034946fccd4afd8606b \
--hash=sha256:8ab81ccd753859ab89e67199b9da62c543850f819993761c1e94a75a814ed667 \
--hash=sha256:8fb49a0ba4d8f41198ae2d52118b050fd34dace4b8f3fb0ee34e23eb4ae775b1 \
--hash=sha256:9156ca1f79fc4acc226696e95bfcc2b486f165a6a59ebe22b2c1f82ab190384a \
--hash=sha256:9523f8b46485db6939bd069b28b642fec86c30909cea90ef550373787f79530e \
--hash=sha256:a0756a179afa766ad7cb6f036de622e8a8f16ffdd55aa31f296c870b5679d745 \
--hash=sha256:a0cdef204199278f5c461a0bed6ed2e052998276e6d8ab2963d5b5c39a0500bc \
--hash=sha256:ab83adc099ec62e044b1fbb3a05499fa1e99f6d53a1dde102b2d85eff66ed324 \
--hash=sha256:b34fa5e3b5d6dc7e0a4243fa0f81367027cb6f4a7215a17852979634b5544ee0 \
--hash=sha256:b47c551c6724960479cefd7353656498b86e7232429e3a41ab83be4da1b109e8 \
--hash=sha256:c4cd94dfefbefec3f8b544f61286584292d740e6e9d4677769bc76b8f41deb02 \
--hash=sha256:c4f982715e65036c34897eb598d64aef15150c447be2cfc6643ec7a11af06574 \
--hash=sha256:d8f699a709120b220dfe173f79c73cb2a2cab2c0b88dd59d7b49407d032b8ebd \
--hash=sha256:dd94ce596bda40a9618324547cfaaf6650b1a24f5390350142499aa4e34e53d1 \
--hash=sha256:de844aaa4815b78f6023832590d77da0e3b6805c644c33ce94a1e449f16d6ab5 \
--hash=sha256:e5f0642cdf4636198a4990de7a71b693d824c56a757862230454629cf62e323d \
--hash=sha256:f07fa2f15dabe91259828ce7d71b5ca9e2eb7c8c26baa822c825ce43552f4883 \
--hash=sha256:f15976718c004466406342789f31b6673776360f3b1e3c575f25302d7e789575 \
--hash=sha256:f358ea9e47eb3c2d6eba121ab512dfff38a88db719c38d1e67349af210bc7529 \
--hash=sha256:f505264735ee074250a9c78247ee8618292091d9d1fcc023290e9ac67e8f1afa \
--hash=sha256:f5ebbf9fbdabed208d4ecd2e1dfd2c0741af2f876e7ae522c2537d404ca895c3 \
--hash=sha256:f6b26e6c3b98adb648243670fddc8cab6ae17473f9dc58c51574af3e64d61211 \
--hash=sha256:f8e93a01a35be08d31ae33021e5268f157a2d60ebd643cfc15de6ab8e4722eb1 \
--hash=sha256:fe76d75b345dc045acdbc006adcb197cc680754afd6c259de60d358d60c93736 \
--hash=sha256:ffbd6faeb190aaf2b5e9024bac9622d2ee549b7ec89ef3a9373fa35313d44e0e
# via
# -r core/operators/dimension_reduction_requirements.in
# scikit-learn
# scipy
scikit-learn==1.5.1 \
--hash=sha256:0828673c5b520e879f2af6a9e99eee0eefea69a2188be1ca68a6121b809055c1 \
--hash=sha256:0ea5d40c0e3951df445721927448755d3fe1d80833b0b7308ebff5d2a45e6414 \
--hash=sha256:10e49170691514a94bb2e03787aa921b82dbc507a4ea1f20fd95557862c98dc1 \
--hash=sha256:154297ee43c0b83af12464adeab378dee2d0a700ccd03979e2b821e7dd7cc1c2 \
--hash=sha256:161808750c267b77b4a9603cf9c93579c7a74ba8486b1336034c2f1579546d21 \
--hash=sha256:1bd8d3a19d4bd6dc5a7d4f358c8c3a60934dc058f363c34c0ac1e9e12a31421d \
--hash=sha256:1ff4ba34c2abff5ec59c803ed1d97d61b036f659a17f55be102679e88f926fac \
--hash=sha256:508907e5f81390e16d754e8815f7497e52139162fd69c4fdbd2dfa5d6cc88915 \
--hash=sha256:5944ce1faada31c55fb2ba20a5346b88e36811aab504ccafb9f0339e9f780395 \
--hash=sha256:5f57428de0c900a98389c4a433d4a3cf89de979b3aa24d1c1d251802aa15e44d \
--hash=sha256:689b6f74b2c880276e365fe84fe4f1befd6a774f016339c65655eaff12e10cbf \
--hash=sha256:781586c414f8cc58e71da4f3d7af311e0505a683e112f2f62919e3019abd3745 \
--hash=sha256:7b073a27797a283187a4ef4ee149959defc350b46cbf63a84d8514fe16b69855 \
--hash=sha256:88e0672c7ac21eb149d409c74cc29f1d611d5158175846e7a9c2427bd12b3956 \
--hash=sha256:909144d50f367a513cee6090873ae582dba019cb3fca063b38054fa42704c3a4 \
--hash=sha256:97625f217c5c0c5d0505fa2af28ae424bd37949bb2f16ace3ff5f2f81fb4498b \
--hash=sha256:9a07f90846313a7639af6a019d849ff72baadfa4c74c778821ae0fad07b7275b \
--hash=sha256:b59e3e62d2be870e5c74af4e793293753565c7383ae82943b83383fdcf5cc5c1 \
--hash=sha256:b5e865e9bd59396220de49cb4a57b17016256637c61b4c5cc81aaf16bc123bbe \
--hash=sha256:da3f404e9e284d2b0a157e1b56b6566a34eb2798205cba35a211df3296ab7a74 \
--hash=sha256:f5b213bc29cc30a89a3130393b0e39c847a15d769d6e59539cd86b75d276b1a7
# via -r core/operators/dimension_reduction_requirements.in
scipy==1.14.1 \
--hash=sha256:0c2f95de3b04e26f5f3ad5bb05e74ba7f68b837133a4492414b3afd79dfe540e \
--hash=sha256:1729560c906963fc8389f6aac023739ff3983e727b1a4d87696b7bf108316a79 \
--hash=sha256:278266012eb69f4a720827bdd2dc54b2271c97d84255b2faaa8f161a158c3b37 \
--hash=sha256:2843f2d527d9eebec9a43e6b406fb7266f3af25a751aa91d62ff416f54170bc5 \
--hash=sha256:2da0469a4ef0ecd3693761acbdc20f2fdeafb69e6819cc081308cc978153c675 \
--hash=sha256:2ff0a7e01e422c15739ecd64432743cf7aae2b03f3084288f399affcefe5222d \
--hash=sha256:2ff38e22128e6c03ff73b6bb0f85f897d2362f8c052e3b8ad00532198fbdae3f \
--hash=sha256:30ac8812c1d2aab7131a79ba62933a2a76f582d5dbbc695192453dae67ad6310 \
--hash=sha256:3a1b111fac6baec1c1d92f27e76511c9e7218f1695d61b59e05e0fe04dc59617 \
--hash=sha256:4079b90df244709e675cdc8b93bfd8a395d59af40b72e339c2287c91860deb8e \
--hash=sha256:5149e3fd2d686e42144a093b206aef01932a0059c2a33ddfa67f5f035bdfe13e \
--hash=sha256:5a275584e726026a5699459aa72f828a610821006228e841b94275c4a7c08417 \
--hash=sha256:631f07b3734d34aced009aaf6fedfd0eb3498a97e581c3b1e5f14a04164a456d \
--hash=sha256:716e389b694c4bb564b4fc0c51bc84d381735e0d39d3f26ec1af2556ec6aad94 \
--hash=sha256:8426251ad1e4ad903a4514712d2fa8fdd5382c978010d1c6f5f37ef286a713ad \
--hash=sha256:8475230e55549ab3f207bff11ebfc91c805dc3463ef62eda3ccf593254524ce8 \
--hash=sha256:8bddf15838ba768bb5f5083c1ea012d64c9a444e16192762bd858f1e126196d0 \
--hash=sha256:8e32dced201274bf96899e6491d9ba3e9a5f6b336708656466ad0522d8528f69 \
--hash=sha256:8f9ea80f2e65bdaa0b7627fb00cbeb2daf163caa015e59b7516395fe3bd1e066 \
--hash=sha256:97c5dddd5932bd2a1a31c927ba5e1463a53b87ca96b5c9bdf5dfd6096e27efc3 \
--hash=sha256:a49f6ed96f83966f576b33a44257d869756df6cf1ef4934f59dd58b25e0327e5 \
--hash=sha256:af29a935803cc707ab2ed7791c44288a682f9c8107bc00f0eccc4f92c08d6e07 \
--hash=sha256:b05d43735bb2f07d689f56f7b474788a13ed8adc484a85aa65c0fd931cf9ccd2 \
--hash=sha256:b28d2ca4add7ac16ae8bb6632a3c86e4b9e4d52d3e34267f6e1b0c1f8d87e389 \
--hash=sha256:b99722ea48b7ea25e8e015e8341ae74624f72e5f21fc2abd45f3a93266de4c5d \
--hash=sha256:baff393942b550823bfce952bb62270ee17504d02a1801d7fd0719534dfb9c84 \
--hash=sha256:c0ee987efa6737242745f347835da2cc5bb9f1b42996a4d97d5c7ff7928cb6f2 \
--hash=sha256:d0d2821003174de06b69e58cef2316a6622b60ee613121199cb2852a873f8cf3 \
--hash=sha256:e0cf28db0f24a38b2a0ca33a85a54852586e43cf6fd876365c86e0657cfe7d73 \
--hash=sha256:e4f5a7c49323533f9103d4dacf4e4f07078f360743dec7f7596949149efeec06 \
--hash=sha256:eb58ca0abd96911932f688528977858681a59d61a7ce908ffd355957f7025cfc \
--hash=sha256:edaf02b82cd7639db00dbff629995ef185c8df4c3ffa71a5562a595765a06ce1 \
--hash=sha256:fef8c87f8abfb884dac04e97824b61299880c43f4ce675dd2cbeadd3c9b466d2
# via scikit-learn
threadpoolctl==3.5.0 \
--hash=sha256:082433502dd922bf738de0d8bcc4fdcbf0979ff44c42bd40f5af8a282f6fa107 \
--hash=sha256:56c1e26c150397e58c4926da8eeee87533b1e32bef131bd4bf6a2f45f3185467
# via scikit-learn
49 changes: 49 additions & 0 deletions src/core/operators/test_dimension_reduction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import unittest
import numpy as np
from core.operators.dimension_reduction import setup_reduction, perform_reduction


class TestDimensionReductionOperator(unittest.TestCase):
@classmethod
def setUpClass(cls):
# Initialize operator
cls.initial_params = {
'n_components': 2,
'perplexity': 30,
'learning_rate': 200,
'max_iter': 250,
'random_state': 42,
'method': 'barnes_hut'
}
setup_reduction('tsne', cls.initial_params)

@classmethod
def tearDownClass(cls):
# Clean up if necessary
pass

def test_tsne_reduction(self):
# Create sample embeddings
sample_embeddings = np.random.rand(100, 50) # 100 samples, 50 dimensions

input_data = [{'payload': str(i), 'embedding': embedding} for i, embedding in enumerate(sample_embeddings)]

# Perform reduction
reduced_data = perform_reduction(input_data)
reduced_embeddings = np.array([d['reduced_embedding'] for d in reduced_data])

# Check output shape
self.assertEqual(reduced_embeddings.shape, (100, 2)) # Should reduce to 2D

def test_invalid_input(self):
# Test with empty list
with self.assertRaises(ValueError):
perform_reduction([])

# Test with non-list input
with self.assertRaises(ValueError):
perform_reduction("not a list")

# Test with missing keys in input data
with self.assertRaises(KeyError):
perform_reduction([{'payload': '123'}])
Loading