From e6f03c37c5af665c7bbcb2774fd5d3c985383c84 Mon Sep 17 00:00:00 2001 From: Mirko Bunse Date: Mon, 23 Sep 2024 11:42:53 +0200 Subject: [PATCH] QuaPyWrapper supports likelihood methods --- qunfold/quapy.py | 6 ++++-- qunfold/tests/__init__.py | 31 +++++++++++++++++++++++++------ 2 files changed, 29 insertions(+), 8 deletions(-) diff --git a/qunfold/quapy.py b/qunfold/quapy.py index a84d42f..54ac2c1 100644 --- a/qunfold/quapy.py +++ b/qunfold/quapy.py @@ -62,7 +62,7 @@ class QuaPyWrapper(BaseQuantifier): """A thin wrapper for using qunfold methods in QuaPy. Args: - generic_method: A LinearMethod method to wrap. + generic_method: An instance of `qunfold.methods.AbstractMethod` to wrap. Examples: Here, we wrap an instance of ACC to perform a grid search with QuaPy. @@ -87,4 +87,6 @@ def set_params(self, **params): _set_params(self.generic_method, self.get_params(deep=True), **params) return self def get_params(self, deep=True): - return _get_params(self.generic_method, deep, LinearMethod) + if isinstance(self.generic_method, LinearMethod): # use super-class constructor? + return _get_params(self.generic_method, deep, LinearMethod) + return _get_params(self.generic_method, deep) diff --git a/qunfold/tests/__init__.py b/qunfold/tests/__init__.py index d699d09..8c51f98 100644 --- a/qunfold/tests/__init__.py +++ b/qunfold/tests/__init__.py @@ -141,13 +141,18 @@ def test_methods(self): n_estimators = 10, random_state = RNG.randint(np.iinfo("uint16").max), ) - p_acc = QuaPyWrapper(qunfold.ACC(lr)) + wrapped_acc = QuaPyWrapper(qunfold.ACC(lr)) + wrapped_sld = QuaPyWrapper(qunfold.ExpectationMaximizer(lr.estimator)) self.assertEqual( # check that get_params returns the correct settings - p_acc.get_params(deep=True)["representation__classifier__estimator__C"], + wrapped_acc.get_params(deep=True)["representation__classifier__estimator__C"], 1e-2 ) - quapy_method = qp.model_selection.GridSearchQ( - model = p_acc, + self.assertEqual( + wrapped_sld.get_params(deep=True)["classifier__C"], + 1e-2 + ) + cv_acc = qp.model_selection.GridSearchQ( + model = wrapped_acc, param_grid = { "representation__classifier__estimator__C": [1e-1, 1e0, 1e1, 1e2], }, @@ -157,8 +162,22 @@ def test_methods(self): verbose = True, ).fit(qp.data.LabelledCollection(X_trn, y_trn)) self.assertEqual( # check that best parameters are actually used - quapy_method.best_params_["representation__classifier__estimator__C"], - quapy_method.best_model_.generic_method.representation.classifier.estimator.C + cv_acc.best_params_["representation__classifier__estimator__C"], + cv_acc.best_model_.generic_method.representation.classifier.estimator.C + ) + cv_sld = qp.model_selection.GridSearchQ( + model = wrapped_sld, + param_grid = { + "classifier__C": [1e-1, 1e0, 1e1, 1e2], + }, + protocol = SingleSampleProtocol(X_tst, p_tst), + error = "mae", + refit = False, + verbose = True, + ).fit(qp.data.LabelledCollection(X_trn, y_trn)) + self.assertEqual( # check that best parameters are actually used + cv_sld.best_params_["classifier__C"], + cv_sld.best_model_.generic_method.classifier.C ) class TestDistanceRepresentation(TestCase):