Skip to content

Commit

Permalink
feat: add dimensionality reduction operator
Browse files Browse the repository at this point in the history
  • Loading branch information
Chaithanya512 committed Aug 27, 2024
1 parent faa9727 commit 2caf4a4
Show file tree
Hide file tree
Showing 4 changed files with 359 additions and 0 deletions.
179 changes: 179 additions & 0 deletions src/core/operators/dimension_reduction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
"""Operator to perform dimensionality reduction given the embedddings."""

from abc import ABC, abstractmethod
from sklearn.manifold import TSNE
import numpy as np


class DimensionReduction(ABC):
"""Abstract base class for dimension reduction techniques."""

@abstractmethod
def initialize(self, params):
pass

@abstractmethod
def run(self, embeddings):
pass


class TSNEReduction(DimensionReduction):
"""t-SNE implementation of the DimensionReduction abstract class."""

def initialize(self, params):
"""
Initialize the t-SNE model with parameters.
Args:
params (dict): A dictionary containing t-SNE parameters such as:
- n_components (int): Number of dimensions to reduce to. Default is 2.
- perplexity (float): Perplexity parameter for t-SNE. Default is 30.
- learning_rate (float): Learning rate for t-SNE. Default is 150.
- n_iter (int): Number of iterations for optimization. Default is 1000.
- random_state (int): Seed for random number generation. Default is 42.
Raises:
ValueError: If the t-SNE model fails to initialize.
"""
try:
self.model = TSNE(
n_components=params.get('n_components', 2),
perplexity=params.get('perplexity', 30),
learning_rate=params.get('learning_rate', 150),
max_iter=params.get('max_iter', 1000),
random_state=params.get('random_state', 42),
method=params.get('method', 'barnes_hut')
)
print("t-SNE model successfully initialized")
except Exception as e:
raise ValueError(f"Failed to initialize t-SNE model: {e}")

def run(self, embeddings_array):
"""
Apply the t-SNE model to reduce the dimensionality of embeddings.
Args:
embeddings (list or numpy.ndarray): A list or array of embeddings to be reduced.
Returns:
numpy.ndarray: The reduced embeddings as a 2D array.
Raises:
ValueError: If the embeddings input is not a 2D array.
RuntimeError: If the t-SNE reduction fails.
"""
try:
if embeddings_array.ndim != 2:
raise ValueError("Embeddings should be a 2D array.")
return self.model.fit_transform(embeddings_array)
except Exception as e:
raise RuntimeError(f"t-SNE reduction failed: {e}")


class DimensionReductionFactory:
"""Factory class for creating dimension reduction models."""

@staticmethod
def get_reduction_model(model_type):
"""
Factory method to create a dimension reduction model based on type.
Args:
model_type (str): String indicating the type of model (e.g., 'tsne').
Returns:
DimensionReduction: An instance of the corresponding dimension reduction model.
Raises:
ValueError: If the specified model type is unsupported.
"""
if model_type.lower() == 'tsne':
return TSNEReduction()
else:
raise ValueError(f"Unsupported model type: {model_type}")


def gen_data(payloads, reduced_embeddings):
"""
Generates the formatted output.
Args:
payloads (list): List of paylods.
reduced_embeddings (nd.array): An array of reduced embeddings.
Returns:
list: A list of dictionaries containing the payload and corresponding embedding.
"""
out = []

for payload, reduced_embedding in zip(payloads, reduced_embeddings):
tmp_dict = {}
tmp_dict['payload'] = payload
tmp_dict['reduced_embedding'] = reduced_embedding
out.append(tmp_dict)
return out


def setup_reduction(model_type, params):
"""
Initialize the dimension reduction model with provided type and parameters.
Args:
model_type (str): String indicating the type of model (e.g., 'tsne').
params (dict): Dictionary of parameters for the model initialization.
"""
global reduction_model
reduction_model = DimensionReductionFactory.get_reduction_model(model_type)
reduction_model.initialize(params)


def perform_reduction(input_data):
"""
Reduce the dimensionality of the provided embeddings using the initialized model.
Args:
input_data (list): A list of dictionaries containing payload and embeddings to be reduced.
Example:
[
{
"payload": "123",
"embedding": [1, 2, 3]
},
{
"payload": "124",
"embedding": [1, 0, 1]
}
]
Returns:
list: The reduced embeddings and the corresponding payload as a list of dictionaries.
Example:
[
{
"payload":"123",
"reduced_embedding": [1, 2]
},
{
"payload": "124",
"reduced_embedding": [1, 0]
}
]
Raises:
ValueError: If the embeddings input is not a non-empty list.
KeyError: If the input data is invalid.
"""
if not isinstance(input_data, list) or len(input_data) == 0:
raise ValueError("Input should be a non-empty list.")

