Skip to content

Commit

Permalink
QuaPyWrapper supports likelihood methods
Browse files Browse the repository at this point in the history
  • Loading branch information
mirkobunse committed Sep 23, 2024
1 parent aa58316 commit e6f03c3
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 8 deletions.
6 changes: 4 additions & 2 deletions qunfold/quapy.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class QuaPyWrapper(BaseQuantifier):
"""A thin wrapper for using qunfold methods in QuaPy.
Args:
generic_method: A LinearMethod method to wrap.
generic_method: An instance of `qunfold.methods.AbstractMethod` to wrap.
Examples:
Here, we wrap an instance of ACC to perform a grid search with QuaPy.
Expand All @@ -87,4 +87,6 @@ def set_params(self, **params):
_set_params(self.generic_method, self.get_params(deep=True), **params)
return self
def get_params(self, deep=True):
return _get_params(self.generic_method, deep, LinearMethod)
if isinstance(self.generic_method, LinearMethod): # use super-class constructor?
return _get_params(self.generic_method, deep, LinearMethod)
return _get_params(self.generic_method, deep)
31 changes: 25 additions & 6 deletions qunfold/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,13 +141,18 @@ def test_methods(self):
n_estimators = 10,
random_state = RNG.randint(np.iinfo("uint16").max),
)
p_acc = QuaPyWrapper(qunfold.ACC(lr))
wrapped_acc = QuaPyWrapper(qunfold.ACC(lr))
wrapped_sld = QuaPyWrapper(qunfold.ExpectationMaximizer(lr.estimator))
self.assertEqual( # check that get_params returns the correct settings
p_acc.get_params(deep=True)["representation__classifier__estimator__C"],
wrapped_acc.get_params(deep=True)["representation__classifier__estimator__C"],
1e-2
)
quapy_method = qp.model_selection.GridSearchQ(
model = p_acc,
self.assertEqual(
wrapped_sld.get_params(deep=True)["classifier__C"],
1e-2
)
cv_acc = qp.model_selection.GridSearchQ(
model = wrapped_acc,
param_grid = {
"representation__classifier__estimator__C": [1e-1, 1e0, 1e1, 1e2],
},
Expand All @@ -157,8 +162,22 @@ def test_methods(self):
verbose = True,
).fit(qp.data.LabelledCollection(X_trn, y_trn))
self.assertEqual( # check that best parameters are actually used
quapy_method.best_params_["representation__classifier__estimator__C"],
quapy_method.best_model_.generic_method.representation.classifier.estimator.C
cv_acc.best_params_["representation__classifier__estimator__C"],
cv_acc.best_model_.generic_method.representation.classifier.estimator.C
)
cv_sld = qp.model_selection.GridSearchQ(
model = wrapped_sld,
param_grid = {
"classifier__C": [1e-1, 1e0, 1e1, 1e2],
},
protocol = SingleSampleProtocol(X_tst, p_tst),
error = "mae",
refit = False,
verbose = True,
).fit(qp.data.LabelledCollection(X_trn, y_trn))
self.assertEqual( # check that best parameters are actually used
cv_sld.best_params_["classifier__C"],
cv_sld.best_model_.generic_method.classifier.C
)

class TestDistanceRepresentation(TestCase):
Expand Down

0 comments on commit e6f03c3

Please sign in to comment.