Skip to content

Commit

Permalink
ENH make tools to list estimators within the project
Browse files Browse the repository at this point in the history
  • Loading branch information
glemaitre committed May 1, 2024
1 parent 251b4f9 commit 7d15d78
Show file tree
Hide file tree
Showing 11 changed files with 274 additions and 4 deletions.
12 changes: 12 additions & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,15 @@ Predictor
:template: class.rst

TemplateClassifier


Utilities
=========

.. autosummary::
:toctree: generated/
:template: functions.rst

utils.discovery.all_estimators
utils.discovery.all_displays
utils.discovery.all_functions
3 changes: 3 additions & 0 deletions skltemplate/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Authors: scikit-learn-contrib developers
# License: BSD 3 clause

from ._template import TemplateClassifier, TemplateEstimator, TemplateTransformer
from ._version import __version__

Expand Down
8 changes: 6 additions & 2 deletions skltemplate/_template.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
"""
This is a module to be used as a reference for building other modules
"""

# Authors: scikit-learn-contrib developers
# License: BSD 3 clause

import numpy as np
from sklearn.base import BaseEstimator, ClassifierMixin, TransformerMixin, _fit_context
from sklearn.metrics import euclidean_distances
from sklearn.utils.multiclass import check_classification_targets
from sklearn.utils.validation import check_is_fitted
from sklearn.utils.validation import check_is_fitted


class TemplateEstimator(BaseEstimator):
Expand Down Expand Up @@ -301,4 +305,4 @@ def _more_tags(self):
# https://scikit-learn.org/dev/developers/develop.html#estimator-tags
# Here, our transformer does not do any operation in `fit` and only validate
# the parameters. Thus, it is stateless.
return {'stateless': True}
return {"stateless": True}
3 changes: 3 additions & 0 deletions skltemplate/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
# Authors: scikit-learn-contrib developers
# License: BSD 3 clause

__version__ = "0.0.4.dev0"
2 changes: 2 additions & 0 deletions skltemplate/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Authors: scikit-learn-contrib developers
# License: BSD 3 clause
7 changes: 5 additions & 2 deletions skltemplate/tests/test_common.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
"""This file shows how to write test based on the scikit-learn common tests."""

# Authors: scikit-learn-contrib developers
# License: BSD 3 clause

from sklearn.utils.estimator_checks import parametrize_with_checks

from skltemplate import TemplateClassifier, TemplateEstimator, TemplateTransformer
from skltemplate.utils.discovery import all_estimators


# parametrize_with_checks allows to get a generator of check that is more fine-grained
# than check_estimator
@parametrize_with_checks([TemplateEstimator(), TemplateTransformer(), TemplateClassifier()])
@parametrize_with_checks([est() for _, est in all_estimators()])
def test_estimators(estimator, check, request):
"""Check the compatibility with scikit-learn API"""
check(estimator)
3 changes: 3 additions & 0 deletions skltemplate/tests/test_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

from skltemplate import TemplateClassifier, TemplateEstimator, TemplateTransformer

# Authors: scikit-learn-contrib developers
# License: BSD 3 clause


@pytest.fixture
def data():
Expand Down
2 changes: 2 additions & 0 deletions skltemplate/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Authors: scikit-learn-contrib developers
# License: BSD 3 clause
217 changes: 217 additions & 0 deletions skltemplate/utils/discovery.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
"""
The :mod:`skltemplate.utils.discovery` module includes utilities to discover
objects (i.e. estimators, displays, functions) from the `skltemplate` package.
"""

# Adapted from scikit-learn
# Authors: scikit-learn-contrib developers
# License: BSD 3 clause

import inspect
import pkgutil
from importlib import import_module
from operator import itemgetter
from pathlib import Path

from sklearn.base import (
BaseEstimator,
ClassifierMixin,
ClusterMixin,
RegressorMixin,
TransformerMixin,
)
from sklearn.utils._testing import ignore_warnings

_MODULE_TO_IGNORE = {"tests"}


def all_estimators(type_filter=None):
"""Get a list of all estimators from `skltemplate`.
This function crawls the module and gets all classes that inherit
from `BaseEstimator`. Classes that are defined in test-modules are not
included.
Parameters
----------
type_filter : {"classifier", "regressor", "cluster", "transformer"} \
or list of such str, default=None
Which kind of estimators should be returned. If None, no filter is
applied and all estimators are returned. Possible values are
'classifier', 'regressor', 'cluster' and 'transformer' to get
estimators only of these specific types, or a list of these to
get the estimators that fit at least one of the types.
Returns
-------
estimators : list of tuples
List of (name, class), where ``name`` is the class name as string
and ``class`` is the actual type of the class.
Examples
--------
>>> from skltemplate.utils.discovery import all_estimators
>>> estimators = all_estimators()
>>> type(estimators)
<class 'list'>
"""