try:
embeddings, payloads = zip(*[(data['embedding'], data['payload']) for data in input_data])
except KeyError as e:
raise KeyError(f"Invalid data. Each data point in input must have `embedding` and `payload` properties. Missing key: {e}.")

reduced_embeddings = reduction_model.run(np.array(embeddings))

return gen_data(payloads, reduced_embeddings)
2 changes: 2 additions & 0 deletions src/core/operators/dimension_reduction_requirements.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
scikit-learn==1.5.1
numpy==2.1.0
129 changes: 129 additions & 0 deletions src/core/operators/dimension_reduction_requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
#
# This file is autogenerated by pip-compile with Python 3.10
# by the following command:
#
# pip-compile --allow-unsafe --generate-hashes core/operators/dimension_reduction_requirements.in
#
joblib==1.4.2 \
--hash=sha256:06d478d5674cbc267e7496a410ee875abd68e4340feff4490bcb7afb88060ae6 \
--hash=sha256:2382c5816b2636fbd20a09e0f4e9dad4736765fdfb7dca582943b9c1366b3f0e
# via scikit-learn
numpy==2.1.0 \
--hash=sha256:08801848a40aea24ce16c2ecde3b756f9ad756586fb2d13210939eb69b023f5b \
--hash=sha256:0937e54c09f7a9a68da6889362ddd2ff584c02d015ec92672c099b61555f8911 \
--hash=sha256:0ab32eb9170bf8ffcbb14f11613f4a0b108d3ffee0832457c5d4808233ba8977 \
--hash=sha256:0abb3916a35d9090088a748636b2c06dc9a6542f99cd476979fb156a18192b84 \
--hash=sha256:0af3a5987f59d9c529c022c8c2a64805b339b7ef506509fba7d0556649b9714b \
--hash=sha256:10e2350aea18d04832319aac0f887d5fcec1b36abd485d14f173e3e900b83e33 \
--hash=sha256:15ef8b2177eeb7e37dd5ef4016f30b7659c57c2c0b57a779f1d537ff33a72c7b \
--hash=sha256:1f817c71683fd1bb5cff1529a1d085a57f02ccd2ebc5cd2c566f9a01118e3b7d \
--hash=sha256:24003ba8ff22ea29a8c306e61d316ac74111cebf942afbf692df65509a05f111 \
--hash=sha256:30014b234f07b5fec20f4146f69e13cfb1e33ee9a18a1879a0142fbb00d47673 \
--hash=sha256:343e3e152bf5a087511cd325e3b7ecfd5b92d369e80e74c12cd87826e263ec06 \
--hash=sha256:378cb4f24c7d93066ee4103204f73ed046eb88f9ad5bb2275bb9fa0f6a02bd36 \
--hash=sha256:398049e237d1aae53d82a416dade04defed1a47f87d18d5bd615b6e7d7e41d1f \
--hash=sha256:3a3336fbfa0d38d3deacd3fe7f3d07e13597f29c13abf4d15c3b6dc2291cbbdd \
--hash=sha256:442596f01913656d579309edcd179a2a2f9977d9a14ff41d042475280fc7f34e \
--hash=sha256:44e44973262dc3ae79e9063a1284a73e09d01b894b534a769732ccd46c28cc62 \
--hash=sha256:54139e0eb219f52f60656d163cbe67c31ede51d13236c950145473504fa208cb \
--hash=sha256:5474dad8c86ee9ba9bb776f4b99ef2d41b3b8f4e0d199d4f7304728ed34d0300 \
--hash=sha256:54c6a63e9d81efe64bfb7bcb0ec64332a87d0b87575f6009c8ba67ea6374770b \
--hash=sha256:624884b572dff8ca8f60fab591413f077471de64e376b17d291b19f56504b2bb \
--hash=sha256:6326ab99b52fafdcdeccf602d6286191a79fe2fda0ae90573c5814cd2b0bc1b8 \
--hash=sha256:652e92fc409e278abdd61e9505649e3938f6d04ce7ef1953f2ec598a50e7c195 \
--hash=sha256:6c1de77ded79fef664d5098a66810d4d27ca0224e9051906e634b3f7ead134c2 \
--hash=sha256:76368c788ccb4f4782cf9c842b316140142b4cbf22ff8db82724e82fe1205dce \
--hash=sha256:7a894c51fd8c4e834f00ac742abad73fc485df1062f1b875661a3c1e1fb1c2f6 \
--hash=sha256:7dc90da0081f7e1da49ec4e398ede6a8e9cc4f5ebe5f9e06b443ed889ee9aaa2 \
--hash=sha256:848c6b5cad9898e4b9ef251b6f934fa34630371f2e916261070a4eb9092ffd33 \
--hash=sha256:899da829b362ade41e1e7eccad2cf274035e1cb36ba73034946fccd4afd8606b \
--hash=sha256:8ab81ccd753859ab89e67199b9da62c543850f819993761c1e94a75a814ed667 \
--hash=sha256:8fb49a0ba4d8f41198ae2d52118b050fd34dace4b8f3fb0ee34e23eb4ae775b1 \
--hash=sha256:9156ca1f79fc4acc226696e95bfcc2b486f165a6a59ebe22b2c1f82ab190384a \
--hash=sha256:9523f8b46485db6939bd069b28b642fec86c30909cea90ef550373787f79530e \
--hash=sha256:a0756a179afa766ad7cb6f036de622e8a8f16ffdd55aa31f296c870b5679d745 \
--hash=sha256:a0cdef204199278f5c461a0bed6ed2e052998276e6d8ab2963d5b5c39a0500bc \
--hash=sha256:ab83adc099ec62e044b1fbb3a05499fa1e99f6d53a1dde102b2d85eff66ed324 \
--hash=sha256:b34fa5e3b5d6dc7e0a4243fa0f81367027cb6f4a7215a17852979634b5544ee0 \
--hash=sha256:b47c551c6724960479cefd7353656498b86e7232429e3a41ab83be4da1b109e8 \
--hash=sha256:c4cd94dfefbefec3f8b544f61286584292d740e6e9d4677769bc76b8f41deb02 \
--hash=sha256:c4f982715e65036c34897eb598d64aef15150c447be2cfc6643ec7a11af06574 \
--hash=sha256:d8f699a709120b220dfe173f79c73cb2a2cab2c0b88dd59d7b49407d032b8ebd \
--hash=sha256:dd94ce596bda40a9618324547cfaaf6650b1a24f5390350142499aa4e34e53d1 \
--hash=sha256:de844aaa4815b78f6023832590d77da0e3b6805c644c33ce94a1e449f16d6ab5 \
--hash=sha256:e5f0642cdf4636198a4990de7a71b693d824c56a757862230454629cf62e323d \
--hash=sha256:f07fa2f15dabe91259828ce7d71b5ca9e2eb7c8c26baa822c825ce43552f4883 \
--hash=sha256:f15976718c004466406342789f31b6673776360f3b1e3c575f25302d7e789575 \
--hash=sha256:f358ea9e47eb3c2d6eba121ab512dfff38a88db719c38d1e67349af210bc7529 \
--hash=sha256:f505264735ee074250a9c78247ee8618292091d9d1fcc023290e9ac67e8f1afa \
--hash=sha256:f5ebbf9fbdabed208d4ecd2e1dfd2c0741af2f876e7ae522c2537d404ca895c3 \
--hash=sha256:f6b26e6c3b98adb648243670fddc8cab6ae17473f9dc58c51574af3e64d61211 \
--hash=sha256:f8e93a01a35be08d31ae33021e5268f157a2d60ebd643cfc15de6ab8e4722eb1 \
--hash=sha256:fe76d75b345dc045acdbc006adcb197cc680754afd6c259de60d358d60c93736 \
--hash=sha256:ffbd6faeb190aaf2b5e9024bac9622d2ee549b7ec89ef3a9373fa35313d44e0e
# via
# -r core/operators/dimension_reduction_requirements.in
# scikit-learn
# scipy
scikit-learn==1.5.1 \
--hash=sha256:0828673c5b520e879f2af6a9e99eee0eefea69a2188be1ca68a6121b809055c1 \
--hash=sha256:0ea5d40c0e3951df445721927448755d3fe1d80833b0b7308ebff5d2a45e6414 \
--hash=sha256:10e49170691514a94bb2e03787aa921b82dbc507a4ea1f20fd95557862c98dc1 \
--hash=sha256:154297ee43c0b83af12464adeab378dee2d0a700ccd03979e2b821e7dd7cc1c2 \
--hash=sha256:161808750c267b77b4a9603cf9c93579c7a74ba8486b1336034c2f1579546d21 \
--hash=sha256:1bd8d3a19d4bd6dc5a7d4f358c8c3a60934dc058f363c34c0ac1e9e12a31421d \
--hash=sha256:1ff4ba34c2abff5ec59c803ed1d97d61b036f659a17f55be102679e88f926fac \
--hash=sha256:508907e5f81390e16d754e8815f7497e52139162fd69c4fdbd2dfa5d6cc88915 \
--hash=sha256:5944ce1faada31c55fb2ba20a5346b88e36811aab504ccafb9f0339e9f780395 \
--hash=sha256:5f57428de0c900a98389c4a433d4a3cf89de979b3aa24d1c1d251802aa15e44d \
--hash=sha256:689b6f74b2c880276e365fe84fe4f1befd6a774f016339c65655eaff12e10cbf \
--hash=sha256:781586c414f8cc58e71da4f3d7af311e0505a683e112f2f62919e3019abd3745 \
--hash=sha256:7b073a27797a283187a4ef4ee149959defc350b46cbf63a84d8514fe16b69855 \
--hash=sha256:88e0672c7ac21eb149d409c74cc29f1d611d5158175846e7a9c2427bd12b3956 \
--hash=sha256:909144d50f367a513cee6090873ae582dba019cb3fca063b38054fa42704c3a4 \
--hash=sha256:97625f217c5c0c5d0505fa2af28ae424bd37949bb2f16ace3ff5f2f81fb4498b \
--hash=sha256:9a07f90846313a7639af6a019d849ff72baadfa4c74c778821ae0fad07b7275b \
--hash=sha256:b59e3e62d2be870e5c74af4e793293753565c7383ae82943b83383fdcf5cc5c1 \
--hash=sha256:b5e865e9bd59396220de49cb4a57b17016256637c61b4c5cc81aaf16bc123bbe \
--hash=sha256:da3f404e9e284d2b0a157e1b56b6566a34eb2798205cba35a211df3296ab7a74 \
--hash=sha256:f5b213bc29cc30a89a3130393b0e39c847a15d769d6e59539cd86b75d276b1a7
# via -r core/operators/dimension_reduction_requirements.in
scipy==1.14.1 \
--hash=sha256:0c2f95de3b04e26f5f3ad5bb05e74ba7f68b837133a4492414b3afd79dfe540e \
--hash=sha256:1729560c906963fc8389f6aac023739ff3983e727b1a4d87696b7bf108316a79 \
--hash=sha256:278266012eb69f4a720827bdd2dc54b2271c97d84255b2faaa8f161a158c3b37 \
--hash=sha256:2843f2d527d9eebec9a43e6b406fb7266f3af25a751aa91d62ff416f54170bc5 \
--hash=sha256:2da0469a4ef0ecd3693761acbdc20f2fdeafb69e6819cc081308cc978153c675 \
--hash=sha256:2ff0a7e01e422c15739ecd64432743cf7aae2b03f3084288f399affcefe5222d \
--hash=sha256:2ff38e22128e6c03ff73b6bb0f85f897d2362f8c052e3b8ad00532198fbdae3f \
--hash=sha256:30ac8812c1d2aab7131a79ba62933a2a76f582d5dbbc695192453dae67ad6310 \
--hash=sha256:3a1b111fac6baec1c1d92f27e76511c9e7218f1695d61b59e05e0fe04dc59617 \
--hash=sha256:4079b90df244709e675cdc8b93bfd8a395d59af40b72e339c2287c91860deb8e \
--hash=sha256:5149e3fd2d686e42144a093b206aef01932a0059c2a33ddfa67f5f035bdfe13e \
--hash=sha256:5a275584e726026a5699459aa72f828a610821006228e841b94275c4a7c08417 \
--hash=sha256:631f07b3734d34aced009aaf6fedfd0eb3498a97e581c3b1e5f14a04164a456d \
--hash=sha256:716e389b694c4bb564b4fc0c51bc84d381735e0d39d3f26ec1af2556ec6aad94 \
--hash=sha256:8426251ad1e4ad903a4514712d2fa8fdd5382c978010d1c6f5f37ef286a713ad \
--hash=sha256:8475230e55549ab3f207bff11ebfc91c805dc3463ef62eda3ccf593254524ce8 \
--hash=sha256:8bddf15838ba768bb5f5083c1ea012d64c9a444e16192762bd858f1e126196d0 \
--hash=sha256:8e32dced201274bf96899e6491d9ba3e9a5f6b336708656466ad0522d8528f69 \
--hash=sha256:8f9ea80f2e65bdaa0b7627fb00cbeb2daf163caa015e59b7516395fe3bd1e066 \
--hash=sha256:97c5dddd5932bd2a1a31c927ba5e1463a53b87ca96b5c9bdf5dfd6096e27efc3 \
--hash=sha256:a49f6ed96f83966f576b33a44257d869756df6cf1ef4934f59dd58b25e0327e5 \
--hash=sha256:af29a935803cc707ab2ed7791c44288a682f9c8107bc00f0eccc4f92c08d6e07 \
--hash=sha256:b05d43735bb2f07d689f56f7b474788a13ed8adc484a85aa65c0fd931cf9ccd2 \
--hash=sha256:b28d2ca4add7ac16ae8bb6632a3c86e4b9e4d52d3e34267f6e1b0c1f8d87e389 \
--hash=sha256:b99722ea48b7ea25e8e015e8341ae74624f72e5f21fc2abd45f3a93266de4c5d \
--hash=sha256:baff393942b550823bfce952bb62270ee17504d02a1801d7fd0719534dfb9c84 \
--hash=sha256:c0ee987efa6737242745f347835da2cc5bb9f1b42996a4d97d5c7ff7928cb6f2 \
--hash=sha256:d0d2821003174de06b69e58cef2316a6622b60ee613121199cb2852a873f8cf3 \
--hash=sha256:e0cf28db0f24a38b2a0ca33a85a54852586e43cf6fd876365c86e0657cfe7d73 \
--hash=sha256:e4f5a7c49323533f9103d4dacf4e4f07078f360743dec7f7596949149efeec06 \
--hash=sha256:eb58ca0abd96911932f688528977858681a59d61a7ce908ffd355957f7025cfc \
--hash=sha256:edaf02b82cd7639db00dbff629995ef185c8df4c3ffa71a5562a595765a06ce1 \
--hash=sha256:fef8c87f8abfb884dac04e97824b61299880c43f4ce675dd2cbeadd3c9b466d2
# via scikit-learn
threadpoolctl==3.5.0 \
--hash=sha256:082433502dd922bf738de0d8bcc4fdcbf0979ff44c42bd40f5af8a282f6fa107 \
--hash=sha256:56c1e26c150397e58c4926da8eeee87533b1e32bef131bd4bf6a2f45f3185467
# via scikit-learn
49 changes: 49 additions & 0 deletions src/core/operators/test_dimension_reduction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import unittest
import numpy as np
from core.operators.dimension_reduction import setup_reduction, perform_reduction


