Skip to content

Commit

Permalink
tests
Browse files Browse the repository at this point in the history
  • Loading branch information
SkBlaz committed May 13, 2024
1 parent d6dc5d3 commit 5898aea
Show file tree
Hide file tree
Showing 3 changed files with 190 additions and 0 deletions.
136 changes: 136 additions & 0 deletions outrank/algorithms/neural/mlp_nn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
from __future__ import annotations

import logging

import jax
import optax
from flax import linen as nn
from jax import numpy as jnp
from jax import random

logger = logging.getLogger('syn-logger')
logger.setLevel(logging.DEBUG)

key = random.PRNGKey(7235123)


class GenericJaxNN(nn.Module):
num_features: int
architecture: jnp.array

@nn.compact
def __call__(self, x):
for num_units in self.architecture:
x = nn.Dense(features=num_units)(x)
x = nn.relu(x)
x = nn.Dense(features=2)(x)
return x


class NNClassifier:

def __init__(self, learning_rate=0.001, architecture=[48, 48], epochs=100):
self.learning_rate = learning_rate
self.architecture = architecture
self.num_epochs = epochs
self.ncl = None

batch_size = 10
features = 5 # Number of input features to the model

# Pass the features and architecture parameters to the model constructor
self.mlp_model = GenericJaxNN(
num_features=features,
architecture=jnp.array(architecture),
)

def fit(self, X, Y, print_loss=True):

loss_grad_fn = jax.value_and_grad(self.forward_loss)
self.ncl = len(jnp.unique(X))
X = jax.nn.one_hot(X, num_classes=self.ncl).reshape(X.shape[0], -1)
sample_batch = jnp.ones((1, X.shape[1]))
self.init_internal_mlp(sample_batch)

for i in range(self.num_epochs):
loss_val, grads = loss_grad_fn(self.parameters, X, Y)
updates, self.opt_state = self.tx.update(grads, self.opt_state)
self.parameters = optax.apply_updates(self.parameters, updates)
if print_loss:
print(f'Loss step {i + 1}: ', loss_val.item())

def forward_loss(self, parameters, X, Y):

def get_logits(x):
pred_logits = self.mlp_model.apply(parameters, x)
return pred_logits

batch_logits = jax.vmap(get_logits)(X)
one_hot_labels = jax.nn.one_hot(Y, num_classes=2)
loss = optax.softmax_cross_entropy(
logits=batch_logits,
labels=one_hot_labels,
).mean()

return loss

def forward_pass(self, X):

def get_logits(x):
pred_logits = self.mlp_model.apply(self.parameters, x)
return jax.nn.softmax(pred_logits)

batch_logits = jax.vmap(get_logits)(X)
return batch_logits

def init_internal_mlp(self, sample_batch):

self.parameters = self.mlp_model.init(key, sample_batch)
self.tx = optax.adam(learning_rate=self.learning_rate)
self.opt_state = self.tx.init(self.parameters)

def selftest(self):
random_data = random.randint(
minval=0,
maxval=200,
key=key,
shape=(10, 1),
)
ncl = len(jnp.unique(random_data))
random_data = jax.nn.one_hot(random_data, num_classes=ncl).reshape(
random_data.shape[0], -1,
)
sample_batch = jnp.ones((1, random_data.shape[1]))
self.init_internal_mlp(sample_batch)

is_class1 = jnp.sum(random_data, axis=1) < 500
random_labels = (is_class1 + 0).astype(jnp.int32)
real_output = self.forward_pass(random_data)
assert jnp.sum(jnp.sum(real_output, axis=1)) == real_output.shape[0]

forward_loss = self.forward_loss(
self.parameters, random_data,
random_labels,
)
print('Self-test loss:', forward_loss.item())
assert forward_loss.item() > 0.6
self.fit(random_data, random_labels)

preds = self.predict(random_data)
assert len(preds) == 10
assert preds[0, 0] < 0.1

def predict(self, X):

if self.ncl is not None:
X_ohe = jax.nn.one_hot(X, num_classes=self.ncl).reshape(
X.shape[0], -1,
)
return self.forward_pass(X_ohe)
else:
logger.error('number of classes unknown (NNClassifier)!')


if __name__ == '__main__':
clf = NNClassifier(learning_rate=0.01, architecture=[48, 48], epochs=10)
clf.selftest()
4 changes: 4 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
flake8>=6.1.0
flux==1.3.5
jax==0.4.28
jaxlib==0.4.28
matplotlib>=3.7.2
numba>=0.55.1
numpy>=1.21.6
optax==0.2.2
pandas>=1.3.1
pathos>=0.2.9
pre-commit>=3.4.0
Expand Down
50 changes: 50 additions & 0 deletions tests/nn_module_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from __future__ import annotations

import sys
import unittest

import jax
import optax
from flax import linen as nn
from jax import numpy as jnp
from jax import random

from outrank.algorithms.neural.mlp_nn import NNClassifier
sys.path.append('./outrank')


class NNClassifierTest(unittest.TestCase):

def setUp(self):
# Common setup operations, run before each test method
self.learning_rate = 0.001
self.architecture = [48, 48]
self.epochs = 100
self.clf = NNClassifier(
self.learning_rate, self.architecture,
self.epochs,
)
self.key = random.PRNGKey(7235123)

def test_imports(self):
# Test imports to ensure they are available
try:
import jax
import optax
import flax
except ImportError as e:
self.fail(f'Import failed: {e}')

def test_initialization(self):
# Check if the NNClassifier is initialized correctly
self.assertIsInstance(self.clf, NNClassifier)
self.assertEqual(self.clf.learning_rate, self.learning_rate)
self.assertEqual(self.clf.architecture, self.architecture)
self.assertEqual(self.clf.num_epochs, self.epochs)

def test_self(self):
# Check if `selftest` runs without assertion errors.
try:
self.clf.selftest()
except AssertionError as e:
self.fail(f'selftest() failed with an AssertionError: {e}')

0 comments on commit 5898aea

Please sign in to comment.