From 990aa3b1c083c145dbe2d719d9d3a27ee4ddb3b5 Mon Sep 17 00:00:00 2001 From: bskrlj Date: Fri, 26 Jul 2024 12:35:51 +0200 Subject: [PATCH] threads --- outrank/algorithms/importance_estimator.py | 3 +++ requirements.txt | 6 +++--- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/outrank/algorithms/importance_estimator.py b/outrank/algorithms/importance_estimator.py index 56953c5..e19155c 100644 --- a/outrank/algorithms/importance_estimator.py +++ b/outrank/algorithms/importance_estimator.py @@ -18,6 +18,7 @@ from sklearn.preprocessing import OneHotEncoder from sklearn.svm import SVC +from outrank.algorithms.neural.mlp_nn import NNClassifier from outrank.core_utils import is_prior_heuristic logger = logging.getLogger('syn-logger') @@ -224,6 +225,8 @@ def initialize_classifier(surrogate_model: str): return LogisticRegression(max_iter=100000) elif 'surrogate-SVM' in surrogate_model: return SVC(gamma='auto', probability=True) + elif 'surrogate-NN' in surrogate_model: + return NNClassifier() elif 'surrogate-SGD' in surrogate_model: return SGDClassifier(max_iter=100000, loss='log_loss') else: diff --git a/requirements.txt b/requirements.txt index 4407704..1470797 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ flake8>=6.1.0 -flax==0.8.3 -jax==0.4.28 -jaxlib==0.4.28 +flax>=0.8.3 +jax>=0.4.28 +jaxlib>=0.4.28 matplotlib>=3.7.2 numba>=0.55.1 numpy>=1.21.6