class TestDimensionReductionOperator(unittest.TestCase):
@classmethod
def setUpClass(cls):
# Initialize operator
cls.initial_params = {
'n_components': 2,
'perplexity': 30,
'learning_rate': 200,
'max_iter': 250,
'random_state': 42,
'method': 'barnes_hut'
}
setup_reduction('tsne', cls.initial_params)

@classmethod
def tearDownClass(cls):
# Clean up if necessary
pass

def test_tsne_reduction(self):
# Create sample embeddings
sample_embeddings = np.random.rand(100, 50) # 100 samples, 50 dimensions

input_data = [{'payload': str(i), 'embedding': embedding} for i, embedding in enumerate(sample_embeddings)]

# Perform reduction
reduced_data = perform_reduction(input_data)
reduced_embeddings = np.array([d['reduced_embedding'] for d in reduced_data])

# Check output shape
self.assertEqual(reduced_embeddings.shape, (100, 2)) # Should reduce to 2D

def test_invalid_input(self):
# Test with empty list
with self.assertRaises(ValueError):
perform_reduction([])

# Test with non-list input
with self.assertRaises(ValueError):
perform_reduction("not a list")

# Test with missing keys in input data
with self.assertRaises(KeyError):
perform_reduction([{'payload': '123'}])

0 comments on commit 2caf4a4

Please sign in to comment.