def is_abstract(c):
if not (hasattr(c, "__abstractmethods__")):
return False
if not len(c.__abstractmethods__):
return False
return True

all_classes = []
root = str(Path(__file__).parent.parent) # skltemplate package
# Ignore deprecation warnings triggered at import time and from walking
# packages
with ignore_warnings(category=FutureWarning):
for _, module_name, _ in pkgutil.walk_packages(
path=[root], prefix="skltemplate."
):
module_parts = module_name.split(".")
if any(part in _MODULE_TO_IGNORE for part in module_parts):
continue
module = import_module(module_name)
classes = inspect.getmembers(module, inspect.isclass)
classes = [
(name, est_cls) for name, est_cls in classes if not name.startswith("_")
]

all_classes.extend(classes)

all_classes = set(all_classes)

estimators = [
c
for c in all_classes
if (issubclass(c[1], BaseEstimator) and c[0] != "BaseEstimator")
]
# get rid of abstract base classes
estimators = [c for c in estimators if not is_abstract(c[1])]

if type_filter is not None:
if not isinstance(type_filter, list):
type_filter = [type_filter]
else:
type_filter = list(type_filter) # copy
filtered_estimators = []
filters = {
"classifier": ClassifierMixin,
"regressor": RegressorMixin,
"transformer": TransformerMixin,
"cluster": ClusterMixin,
}
for name, mixin in filters.items():
if name in type_filter:
type_filter.remove(name)
filtered_estimators.extend(
[est for est in estimators if issubclass(est[1], mixin)]
)
estimators = filtered_estimators
if type_filter:
raise ValueError(
"Parameter type_filter must be 'classifier', "
"'regressor', 'transformer', 'cluster' or "
"None, got"
f" {repr(type_filter)}."
)

# drop duplicates, sort for reproducibility
# itemgetter is used to ensure the sort does not extend to the 2nd item of
# the tuple
return sorted(set(estimators), key=itemgetter(0))


def all_displays():
"""Get a list of all displays from `skltemplate`.
Returns
-------
displays : list of tuples
List of (name, class), where ``name`` is the display class name as
string and ``class`` is the actual type of the class.
Examples
--------
>>> from skltemplate.utils.discovery import all_displays
>>> displays = all_displays()
"""
all_classes = []
root = str(Path(__file__).parent.parent) # skltemplate package
# Ignore deprecation warnings triggered at import time and from walking
# packages
with ignore_warnings(category=FutureWarning):
for _, module_name, _ in pkgutil.walk_packages(
path=[root], prefix="skltemplate."
):
module_parts = module_name.split(".")
if any(part in _MODULE_TO_IGNORE for part in module_parts):
continue
module = import_module(module_name)
classes = inspect.getmembers(module, inspect.isclass)
classes = [
(name, display_class)
for name, display_class in classes
if not name.startswith("_") and name.endswith("Display")
]
all_classes.extend(classes)

return sorted(set(all_classes), key=itemgetter(0))


def _is_checked_function(item):
if not inspect.isfunction(item):
return False

if item.__name__.startswith("_"):
return False

mod = item.__module__
if not mod.startswith("skltemplate.") or mod.endswith("estimator_checks"):
return False

return True


def all_functions():
"""Get a list of all functions from `skltemplate`.
Returns
-------
functions : list of tuples
List of (name, function), where ``name`` is the function name as
string and ``function`` is the actual function.
Examples
--------
>>> from skltemplate.utils.discovery import all_functions
>>> functions = all_functions()
"""
all_functions = []
root = str(Path(__file__).parent.parent) # skltemplate package
# Ignore deprecation warnings triggered at import time and from walking
# packages
with ignore_warnings(category=FutureWarning):
for _, module_name, _ in pkgutil.walk_packages(
path=[root], prefix="skltemplate."
):
module_parts = module_name.split(".")
if any(part in _MODULE_TO_IGNORE for part in module_parts):
continue

module = import_module(module_name)
functions = inspect.getmembers(module, _is_checked_function)
functions = [
(func.__name__, func)
for name, func in functions
if not name.startswith("_")
]
all_functions.extend(functions)

# drop duplicates, sort for reproducibility
# itemgetter is used to ensure the sort does not extend to the 2nd item of
# the tuple
return sorted(set(all_functions), key=itemgetter(0))
2 changes: 2 additions & 0 deletions skltemplate/utils/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Authors: scikit-learn-contrib developers
# License: BSD 3 clause
19 changes: 19 additions & 0 deletions skltemplate/utils/tests/test_discovery.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Authors: scikit-learn-contrib developers
# License: BSD 3 clause

from skltemplate.utils.discovery import all_displays, all_estimators, all_functions


def test_all_estimators():
estimators = all_estimators()
assert len(estimators) == 3


def test_all_displays():
displays = all_displays()
assert len(displays) == 0


def test_all_functions():
functions = all_functions()
assert len(functions) == 3

0 comments on commit 7d15d78

Please sign in to comment.