-
Notifications
You must be signed in to change notification settings - Fork 91
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
ENH make tools to list estimators within the project
- Loading branch information
Showing
11 changed files
with
277 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,4 @@ | ||
# Authors: scikit-learn-contrib developers | ||
# License: BSD 3 clause | ||
|
||
__version__ = "0.0.4.dev0" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
# Authors: scikit-learn-contrib developers | ||
# License: BSD 3 clause |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,16 @@ | ||
"""This file shows how to write test based on the scikit-learn common tests.""" | ||
|
||
# Authors: scikit-learn-contrib developers | ||
# License: BSD 3 clause | ||
|
||
from sklearn.utils.estimator_checks import parametrize_with_checks | ||
|
||
from skltemplate import TemplateClassifier, TemplateEstimator, TemplateTransformer | ||
from skltemplate.utils.discovery import all_estimators | ||
|
||
|
||
# parametrize_with_checks allows to get a generator of check that is more fine-grained | ||
# than check_estimator | ||
@parametrize_with_checks([TemplateEstimator(), TemplateTransformer(), TemplateClassifier()]) | ||
@parametrize_with_checks([est() for _, est in all_estimators()]) | ||
def test_estimators(estimator, check, request): | ||
"""Check the compatibility with scikit-learn API""" | ||
check(estimator) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
# Authors: scikit-learn-contrib developers | ||
# License: BSD 3 clause |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,223 @@ | ||
""" | ||
The :mod:`skltemplate.utils.discovery` module includes utilities to discover | ||
objects (i.e. estimators, displays, functions) from the `skltemplate` package. | ||
""" | ||
|
||
# Adapted from scikit-learn | ||
# Authors: scikit-learn-contrib developers | ||
# License: BSD 3 clause | ||
|
||
import inspect | ||
import pkgutil | ||
from importlib import import_module | ||
from operator import itemgetter | ||
from pathlib import Path | ||
|
||
from sklearn.base import ( | ||
BaseEstimator, | ||
ClassifierMixin, | ||
ClusterMixin, | ||
RegressorMixin, | ||
TransformerMixin, | ||
) | ||
from sklearn.utils._testing import ignore_warnings | ||
|
||
_MODULE_TO_IGNORE = {"tests"} | ||
|
||
|
||
def all_estimators(type_filter=None): | ||
"""Get a list of all estimators from `skltemplate`. | ||
This function crawls the module and gets all classes that inherit | ||
from `BaseEstimator`. Classes that are defined in test-modules are not | ||
included. | ||
Parameters | ||
---------- | ||
type_filter : {"classifier", "regressor", "cluster", "transformer"} \ | ||
or list of such str, default=None | ||
Which kind of estimators should be returned. If None, no filter is | ||
applied and all estimators are returned. Possible values are | ||
'classifier', 'regressor', 'cluster' and 'transformer' to get | ||
estimators only of these specific types, or a list of these to | ||
get the estimators that fit at least one of the types. | ||
Returns | ||
------- | ||
estimators : list of tuples | ||
List of (name, class), where ``name`` is the class name as string | ||
and ``class`` is the actual type of the class. | ||
Examples | ||
-------- | ||
>>> from skltemplate.utils.discovery import all_estimators | ||
>>> estimators = all_estimators() | ||
>>> type(estimators) | ||
<class 'list'> | ||
""" | ||
|
||
def is_abstract(c): | ||
if not (hasattr(c, "__abstractmethods__")): | ||
return False | ||
if not len(c.__abstractmethods__): | ||
return False | ||
return True | ||
|
||
all_classes = [] | ||
root = str(Path(__file__).parent.parent) # skltemplate package | ||
# Ignore deprecation warnings triggered at import time and from walking | ||
# packages | ||
with ignore_warnings(category=FutureWarning): | ||
for _, module_name, _ in pkgutil.walk_packages( | ||
path=[root], prefix="skltemplate." | ||
): | ||
module_parts = module_name.split(".") | ||
if ( | ||
any(part in _MODULE_TO_IGNORE for part in module_parts) | ||
): | ||
continue | ||
module = import_module(module_name) | ||
classes = inspect.getmembers(module, inspect.isclass) | ||
classes = [ | ||
(name, est_cls) for name, est_cls in classes if not name.startswith("_") | ||
] | ||
|
||
all_classes.extend(classes) | ||
|
||
all_classes = set(all_classes) | ||
|
||
estimators = [ | ||
c | ||
for c in all_classes | ||
if (issubclass(c[1], BaseEstimator) and c[0] != "BaseEstimator") | ||
] | ||
# get rid of abstract base classes | ||
estimators = [c for c in estimators if not is_abstract(c[1])] | ||
|
||
if type_filter is not None: | ||
if not isinstance(type_filter, list): | ||
type_filter = [type_filter] | ||
else: | ||
type_filter = list(type_filter) # copy | ||
filtered_estimators = [] | ||
filters = { | ||
"classifier": ClassifierMixin, | ||
"regressor": RegressorMixin, | ||
"transformer": TransformerMixin, | ||
"cluster": ClusterMixin, | ||
} | ||
for name, mixin in filters.items(): | ||
if name in type_filter: | ||
type_filter.remove(name) | ||
filtered_estimators.extend( | ||
[est for est in estimators if issubclass(est[1], mixin)] | ||
) | ||
estimators = filtered_estimators | ||
if type_filter: | ||
raise ValueError( | ||
"Parameter type_filter must be 'classifier', " | ||
"'regressor', 'transformer', 'cluster' or " | ||
"None, got" | ||
f" {repr(type_filter)}." | ||
) | ||
|
||
# drop duplicates, sort for reproducibility | ||
# itemgetter is used to ensure the sort does not extend to the 2nd item of | ||
# the tuple | ||
return sorted(set(estimators), key=itemgetter(0)) | ||
|
||
|
||
def all_displays(): | ||
"""Get a list of all displays from `skltemplate`. | ||
Returns | ||
------- | ||
displays : list of tuples | ||
List of (name, class), where ``name`` is the display class name as | ||
string and ``class`` is the actual type of the class. | ||
Examples | ||
-------- | ||
>>> from skltemplate.utils.discovery import all_displays | ||
>>> displays = all_displays() | ||
""" | ||
all_classes = [] | ||
root = str(Path(__file__).parent.parent) # skltemplate package | ||
# Ignore deprecation warnings triggered at import time and from walking | ||
# packages | ||
with ignore_warnings(category=FutureWarning): | ||
for _, module_name, _ in pkgutil.walk_packages( | ||
path=[root], prefix="skltemplate." | ||
): | ||
module_parts = module_name.split(".") | ||
if ( | ||
any(part in _MODULE_TO_IGNORE for part in module_parts) | ||
): | ||
continue | ||
module = import_module(module_name) | ||
classes = inspect.getmembers(module, inspect.isclass) | ||
classes = [ | ||
(name, display_class) | ||
for name, display_class in classes | ||
if not name.startswith("_") and name.endswith("Display") | ||
] | ||
all_classes.extend(classes) | ||
|
||
return sorted(set(all_classes), key=itemgetter(0)) | ||
|
||
|
||
def _is_checked_function(item): | ||
if not inspect.isfunction(item): | ||
return False | ||
|
||
if item.__name__.startswith("_"): | ||
return False | ||
|
||
mod = item.__module__ | ||
if not mod.startswith("skltemplate.") or mod.endswith("estimator_checks"): | ||
return False | ||
|
||
return True | ||
|
||
|
||
def all_functions(): | ||
"""Get a list of all functions from `skltemplate`. | ||
Returns | ||
------- | ||
functions : list of tuples | ||
List of (name, function), where ``name`` is the function name as | ||
string and ``function`` is the actual function. | ||
Examples | ||
-------- | ||
>>> from skltemplate.utils.discovery import all_functions | ||
>>> functions = all_functions() | ||
""" | ||
all_functions = [] | ||
root = str(Path(__file__).parent.parent) # skltemplate package | ||
# Ignore deprecation warnings triggered at import time and from walking | ||
# packages | ||
with ignore_warnings(category=FutureWarning): | ||
for _, module_name, _ in pkgutil.walk_packages( | ||
path=[root], prefix="skltemplate." | ||
): | ||
module_parts = module_name.split(".") | ||
if ( | ||
any(part in _MODULE_TO_IGNORE for part in module_parts) | ||
): | ||
continue | ||
|
||
module = import_module(module_name) | ||
functions = inspect.getmembers(module, _is_checked_function) | ||
functions = [ | ||
(func.__name__, func) | ||
for name, func in functions | ||
if not name.startswith("_") | ||
] | ||
all_functions.extend(functions) | ||
|
||
# drop duplicates, sort for reproducibility | ||
# itemgetter is used to ensure the sort does not extend to the 2nd item of | ||
# the tuple | ||
return sorted(set(all_functions), key=itemgetter(0)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
# Authors: scikit-learn-contrib developers | ||
# License: BSD 3 clause |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# Authors: scikit-learn-contrib developers | ||
# License: BSD 3 clause | ||
|
||
from skltemplate.utils.discovery import all_estimators, all_displays, all_functions | ||
|
||
|
||
def test_all_estimators(): | ||
estimators = all_estimators() | ||
assert len(estimators) == 3 | ||
|
||
|
||
def test_all_displays(): | ||
displays = all_displays() | ||
assert len(displays) == 0 | ||
|
||
|
||
def test_all_functions(): | ||
functions = all_functions() | ||
assert len(functions) == 3 |