Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MLP surrogate support #68

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 136 additions & 0 deletions outrank/algorithms/neural/mlp_nn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
from __future__ import annotations

import logging

import jax
import optax
from flax import linen as nn
from jax import numpy as jnp
from jax import random

logger = logging.getLogger('syn-logger')
logger.setLevel(logging.DEBUG)

key = random.PRNGKey(7235123)


class GenericJaxNN(nn.Module):
num_features: int
architecture: jnp.array

@nn.compact
def __call__(self, x):
for num_units in self.architecture:
x = nn.Dense(features=num_units)(x)
x = nn.relu(x)
x = nn.Dense(features=2)(x)
return x


class NNClassifier:

def __init__(self, learning_rate=0.001, architecture=[48, 48], epochs=100):
self.learning_rate = learning_rate
self.architecture = architecture
self.num_epochs = epochs
self.ncl = None

batch_size = 10
features = 5 # Number of input features to the model

# Pass the features and architecture parameters to the model constructor
self.mlp_model = GenericJaxNN(
num_features=features,
architecture=jnp.array(architecture),
)

def fit(self, X, Y, print_loss=True):

loss_grad_fn = jax.value_and_grad(self.forward_loss)
self.ncl = len(jnp.unique(X))
X = jax.nn.one_hot(X, num_classes=self.ncl).reshape(X.shape[0], -1)
sample_batch = jnp.ones((1, X.shape[1]))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you test perf with other initializers? Not sure how well that behaves with a hardcoded PRNGKey.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It shouldn't and it doesn't matter for a somewhat ok space of categorical problems we're considering here.

self.init_internal_mlp(sample_batch)

for i in range(self.num_epochs):
loss_val, grads = loss_grad_fn(self.parameters, X, Y)
updates, self.opt_state = self.tx.update(grads, self.opt_state)
self.parameters = optax.apply_updates(self.parameters, updates)
if print_loss:
print(f'Loss step {i + 1}: ', loss_val.item())

def forward_loss(self, parameters, X, Y):

def get_logits(x):
pred_logits = self.mlp_model.apply(parameters, x)
return pred_logits

batch_logits = jax.vmap(get_logits)(X)
one_hot_labels = jax.nn.one_hot(Y, num_classes=2)
loss = optax.softmax_cross_entropy(
logits=batch_logits,
labels=one_hot_labels,
).mean()

return loss

def forward_pass(self, X):

def get_logits(x):
pred_logits = self.mlp_model.apply(self.parameters, x)
return jax.nn.softmax(pred_logits)

batch_logits = jax.vmap(get_logits)(X)
return batch_logits

def init_internal_mlp(self, sample_batch):

self.parameters = self.mlp_model.init(key, sample_batch)
self.tx = optax.adam(learning_rate=self.learning_rate)
self.opt_state = self.tx.init(self.parameters)

def selftest(self):
random_data = random.randint(
minval=0,
maxval=200,
key=key,
shape=(10, 1),
)
ncl = len(jnp.unique(random_data))
random_data = jax.nn.one_hot(random_data, num_classes=ncl).reshape(
random_data.shape[0], -1,
)
sample_batch = jnp.ones((1, random_data.shape[1]))
self.init_internal_mlp(sample_batch)

is_class1 = jnp.sum(random_data, axis=1) < 500
random_labels = (is_class1 + 0).astype(jnp.int32)
real_output = self.forward_pass(random_data)
assert jnp.sum(jnp.sum(real_output, axis=1)) == real_output.shape[0]

forward_loss = self.forward_loss(
self.parameters, random_data,
random_labels,
)
print('Self-test loss:', forward_loss.item())
assert forward_loss.item() > 0.6
self.fit(random_data, random_labels)

preds = self.predict(random_data)
assert len(preds) == 10
assert preds[0, 0] < 0.1

def predict(self, X):

if self.ncl is not None:
X_ohe = jax.nn.one_hot(X, num_classes=self.ncl).reshape(
X.shape[0], -1,
)
return self.forward_pass(X_ohe)
else:
logger.error('number of classes unknown (NNClassifier)!')


if __name__ == '__main__':
clf = NNClassifier(learning_rate=0.01, architecture=[48, 48], epochs=10)
clf.selftest()
4 changes: 4 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
flake8>=6.1.0
flux==1.3.5
jax==0.4.28
jaxlib==0.4.28
matplotlib>=3.7.2
numba>=0.55.1
numpy>=1.21.6
optax==0.2.2
pandas>=1.3.1
pathos>=0.2.9
pre-commit>=3.4.0
Expand Down
50 changes: 50 additions & 0 deletions tests/nn_module_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from __future__ import annotations

import sys
import unittest

import jax
import optax
from flax import linen as nn
from jax import numpy as jnp
from jax import random

from outrank.algorithms.neural.mlp_nn import NNClassifier
sys.path.append('./outrank')


class NNClassifierTest(unittest.TestCase):

def setUp(self):
# Common setup operations, run before each test method
self.learning_rate = 0.001
self.architecture = [48, 48]
self.epochs = 100
self.clf = NNClassifier(
self.learning_rate, self.architecture,
self.epochs,
)
self.key = random.PRNGKey(7235123)

def test_imports(self):
# Test imports to ensure they are available
try:
import jax
import optax
import flax
except ImportError as e:
self.fail(f'Import failed: {e}')

def test_initialization(self):
# Check if the NNClassifier is initialized correctly

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was incorrect initialization an issue?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better safe than sorry

self.assertIsInstance(self.clf, NNClassifier)
self.assertEqual(self.clf.learning_rate, self.learning_rate)
self.assertEqual(self.clf.architecture, self.architecture)
self.assertEqual(self.clf.num_epochs, self.epochs)

def test_self(self):
# Check if `selftest` runs without assertion errors.
try:
self.clf.selftest()
except AssertionError as e:
self.fail(f'selftest() failed with an AssertionError: {e}')
Loading