Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Scanning plugin entrypoints of external packages #2377

Merged
merged 16 commits into from
Aug 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/2377.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Implement scanning plugin entrypoints of external packages
4 changes: 0 additions & 4 deletions py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,4 @@ if [ $? -ne 0 ]; then
fi
LOCKSET=${LOCKSET:-python-default/$PYTHON_VERSION}
source dist/export/python/virtualenvs/$LOCKSET/bin/activate
PYTHONPATH="${PYTHONPATH}"
for plugin_dir in $(ls -d plugins/*/ 2>/dev/null); do
PYTHONPATH="${plugin_dir}/src:${PYTHONPATH}"
done
PYTHONPATH="src:${PYTHONPATH}" exec python "$@"
2 changes: 1 addition & 1 deletion src/ai/backend/agent/docker/metadata/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ async def load_metadata_plugins(self):
plugin_ctx = MetadataPluginContext(root_ctx.etcd, root_ctx.local_config)
await plugin_ctx.init()
root_ctx.metadata_plugin_ctx = plugin_ctx
log.debug("Available plugins: {}", plugin_ctx.plugins)
log.debug("Available metadata plugins: {}", plugin_ctx.plugins)
for plugin_name, plugin_instance in plugin_ctx.plugins.items():
log.info("Loading metadata plugin: {0}", plugin_name)
subapp, global_middlewares, route_structure = await plugin_instance.create_app()
Expand Down
1 change: 1 addition & 0 deletions src/ai/backend/common/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ visibility_private_component(
"//src/ai/backend/agent/**",
"//src/ai/backend/client/**",
"//src/ai/backend/cli/**",
"//src/ai/backend/plugin/**",
"//src/ai/backend/storage/**",
"//src/ai/backend/web/**",
"//src/ai/backend/accelerator/**",
Expand Down
1 change: 1 addition & 0 deletions src/ai/backend/plugin/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ visibility_private_component(
],
allowed_dependencies=[
"//src/ai/backend/cli/**",
"//src/ai/backend/common/**",
],
)

Expand Down
90 changes: 80 additions & 10 deletions src/ai/backend/plugin/cli.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,70 @@
from __future__ import annotations

import enum
import itertools
import json
import logging
from collections import defaultdict
from typing import Self

import click
import colorama
import tabulate
from colorama import Fore, Style

from ai.backend.common.logging import AbstractLogger, LocalLogger
from ai.backend.common.types import LogSeverity

from .entrypoint import (
prepare_wheelhouse,
scan_entrypoint_from_buildscript,
scan_entrypoint_from_package_metadata,
scan_entrypoint_from_plugin_checkouts,
)

log = logging.getLogger(__spec__.name) # type: ignore[name-defined]


class FormatOptions(enum.StrEnum):
CONSOLE = "console"
JSON = "json"


class CLIContext:
_logger: AbstractLogger

def __init__(self, log_level: LogSeverity) -> None:
self.log_level = log_level

def __enter__(self) -> Self:
self._logger = LocalLogger({
"level": self.log_level,
"pkg-ns": {
"": LogSeverity.WARNING,
"ai.backend": self.log_level,
},
})
self._logger.__enter__()
return self

def __exit__(self, *exc_info) -> None:
self._logger.__exit__()


@click.group()
def main():
@click.option(
"--debug",
is_flag=True,
help="Set the logging level to DEBUG",
)
@click.pass_context
def main(
ctx: click.Context,
debug: bool,
) -> None:
"""The root entrypoint for unified CLI of the plugin subsystem"""
pass
log_level = LogSeverity.DEBUG if debug else LogSeverity.INFO
ctx.obj = ctx.with_resource(CLIContext(log_level))


@main.command()
Expand All @@ -35,24 +76,31 @@ def main():
show_default=True,
help="Set the output format.",
)
def scan(group_name: str, format: FormatOptions) -> None:
duplicate_count: dict[str, int] = defaultdict(int)
def scan(
group_name: str,
format: FormatOptions,
) -> None:
sources: dict[str, set[str]] = defaultdict(set)
rows = []

prepare_wheelhouse()
for source, entrypoint in itertools.chain(
(("buildscript", item) for item in scan_entrypoint_from_buildscript(group_name)),
(("plugin-checkout", item) for item in scan_entrypoint_from_plugin_checkouts(group_name)),
(("python-package", item) for item in scan_entrypoint_from_package_metadata(group_name)),
):
duplicate_count[entrypoint.name] += 1
sources[entrypoint.name].add(source)
rows.append((source, entrypoint.name, entrypoint.module))
rows.sort(key=lambda row: (row[2], row[1], row[0]))

match format:
case FormatOptions.CONSOLE:
if not rows:
print(f"No plugins found for the entrypoint {group_name!r}")
return
colorama.init(autoreset=True)
ITALIC = colorama.ansi.code_to_chars(3)
STRIKETHR = colorama.ansi.code_to_chars(9)
src_style = {
"buildscript": Fore.LIGHTYELLOW_EX,
"plugin-checkout": Fore.LIGHTGREEN_EX,
Expand All @@ -62,22 +110,44 @@ def scan(group_name: str, format: FormatOptions) -> None:
f"{ITALIC}Source{Style.RESET_ALL}",
f"{ITALIC}Name{Style.RESET_ALL}",
f"{ITALIC}Module Path{Style.RESET_ALL}",
f"{ITALIC}Note{Style.RESET_ALL}",
)
display_rows = []
has_duplicate = False
duplicates = set()
warnings: dict[str, str] = dict()
for source, name, module_path in rows:
note = ""
name_style = Style.BRIGHT
if duplicate_count[name] > 1:
has_duplicate = True
has_plugin_checkout = "plugin-checkout" in sources[name]
duplication_threshold = 2 if has_plugin_checkout else 1
if len(sources[name]) > duplication_threshold:
duplicates.add(name)
name_style = Fore.RED + Style.BRIGHT
if source == "plugin-checkout":
name_style = Style.DIM + STRIKETHR
if "python-package" in sources[name]:
note = "Loaded via the python-package source"
else:
note = "Ignored when loading plugins unless installed as editable"
display_rows.append((
f"{src_style[source]}{source}{Style.RESET_ALL}",
f"{name_style}{name}{Style.RESET_ALL}",
module_path,
note,
))
print(tabulate.tabulate(display_rows, display_headers))
if has_duplicate:
print(f"\n💥 {Fore.LIGHTRED_EX}Detected duplicated entrypoint(s)!{Style.RESET_ALL}")
for name, msg in warnings.items():
print(msg)
if duplicates:
duplicate_list = ", ".join(duplicates)
print(
f"\n{Fore.LIGHTRED_EX}\u26a0 Detected duplicated entrypoint(s): {Style.BRIGHT}{duplicate_list}{Style.RESET_ALL}"
)
if "accelerator" in group_name:
print(
f"{Fore.LIGHTRED_EX} You should check [agent].allow-compute-plugins in "
f"agent.toml to activate only one accelerator implementation for each name.{Style.RESET_ALL}"
)
case FormatOptions.JSON:
output_rows = []
for source, name, module_path in rows:
Expand Down
23 changes: 22 additions & 1 deletion src/ai/backend/plugin/entrypoint.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from __future__ import annotations

import ast
import collections
import configparser
import itertools
import logging
import os
import sys
import zipfile
from importlib.metadata import EntryPoint, entry_points
from pathlib import Path
from typing import Iterable, Iterator, Optional
Expand All @@ -19,9 +23,10 @@ def scan_entrypoints(
if blocklist is None:
blocklist = set()
existing_names: dict[str, EntryPoint] = {}

prepare_wheelhouse()
for entrypoint in itertools.chain(
scan_entrypoint_from_buildscript(group_name),
scan_entrypoint_from_plugin_checkouts(group_name),
scan_entrypoint_from_package_metadata(group_name),
):
if allowlist is not None and not match_plugin_list(entrypoint.value, allowlist):
Expand Down Expand Up @@ -62,6 +67,8 @@ def match_plugin_list(entry_path: str, plugin_list: set[str]) -> bool:


def scan_entrypoint_from_package_metadata(group_name: str) -> Iterator[EntryPoint]:
log.debug("scan_entrypoint_from_package_metadata(%r)", group_name)

yield from entry_points().select(group=group_name)


Expand Down Expand Up @@ -143,6 +150,20 @@ def scan_entrypoint_from_plugin_checkouts(group_name: str) -> Iterator[EntryPoin
yield from entrypoints.values()


def prepare_wheelhouse(base_dir: Path | None = None) -> None:
if base_dir is None:
base_dir = Path.cwd()
for whl_path in (base_dir / "wheelhouse").glob("*.whl"):
extracted_path = whl_path.with_suffix("") # strip the extension
log.debug("prepare_wheelhouse(): loading %s", whl_path)
if not extracted_path.exists():
with zipfile.ZipFile(whl_path, "r") as z:
z.extractall(extracted_path)
decoded_path = os.fsdecode(extracted_path)
if decoded_path not in sys.path:
sys.path.append(decoded_path)


def find_build_root(path: Optional[Path] = None) -> Path:
if env_build_root := os.environ.get("BACKEND_BUILD_ROOT", None):
return Path(env_build_root)
Expand Down
3 changes: 2 additions & 1 deletion wheelhouse/.gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
*.whl
/*.whl
/*/