diff --git a/outrank/algorithms/importance_estimator.py b/outrank/algorithms/importance_estimator.py index 56953c5..e19155c 100644 --- a/outrank/algorithms/importance_estimator.py +++ b/outrank/algorithms/importance_estimator.py @@ -18,6 +18,7 @@ from sklearn.preprocessing import OneHotEncoder from sklearn.svm import SVC +from outrank.algorithms.neural.mlp_nn import NNClassifier from outrank.core_utils import is_prior_heuristic logger = logging.getLogger('syn-logger') @@ -224,6 +225,8 @@ def initialize_classifier(surrogate_model: str): return LogisticRegression(max_iter=100000) elif 'surrogate-SVM' in surrogate_model: return SVC(gamma='auto', probability=True) + elif 'surrogate-NN' in surrogate_model: + return NNClassifier() elif 'surrogate-SGD' in surrogate_model: return SGDClassifier(max_iter=100000, loss='log_loss') else: diff --git a/outrank/algorithms/neural/mlp_nn.py b/outrank/algorithms/neural/mlp_nn.py new file mode 100644 index 0000000..b11d148 --- /dev/null +++ b/outrank/algorithms/neural/mlp_nn.py @@ -0,0 +1,136 @@ +from __future__ import annotations + +import logging + +import jax +import optax +from flax import linen as nn +from jax import numpy as jnp +from jax import random + +logger = logging.getLogger('syn-logger') +logger.setLevel(logging.DEBUG) + +key = random.PRNGKey(7235123) + + +class GenericJaxNN(nn.Module): + num_features: int + architecture: jnp.array + + @nn.compact + def __call__(self, x): + for num_units in self.architecture: + x = nn.Dense(features=num_units)(x) + x = nn.relu(x) + x = nn.Dense(features=2)(x) + return x + + +class NNClassifier: + + def __init__(self, learning_rate=0.001, architecture=[48, 48], epochs=100): + self.learning_rate = learning_rate + self.architecture = architecture + self.num_epochs = epochs + self.ncl = None + + batch_size = 10 + features = 5 # Number of input features to the model + + # Pass the features and architecture parameters to the model constructor + self.mlp_model = GenericJaxNN( + num_features=features, + architecture=jnp.array(architecture), + ) + + def fit(self, X, Y, print_loss=True): + + loss_grad_fn = jax.value_and_grad(self.forward_loss) + self.ncl = len(jnp.unique(X)) + X = jax.nn.one_hot(X, num_classes=self.ncl).reshape(X.shape[0], -1) + sample_batch = jnp.ones((1, X.shape[1])) + self.init_internal_mlp(sample_batch) + + for i in range(self.num_epochs): + loss_val, grads = loss_grad_fn(self.parameters, X, Y) + updates, self.opt_state = self.tx.update(grads, self.opt_state) + self.parameters = optax.apply_updates(self.parameters, updates) + if print_loss: + print(f'Loss step {i + 1}: ', loss_val.item()) + + def forward_loss(self, parameters, X, Y): + + def get_logits(x): + pred_logits = self.mlp_model.apply(parameters, x) + return pred_logits + + batch_logits = jax.vmap(get_logits)(X) + one_hot_labels = jax.nn.one_hot(Y, num_classes=2) + loss = optax.softmax_cross_entropy( + logits=batch_logits, + labels=one_hot_labels, + ).mean() + + return loss + + def forward_pass(self, X): + + def get_logits(x): + pred_logits = self.mlp_model.apply(self.parameters, x) + return jax.nn.softmax(pred_logits) + + batch_logits = jax.vmap(get_logits)(X) + return batch_logits + + def init_internal_mlp(self, sample_batch): + + self.parameters = self.mlp_model.init(key, sample_batch) + self.tx = optax.adam(learning_rate=self.learning_rate) + self.opt_state = self.tx.init(self.parameters) + + def selftest(self): + random_data = random.randint( + minval=0, + maxval=200, + key=key, + shape=(10, 1), + ) + ncl = len(jnp.unique(random_data)) + random_data = jax.nn.one_hot(random_data, num_classes=ncl).reshape( + random_data.shape[0], -1, + ) + sample_batch = jnp.ones((1, random_data.shape[1])) + self.init_internal_mlp(sample_batch) + + is_class1 = jnp.sum(random_data, axis=1) < 500 + random_labels = (is_class1 + 0).astype(jnp.int32) + real_output = self.forward_pass(random_data) + assert jnp.sum(jnp.sum(real_output, axis=1)) == real_output.shape[0] + + forward_loss = self.forward_loss( + self.parameters, random_data, + random_labels, + ) + print('Self-test loss:', forward_loss.item()) + assert forward_loss.item() > 0.6 + self.fit(random_data, random_labels) + + preds = self.predict(random_data) + assert len(preds) == 10 + assert preds[0, 0] < 0.1 + + def predict(self, X): + + if self.ncl is not None: + X_ohe = jax.nn.one_hot(X, num_classes=self.ncl).reshape( + X.shape[0], -1, + ) + return self.forward_pass(X_ohe) + else: + logger.error('number of classes unknown (NNClassifier)!') + + +if __name__ == '__main__': + clf = NNClassifier(learning_rate=0.01, architecture=[48, 48], epochs=10) + clf.selftest() diff --git a/requirements.txt b/requirements.txt index 852a1ba..1470797 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,11 @@ flake8>=6.1.0 +flax>=0.8.3 +jax>=0.4.28 +jaxlib>=0.4.28 matplotlib>=3.7.2 numba>=0.55.1 numpy>=1.21.6 +optax==0.2.2 pandas>=1.3.1 pathos>=0.2.9 pre-commit>=3.4.0 diff --git a/tests/nn_module_test.py b/tests/nn_module_test.py new file mode 100644 index 0000000..5937eaf --- /dev/null +++ b/tests/nn_module_test.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +import sys +import unittest + +import jax +import optax +from flax import linen as nn +from jax import numpy as jnp +from jax import random + +from outrank.algorithms.neural.mlp_nn import NNClassifier +sys.path.append('./outrank') + + +class NNClassifierTest(unittest.TestCase): + + def setUp(self): + # Common setup operations, run before each test method + self.learning_rate = 0.001 + self.architecture = [48, 48] + self.epochs = 100 + self.clf = NNClassifier( + self.learning_rate, self.architecture, + self.epochs, + ) + self.key = random.PRNGKey(7235123) + + def test_imports(self): + # Test imports to ensure they are available + try: + import jax + import optax + import flax + except ImportError as e: + self.fail(f'Import failed: {e}') + + def test_initialization(self): + # Check if the NNClassifier is initialized correctly + self.assertIsInstance(self.clf, NNClassifier) + self.assertEqual(self.clf.learning_rate, self.learning_rate) + self.assertEqual(self.clf.architecture, self.architecture) + self.assertEqual(self.clf.num_epochs, self.epochs) + + def test_self(self): + # Check if `selftest` runs without assertion errors. + try: + self.clf.selftest() + except AssertionError as e: + self.fail(f'selftest() failed with an AssertionError: {e}')