Skip to content

Commit

Permalink
feat(notifiers/core): Improve load_provider_from_points function with…
Browse files Browse the repository at this point in the history
… better validation and type hints
  • Loading branch information
looghao committed Dec 9, 2024
1 parent 1de1315 commit 4d46a17
Showing 1 changed file with 46 additions and 16 deletions.
62 changes: 46 additions & 16 deletions notifiers/core.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import importlib.machinery
import importlib.util
import logging
from abc import ABC
from abc import abstractmethod
Expand Down Expand Up @@ -353,27 +351,59 @@ def get_notifier(provider_name: str, strict: bool = False) -> Provider:
raise NoSuchNotifierError(name=provider_name)


def load_provider_from_points(entry_points: str):
def load_provider_from_points(entry_points: str) -> 'Provider':
"""Load a Provider class from a given entry point string.
This function takes an entry point string in the format
'module_path:class_name', imports the module dynamically using
importlib, and returns the specified class from the module.
'module_path:class_name' and dynamically imports the specified Provider class.
It performs validation to ensure the loaded class is a valid Provider.
:param entry_points: A string specifying the module path and class name, separated by a colon.
:return: :class:`Provider` or None
:raises ImportError: If the specified module cannot be imported.
:raises AttributeError: If the specified class cannot be found in the module.
:param entry_points: A string in the format 'module_path:class_name' (e.g. 'myapp.providers:EmailProvider')
:return: :class:`Provider` The loaded Provider class
:raises ValueError: If the entry_points string format is invalid
:raises ImportError: If the specified module cannot be imported
:raises AttributeError: If the specified class cannot be found in the module
:raises TypeError: If the loaded class is not a subclass of Provider
Example:
>>> entry_points = "xxx.provider:Provider"
>>> plugin_class = load_provider_from_points(entry_points)
>>> plugin_instance = plugin_class()
>>> entry_points = "myapp.providers:EmailProvider"
>>> provider_class = load_provider_from_points(entry_points)
>>> provider = provider_class()
"""
instance_path, instance = entry_points.split(":")
module = importlib.import_module(instance_path)
provider = getattr(module, instance)
return provider
if not entry_points or ':' not in entry_points:
raise ValueError(
f"Invalid entry point format: {entry_points}. "
"Expected format: 'module_path:class_name'"
)

try:
module_path, class_name = entry_points.split(":", 1)
except ValueError as e:
raise ValueError(
f"Multiple colons found in entry point: {entry_points}. "
"Expected format: 'module_path:class_name'"
) from e

try:
module = importlib.import_module(module_path.strip())
except ImportError as e:
raise ImportError(
f"Failed to import module '{module_path}': {str(e)}"
) from e

try:
provider_class = getattr(module, class_name.strip())
except AttributeError as e:
raise AttributeError(
f"Class '{class_name}' not found in module '{module_path}'"
) from e

if not (isinstance(provider_class, type) and issubclass(provider_class, Provider)):
raise TypeError(
f"'{module_path}:{class_name}' must be a subclass of Provider"
)

return provider_class


def get_providers_from_entry_points(group_name: str = "notifiers") -> dict:
Expand Down

0 comments on commit 4d46a17

Please sign in to comment.