diff --git a/doc/api.rst b/doc/api.rst index 6053b4b..0bbac09 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -32,3 +32,15 @@ Predictor :template: class.rst TemplateClassifier + + +Utilities +========= + +.. autosummary:: + :toctree: generated/ + :template: functions.rst + + utils.discovery.all_estimators + utils.discovery.all_displays + utils.discovery.all_functions diff --git a/skltemplate/__init__.py b/skltemplate/__init__.py index 879c4c6..66806e5 100644 --- a/skltemplate/__init__.py +++ b/skltemplate/__init__.py @@ -1,3 +1,6 @@ +# Authors: scikit-learn-contrib developers +# License: BSD 3 clause + from ._template import TemplateClassifier, TemplateEstimator, TemplateTransformer from ._version import __version__ diff --git a/skltemplate/_template.py b/skltemplate/_template.py index aa21f69..ce41d0f 100644 --- a/skltemplate/_template.py +++ b/skltemplate/_template.py @@ -1,6 +1,10 @@ """ This is a module to be used as a reference for building other modules """ + +# Authors: scikit-learn-contrib developers +# License: BSD 3 clause + import numpy as np from sklearn.base import BaseEstimator, ClassifierMixin, TransformerMixin, _fit_context from sklearn.metrics import euclidean_distances diff --git a/skltemplate/_version.py b/skltemplate/_version.py index d1b5dc8..d1cbff4 100644 --- a/skltemplate/_version.py +++ b/skltemplate/_version.py @@ -1 +1,4 @@ +# Authors: scikit-learn-contrib developers +# License: BSD 3 clause + __version__ = "0.0.4.dev0" diff --git a/skltemplate/tests/__init__.py b/skltemplate/tests/__init__.py index e69de29..52e3f96 100644 --- a/skltemplate/tests/__init__.py +++ b/skltemplate/tests/__init__.py @@ -0,0 +1,2 @@ +# Authors: scikit-learn-contrib developers +# License: BSD 3 clause diff --git a/skltemplate/tests/test_common.py b/skltemplate/tests/test_common.py index 381c11b..4249e7f 100644 --- a/skltemplate/tests/test_common.py +++ b/skltemplate/tests/test_common.py @@ -1,13 +1,16 @@ """This file shows how to write test based on the scikit-learn common tests.""" +# Authors: scikit-learn-contrib developers +# License: BSD 3 clause + from sklearn.utils.estimator_checks import parametrize_with_checks -from skltemplate import TemplateClassifier, TemplateEstimator, TemplateTransformer +from skltemplate.utils.discovery import all_estimators # parametrize_with_checks allows to get a generator of check that is more fine-grained # than check_estimator -@parametrize_with_checks([TemplateEstimator(), TemplateTransformer(), TemplateClassifier()]) +@parametrize_with_checks([est() for _, est in all_estimators()]) def test_estimators(estimator, check, request): """Check the compatibility with scikit-learn API""" check(estimator) diff --git a/skltemplate/tests/test_template.py b/skltemplate/tests/test_template.py index 6803945..b34b106 100644 --- a/skltemplate/tests/test_template.py +++ b/skltemplate/tests/test_template.py @@ -6,6 +6,8 @@ from skltemplate import TemplateClassifier, TemplateEstimator, TemplateTransformer +# Authors: scikit-learn-contrib developers +# License: BSD 3 clause @pytest.fixture def data(): diff --git a/skltemplate/utils/__init__.py b/skltemplate/utils/__init__.py new file mode 100644 index 0000000..52e3f96 --- /dev/null +++ b/skltemplate/utils/__init__.py @@ -0,0 +1,2 @@ +# Authors: scikit-learn-contrib developers +# License: BSD 3 clause diff --git a/skltemplate/utils/discovery.py b/skltemplate/utils/discovery.py new file mode 100644 index 0000000..843aa7a --- /dev/null +++ b/skltemplate/utils/discovery.py @@ -0,0 +1,223 @@ +""" +The :mod:`skltemplate.utils.discovery` module includes utilities to discover +objects (i.e. estimators, displays, functions) from the `skltemplate` package. +""" + +# Adapted from scikit-learn +# Authors: scikit-learn-contrib developers +# License: BSD 3 clause + +import inspect +import pkgutil +from importlib import import_module +from operator import itemgetter +from pathlib import Path + +from sklearn.base import ( + BaseEstimator, + ClassifierMixin, + ClusterMixin, + RegressorMixin, + TransformerMixin, +) +from sklearn.utils._testing import ignore_warnings + +_MODULE_TO_IGNORE = {"tests"} + + +def all_estimators(type_filter=None): + """Get a list of all estimators from `skltemplate`. + + This function crawls the module and gets all classes that inherit + from `BaseEstimator`. Classes that are defined in test-modules are not + included. + + Parameters + ---------- + type_filter : {"classifier", "regressor", "cluster", "transformer"} \ + or list of such str, default=None + Which kind of estimators should be returned. If None, no filter is + applied and all estimators are returned. Possible values are + 'classifier', 'regressor', 'cluster' and 'transformer' to get + estimators only of these specific types, or a list of these to + get the estimators that fit at least one of the types. + + Returns + ------- + estimators : list of tuples + List of (name, class), where ``name`` is the class name as string + and ``class`` is the actual type of the class. + + Examples + -------- + >>> from skltemplate.utils.discovery import all_estimators + >>> estimators = all_estimators() + >>> type(estimators) + + """ + + def is_abstract(c): + if not (hasattr(c, "__abstractmethods__")): + return False + if not len(c.__abstractmethods__): + return False + return True + + all_classes = [] + root = str(Path(__file__).parent.parent) # skltemplate package + # Ignore deprecation warnings triggered at import time and from walking + # packages + with ignore_warnings(category=FutureWarning): + for _, module_name, _ in pkgutil.walk_packages( + path=[root], prefix="skltemplate." + ): + module_parts = module_name.split(".") + if ( + any(part in _MODULE_TO_IGNORE for part in module_parts) + ): + continue + module = import_module(module_name) + classes = inspect.getmembers(module, inspect.isclass) + classes = [ + (name, est_cls) for name, est_cls in classes if not name.startswith("_") + ] + + all_classes.extend(classes) + + all_classes = set(all_classes) + + estimators = [ + c + for c in all_classes + if (issubclass(c[1], BaseEstimator) and c[0] != "BaseEstimator") + ] + # get rid of abstract base classes + estimators = [c for c in estimators if not is_abstract(c[1])] + + if type_filter is not None: + if not isinstance(type_filter, list): + type_filter = [type_filter] + else: + type_filter = list(type_filter) # copy + filtered_estimators = [] + filters = { + "classifier": ClassifierMixin, + "regressor": RegressorMixin, + "transformer": TransformerMixin, + "cluster": ClusterMixin, + } + for name, mixin in filters.items(): + if name in type_filter: + type_filter.remove(name) + filtered_estimators.extend( + [est for est in estimators if issubclass(est[1], mixin)] + ) + estimators = filtered_estimators + if type_filter: + raise ValueError( + "Parameter type_filter must be 'classifier', " + "'regressor', 'transformer', 'cluster' or " + "None, got" + f" {repr(type_filter)}." + ) + + # drop duplicates, sort for reproducibility + # itemgetter is used to ensure the sort does not extend to the 2nd item of + # the tuple + return sorted(set(estimators), key=itemgetter(0)) + + +def all_displays(): + """Get a list of all displays from `skltemplate`. + + Returns + ------- + displays : list of tuples + List of (name, class), where ``name`` is the display class name as + string and ``class`` is the actual type of the class. + + Examples + -------- + >>> from skltemplate.utils.discovery import all_displays + >>> displays = all_displays() + """ + all_classes = [] + root = str(Path(__file__).parent.parent) # skltemplate package + # Ignore deprecation warnings triggered at import time and from walking + # packages + with ignore_warnings(category=FutureWarning): + for _, module_name, _ in pkgutil.walk_packages( + path=[root], prefix="skltemplate." + ): + module_parts = module_name.split(".") + if ( + any(part in _MODULE_TO_IGNORE for part in module_parts) + ): + continue + module = import_module(module_name) + classes = inspect.getmembers(module, inspect.isclass) + classes = [ + (name, display_class) + for name, display_class in classes + if not name.startswith("_") and name.endswith("Display") + ] + all_classes.extend(classes) + + return sorted(set(all_classes), key=itemgetter(0)) + + +def _is_checked_function(item): + if not inspect.isfunction(item): + return False + + if item.__name__.startswith("_"): + return False + + mod = item.__module__ + if not mod.startswith("skltemplate.") or mod.endswith("estimator_checks"): + return False + + return True + + +def all_functions(): + """Get a list of all functions from `skltemplate`. + + Returns + ------- + functions : list of tuples + List of (name, function), where ``name`` is the function name as + string and ``function`` is the actual function. + + Examples + -------- + >>> from skltemplate.utils.discovery import all_functions + >>> functions = all_functions() + """ + all_functions = [] + root = str(Path(__file__).parent.parent) # skltemplate package + # Ignore deprecation warnings triggered at import time and from walking + # packages + with ignore_warnings(category=FutureWarning): + for _, module_name, _ in pkgutil.walk_packages( + path=[root], prefix="skltemplate." + ): + module_parts = module_name.split(".") + if ( + any(part in _MODULE_TO_IGNORE for part in module_parts) + ): + continue + + module = import_module(module_name) + functions = inspect.getmembers(module, _is_checked_function) + functions = [ + (func.__name__, func) + for name, func in functions + if not name.startswith("_") + ] + all_functions.extend(functions) + + # drop duplicates, sort for reproducibility + # itemgetter is used to ensure the sort does not extend to the 2nd item of + # the tuple + return sorted(set(all_functions), key=itemgetter(0)) diff --git a/skltemplate/utils/tests/__init__.py b/skltemplate/utils/tests/__init__.py new file mode 100644 index 0000000..52e3f96 --- /dev/null +++ b/skltemplate/utils/tests/__init__.py @@ -0,0 +1,2 @@ +# Authors: scikit-learn-contrib developers +# License: BSD 3 clause diff --git a/skltemplate/utils/tests/test_discovery.py b/skltemplate/utils/tests/test_discovery.py new file mode 100644 index 0000000..26dbdef --- /dev/null +++ b/skltemplate/utils/tests/test_discovery.py @@ -0,0 +1,19 @@ +# Authors: scikit-learn-contrib developers +# License: BSD 3 clause + +from skltemplate.utils.discovery import all_estimators, all_displays, all_functions + + +def test_all_estimators(): + estimators = all_estimators() + assert len(estimators) == 3 + + +def test_all_displays(): + displays = all_displays() + assert len(displays) == 0 + + +def test_all_functions(): + functions = all_functions() + assert len(functions) == 3 \ No newline at end of file