Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add get providers from entry points #469

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
96 changes: 93 additions & 3 deletions notifiers/core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import importlib.machinery
import importlib.util
import logging
from abc import ABC
from abc import abstractmethod
from importlib.metadata import entry_points

import jsonschema
import requests
Expand Down Expand Up @@ -340,16 +343,103 @@ def get_notifier(provider_name: str, strict: bool = False) -> Provider:
:return: :class:`Provider` or None
:raises ValueError: In case ``strict`` is True and provider not found
"""
if provider_name in _all_providers:
providers = get_all_providers()
if provider_name in providers:
log.debug("found a match for '%s', returning", provider_name)
return _all_providers[provider_name]()
return providers[provider_name]()
elif strict:
raise NoSuchNotifierError(name=provider_name)


def load_provider_from_points(entry_points: str) -> "Provider":
"""Load a Provider class from a given entry point string.

This function takes an entry point string in the format
'module_path:class_name' and dynamically imports the specified Provider class.
It performs validation to ensure the loaded class is a valid Provider.

:param entry_points: A string in the format 'module_path:class_name' (e.g. 'myapp.providers:EmailProvider')
:return: :class:`Provider` The loaded Provider class
:raises ValueError: If the entry_points string format is invalid
:raises ImportError: If the specified module cannot be imported
:raises AttributeError: If the specified class cannot be found in the module
:raises TypeError: If the loaded class is not a subclass of Provider

Example:
>>> entry_points = "myapp.providers:EmailProvider"
>>> provider_class = load_provider_from_points(entry_points)
>>> provider = provider_class()
"""
if not entry_points or ":" not in entry_points:
raise ValueError(
f"Invalid entry point format: {entry_points}. "
"Expected format: 'module_path:class_name'"
)

try:
module_path, class_name = entry_points.split(":", 1)
except ValueError as e:
raise ValueError(
f"Multiple colons found in entry point: {entry_points}. "
"Expected format: 'module_path:class_name'"
) from e

try:
module = importlib.import_module(module_path.strip())
except ImportError as e:
raise ImportError(f"Failed to import module '{module_path}': {str(e)}") from e

try:
provider_class = getattr(module, class_name.strip())
except AttributeError as e:
raise AttributeError(
f"Class '{class_name}' not found in module '{module_path}'"
) from e

if not (isinstance(provider_class, type) and issubclass(provider_class, Provider)):
raise TypeError(f"'{module_path}:{class_name}' must be a subclass of Provider")

return provider_class


def get_providers_from_entry_points(group_name: str = "notifiers") -> dict:
"""
Get a dictionary of plugins from the entry points based on the given group name.

This function will search for the entry points with the specified group name
and return a dictionary where the keys are the names of the entry points and
the values are the corresponding entry point values.

:param group_name: The group name of the entry points to search for.
:return: Dict: A dictionary containing the entry point names as keys and their corresponding values as values.

Example:
>>> get_providers_from_entry_points("notifiers")
{"plugin1": "package.module:PluginClass", "plugin2": "package2.module:OtherPluginClass"}
"""
result: dict = {}
points = entry_points()
for item in points.get(group_name, []):
if item.group == group_name:
result[item.name] = load_provider_from_points(item.value)
return result


def get_all_providers() -> dict:
"""Get all providers from the entry points and the default providers.

:return: Dict: A dictionary containing the entry point names as keys and their corresponding values as values.

"""
default_providers = _all_providers.copy()
entry_point_providers = get_providers_from_entry_points()
default_providers.update(entry_point_providers)
return default_providers


def all_providers() -> list:
"""Returns a list of all :class:`~notifiers.core.Provider` names"""
return list(_all_providers.keys())
return list(get_all_providers().keys())


def notify(provider_name: str, **kwargs) -> Response:
Expand Down
Loading