diff --git a/changes/2377.feature.md b/changes/2377.feature.md new file mode 100644 index 0000000000..69d390098e --- /dev/null +++ b/changes/2377.feature.md @@ -0,0 +1 @@ +Implement scanning plugin entrypoints of external packages \ No newline at end of file diff --git a/py b/py index 58001570f4..3d8df6a547 100755 --- a/py +++ b/py @@ -11,8 +11,4 @@ if [ $? -ne 0 ]; then fi LOCKSET=${LOCKSET:-python-default/$PYTHON_VERSION} source dist/export/python/virtualenvs/$LOCKSET/bin/activate -PYTHONPATH="${PYTHONPATH}" -for plugin_dir in $(ls -d plugins/*/ 2>/dev/null); do - PYTHONPATH="${plugin_dir}/src:${PYTHONPATH}" -done PYTHONPATH="src:${PYTHONPATH}" exec python "$@" diff --git a/src/ai/backend/agent/docker/metadata/server.py b/src/ai/backend/agent/docker/metadata/server.py index 32f981e8af..8f456cdec7 100644 --- a/src/ai/backend/agent/docker/metadata/server.py +++ b/src/ai/backend/agent/docker/metadata/server.py @@ -187,7 +187,7 @@ async def load_metadata_plugins(self): plugin_ctx = MetadataPluginContext(root_ctx.etcd, root_ctx.local_config) await plugin_ctx.init() root_ctx.metadata_plugin_ctx = plugin_ctx - log.debug("Available plugins: {}", plugin_ctx.plugins) + log.debug("Available metadata plugins: {}", plugin_ctx.plugins) for plugin_name, plugin_instance in plugin_ctx.plugins.items(): log.info("Loading metadata plugin: {0}", plugin_name) subapp, global_middlewares, route_structure = await plugin_instance.create_app() diff --git a/src/ai/backend/common/BUILD b/src/ai/backend/common/BUILD index 4cd757e910..19b816be19 100644 --- a/src/ai/backend/common/BUILD +++ b/src/ai/backend/common/BUILD @@ -12,6 +12,7 @@ visibility_private_component( "//src/ai/backend/agent/**", "//src/ai/backend/client/**", "//src/ai/backend/cli/**", + "//src/ai/backend/plugin/**", "//src/ai/backend/storage/**", "//src/ai/backend/web/**", "//src/ai/backend/accelerator/**", diff --git a/src/ai/backend/plugin/BUILD b/src/ai/backend/plugin/BUILD index f375baf14e..13754f14cf 100644 --- a/src/ai/backend/plugin/BUILD +++ b/src/ai/backend/plugin/BUILD @@ -20,6 +20,7 @@ visibility_private_component( ], allowed_dependencies=[ "//src/ai/backend/cli/**", + "//src/ai/backend/common/**", ], ) diff --git a/src/ai/backend/plugin/cli.py b/src/ai/backend/plugin/cli.py index 63e3dc04a5..a91ea9d752 100644 --- a/src/ai/backend/plugin/cli.py +++ b/src/ai/backend/plugin/cli.py @@ -1,29 +1,70 @@ +from __future__ import annotations + import enum import itertools import json +import logging from collections import defaultdict +from typing import Self import click import colorama import tabulate from colorama import Fore, Style +from ai.backend.common.logging import AbstractLogger, LocalLogger +from ai.backend.common.types import LogSeverity + from .entrypoint import ( + prepare_wheelhouse, scan_entrypoint_from_buildscript, scan_entrypoint_from_package_metadata, scan_entrypoint_from_plugin_checkouts, ) +log = logging.getLogger(__spec__.name) # type: ignore[name-defined] + class FormatOptions(enum.StrEnum): CONSOLE = "console" JSON = "json" +class CLIContext: + _logger: AbstractLogger + + def __init__(self, log_level: LogSeverity) -> None: + self.log_level = log_level + + def __enter__(self) -> Self: + self._logger = LocalLogger({ + "level": self.log_level, + "pkg-ns": { + "": LogSeverity.WARNING, + "ai.backend": self.log_level, + }, + }) + self._logger.__enter__() + return self + + def __exit__(self, *exc_info) -> None: + self._logger.__exit__() + + @click.group() -def main(): +@click.option( + "--debug", + is_flag=True, + help="Set the logging level to DEBUG", +) +@click.pass_context +def main( + ctx: click.Context, + debug: bool, +) -> None: """The root entrypoint for unified CLI of the plugin subsystem""" - pass + log_level = LogSeverity.DEBUG if debug else LogSeverity.INFO + ctx.obj = ctx.with_resource(CLIContext(log_level)) @main.command() @@ -35,17 +76,23 @@ def main(): show_default=True, help="Set the output format.", ) -def scan(group_name: str, format: FormatOptions) -> None: - duplicate_count: dict[str, int] = defaultdict(int) +def scan( + group_name: str, + format: FormatOptions, +) -> None: + sources: dict[str, set[str]] = defaultdict(set) rows = [] + + prepare_wheelhouse() for source, entrypoint in itertools.chain( (("buildscript", item) for item in scan_entrypoint_from_buildscript(group_name)), (("plugin-checkout", item) for item in scan_entrypoint_from_plugin_checkouts(group_name)), (("python-package", item) for item in scan_entrypoint_from_package_metadata(group_name)), ): - duplicate_count[entrypoint.name] += 1 + sources[entrypoint.name].add(source) rows.append((source, entrypoint.name, entrypoint.module)) rows.sort(key=lambda row: (row[2], row[1], row[0])) + match format: case FormatOptions.CONSOLE: if not rows: @@ -53,6 +100,7 @@ def scan(group_name: str, format: FormatOptions) -> None: return colorama.init(autoreset=True) ITALIC = colorama.ansi.code_to_chars(3) + STRIKETHR = colorama.ansi.code_to_chars(9) src_style = { "buildscript": Fore.LIGHTYELLOW_EX, "plugin-checkout": Fore.LIGHTGREEN_EX, @@ -62,22 +110,44 @@ def scan(group_name: str, format: FormatOptions) -> None: f"{ITALIC}Source{Style.RESET_ALL}", f"{ITALIC}Name{Style.RESET_ALL}", f"{ITALIC}Module Path{Style.RESET_ALL}", + f"{ITALIC}Note{Style.RESET_ALL}", ) display_rows = [] - has_duplicate = False + duplicates = set() + warnings: dict[str, str] = dict() for source, name, module_path in rows: + note = "" name_style = Style.BRIGHT - if duplicate_count[name] > 1: - has_duplicate = True + has_plugin_checkout = "plugin-checkout" in sources[name] + duplication_threshold = 2 if has_plugin_checkout else 1 + if len(sources[name]) > duplication_threshold: + duplicates.add(name) name_style = Fore.RED + Style.BRIGHT + if source == "plugin-checkout": + name_style = Style.DIM + STRIKETHR + if "python-package" in sources[name]: + note = "Loaded via the python-package source" + else: + note = "Ignored when loading plugins unless installed as editable" display_rows.append(( f"{src_style[source]}{source}{Style.RESET_ALL}", f"{name_style}{name}{Style.RESET_ALL}", module_path, + note, )) print(tabulate.tabulate(display_rows, display_headers)) - if has_duplicate: - print(f"\n💥 {Fore.LIGHTRED_EX}Detected duplicated entrypoint(s)!{Style.RESET_ALL}") + for name, msg in warnings.items(): + print(msg) + if duplicates: + duplicate_list = ", ".join(duplicates) + print( + f"\n{Fore.LIGHTRED_EX}\u26a0 Detected duplicated entrypoint(s): {Style.BRIGHT}{duplicate_list}{Style.RESET_ALL}" + ) + if "accelerator" in group_name: + print( + f"{Fore.LIGHTRED_EX} You should check [agent].allow-compute-plugins in " + f"agent.toml to activate only one accelerator implementation for each name.{Style.RESET_ALL}" + ) case FormatOptions.JSON: output_rows = [] for source, name, module_path in rows: diff --git a/src/ai/backend/plugin/entrypoint.py b/src/ai/backend/plugin/entrypoint.py index 703b01ef48..326d91abf5 100644 --- a/src/ai/backend/plugin/entrypoint.py +++ b/src/ai/backend/plugin/entrypoint.py @@ -1,9 +1,13 @@ +from __future__ import annotations + import ast import collections import configparser import itertools import logging import os +import sys +import zipfile from importlib.metadata import EntryPoint, entry_points from pathlib import Path from typing import Iterable, Iterator, Optional @@ -19,9 +23,10 @@ def scan_entrypoints( if blocklist is None: blocklist = set() existing_names: dict[str, EntryPoint] = {} + + prepare_wheelhouse() for entrypoint in itertools.chain( scan_entrypoint_from_buildscript(group_name), - scan_entrypoint_from_plugin_checkouts(group_name), scan_entrypoint_from_package_metadata(group_name), ): if allowlist is not None and not match_plugin_list(entrypoint.value, allowlist): @@ -62,6 +67,8 @@ def match_plugin_list(entry_path: str, plugin_list: set[str]) -> bool: def scan_entrypoint_from_package_metadata(group_name: str) -> Iterator[EntryPoint]: + log.debug("scan_entrypoint_from_package_metadata(%r)", group_name) + yield from entry_points().select(group=group_name) @@ -143,6 +150,20 @@ def scan_entrypoint_from_plugin_checkouts(group_name: str) -> Iterator[EntryPoin yield from entrypoints.values() +def prepare_wheelhouse(base_dir: Path | None = None) -> None: + if base_dir is None: + base_dir = Path.cwd() + for whl_path in (base_dir / "wheelhouse").glob("*.whl"): + extracted_path = whl_path.with_suffix("") # strip the extension + log.debug("prepare_wheelhouse(): loading %s", whl_path) + if not extracted_path.exists(): + with zipfile.ZipFile(whl_path, "r") as z: + z.extractall(extracted_path) + decoded_path = os.fsdecode(extracted_path) + if decoded_path not in sys.path: + sys.path.append(decoded_path) + + def find_build_root(path: Optional[Path] = None) -> Path: if env_build_root := os.environ.get("BACKEND_BUILD_ROOT", None): return Path(env_build_root) diff --git a/wheelhouse/.gitignore b/wheelhouse/.gitignore index 704d307510..ce6d54956c 100644 --- a/wheelhouse/.gitignore +++ b/wheelhouse/.gitignore @@ -1 +1,2 @@ -*.whl +/*.whl +/*/