Skip to content

Commit

Permalink
✨ Basic ConfigModel
Browse files Browse the repository at this point in the history
  • Loading branch information
RF-Tar-Railt committed Dec 30, 2024
1 parent 8ef06cd commit 23cb1a7
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 32 deletions.
1 change: 1 addition & 0 deletions arclet/entari/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@

from . import command as command
from . import scheduler as scheduler
from .config import BasicConfModel as BasicConfModel
from .config import load_config as load_config
from .core import Entari as Entari
from .event import MessageCreatedEvent as MessageCreatedEvent
Expand Down
17 changes: 8 additions & 9 deletions arclet/entari/builtins/auto_reload.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,16 @@
raise ImportError("Please install `watchfiles` first. Install with `pip install arclet-entari[reload]`")

from arclet.entari import add_service, declare_static, load_plugin, metadata, plugin_config, unload_plugin
from arclet.entari.config import EntariConfig
from arclet.entari.config import BasicConfModel, EntariConfig, field
from arclet.entari.event.config import ConfigReload
from arclet.entari.logger import log
from arclet.entari.plugin import find_plugin, find_plugin_by_file

declare_static()


class Config:
watch_dirs: list[str] = ["."]
class Config(BasicConfModel):
watch_dirs: list[Union[str, Path]] = field(default_factory=lambda: ["."])
watch_config: bool = False


Expand Down Expand Up @@ -190,11 +190,9 @@ async def launch(self, manager: Launart):
self.fail.clear()


conf = plugin_config()
watch_dirs = conf.get("watch_dirs", ["."])
watch_config = conf.get("watch_config", False)
conf = plugin_config(Config)

add_service(serv := Watcher(watch_dirs, watch_config))
add_service(serv := Watcher(conf.watch_dirs, conf.watch_config))


@es.on(ConfigReload)
Expand All @@ -203,6 +201,7 @@ def handle_config_reload(event: ConfigReload):
return
if event.key not in ("::auto_reload", "arclet.entari.builtins.auto_reload"):
return
serv.dirs = event.value.get("watch_dirs", ["."])
serv.is_watch_config = event.value.get("watch_config", False)
new_conf = event.plugin_config(Config)
serv.dirs = new_conf.watch_dirs
serv.is_watch_config = new_conf.watch_config
return True
32 changes: 15 additions & 17 deletions arclet/entari/builtins/help.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import random
from dataclasses import field
from typing import Optional

from arclet.alconna import (
Expand All @@ -16,22 +17,19 @@
)
from tarina import lang

from arclet.entari import Session, command, metadata, plugin_config
from arclet.entari import BasicConfModel, Session, command, metadata, plugin_config

config = plugin_config()
help_command: str = config.get("help_command", "help")
help_alias: list[str] = config.get("help_alias", ["帮助", "命令帮助"])
help_all_alias: list[str] = config.get("help_all_alias", ["所有帮助", "所有命令帮助"])
page_size: Optional[int] = config.get("page_size", None)


class Config:
class Config(BasicConfModel):
help_command: str = "help"
help_alias: list[str] = ["帮助", "命令帮助"]
help_all_alias: list[str] = ["所有帮助", "所有命令帮助"]
help_alias: list[str] = field(default_factory=lambda: ["帮助", "命令帮助"])
help_all_alias: list[str] = field(default_factory=lambda: ["所有帮助", "所有命令帮助"])
page_size: Optional[int] = None


config = plugin_config(Config)


metadata(
"help",
["RF-Tar-Railt <[email protected]>"],
Expand All @@ -44,7 +42,7 @@ class Config:
ns.disable_builtin_options = {"shortcut"}

help_cmd = Alconna(
help_command,
config.help_command,
Args[
"query#选择某条命令的id或者名称查看具体帮助;/?",
str,
Expand All @@ -70,13 +68,13 @@ class Config:
meta=CommandMeta(
description="显示所有命令帮助",
usage="可以使用 --hide 参数来显示隐藏命令,使用 -P 参数来显示命令所属插件名称",
example=f"${help_command} 1",
example=f"${config.help_command} 1",
),
)

for alias in set(help_alias):
for alias in set(config.help_alias):
help_cmd.shortcut(alias, {"prefix": True, "fuzzy": False})
for alias in set(help_all_alias):
for alias in set(config.help_all_alias):
help_cmd.shortcut(alias, {"args": ["--hide"], "prefix": True, "fuzzy": False})


Expand Down Expand Up @@ -122,7 +120,7 @@ def help_cmd_handle(arp: Arparma, interactive: bool = False):
return f"{command_string}\n{footer}"
return slot.get_help()

if not page_size:
if not config.page_size:
header = lang.require("manager", "help_header")
command_string = "\n".join(
(
Expand All @@ -134,10 +132,10 @@ def help_cmd_handle(arp: Arparma, interactive: bool = False):
)
return f"{header}\n{command_string}\n{footer}"

max_page = len(cmds) // page_size + 1
max_page = len(cmds) // config.page_size + 1
if page < 1 or page > max_page:
page = 1
max_length = page_size
max_length = config.page_size
if interactive:
footer += "\n" + "输入 '<', 'a' 或 '>', 'd' 来翻页"

Expand Down
70 changes: 68 additions & 2 deletions arclet/entari/config.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
from __future__ import annotations

from dataclasses import dataclass, field
from dataclasses import dataclass, fields, is_dataclass
from dataclasses import field as field
from inspect import Signature
import json
import os
from pathlib import Path
from typing import Any, Callable, ClassVar, TypedDict
from typing import Any, Callable, ClassVar, TypedDict, TypeVar, get_args, get_origin
from typing_extensions import dataclass_transform
import warnings

_available_dc_attrs = set(Signature.from_callable(dataclass).parameters.keys())


class BasicConfig(TypedDict, total=False):
network: list[dict[str, Any]]
Expand Down Expand Up @@ -124,3 +129,64 @@ def _updater(self: EntariConfig):


load_config = EntariConfig.load


_config_model_validators = {}

C = TypeVar("C")


def config_validator_register(base: type):
def wrapper(func: Callable[[dict[str, Any], type[C]], C]):
_config_model_validators[base] = func
return func

return wrapper


def config_model_validate(base: type[C], data: dict[str, Any]) -> C:
for b in base.__mro__[-2::-1]:
if b in _config_model_validators:
return _config_model_validators[b](data, base)
return base(**data)


@dataclass_transform(kw_only_default=True)
class BasicConfModel:
def __init__(self, **kwargs):
for k, v in kwargs.items():
setattr(self, k, v)

def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
dataclass(**{k: v for k, v in kwargs.items() if k in _available_dc_attrs})(cls)


@config_validator_register(BasicConfModel)
def _basic_config_validate(data: dict[str, Any], base: type[C]) -> C:
def _nested_validate(namespace: dict[str, Any], cls):
result = {}
for field_ in fields(cls):
if field_.name not in namespace:
continue
if is_dataclass(field_.type):
result[field_.name] = _nested_validate(namespace[field_.name], field_.type)
elif get_origin(field_.type) is list and is_dataclass(get_args(field_.type)[0]):
result[field_.name] = [_nested_validate(d, get_args(field_.type)[0]) for d in namespace[field_.name]]
elif get_origin(field_.type) is set and is_dataclass(get_args(field_.type)[0]):
result[field_.name] = {_nested_validate(d, get_args(field_.type)[0]) for d in namespace[field_.name]}
elif get_origin(field_.type) is dict and is_dataclass(get_args(field_.type)[1]):
result[field_.name] = {
k: _nested_validate(v, get_args(field_.type)[1]) for k, v in namespace[field_.name].items()
}
elif get_origin(field_.type) is tuple:
args = get_args(field_.type)
result[field_.name] = tuple(
_nested_validate(d, args[i]) if is_dataclass(args[i]) else d
for i, d in enumerate(namespace[field_.name])
)
else:
result[field_.name] = namespace[field_.name]
return cls(**result)

return _nested_validate(data, base)
17 changes: 16 additions & 1 deletion arclet/entari/event/config.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from dataclasses import dataclass
from typing import Any, Optional
from typing import Any, Optional, overload

from arclet.letoderea import make_event

from ..config import C, config_model_validate


@dataclass
@make_event(name="entari.event/config_reload")
Expand All @@ -13,3 +15,16 @@ class ConfigReload:
old: Optional[Any] = None

__result_type__: type[bool] = bool

@overload
def plugin_config(self) -> dict[str, Any]: ...

@overload
def plugin_config(self, model_type: type[C]) -> C: ...

def plugin_config(self, model_type: Optional[type[C]] = None):
if self.scope != "plugin":
raise ValueError("not a plugin config")
if model_type:
return config_model_validate(model_type, self.value)
return self.value
16 changes: 13 additions & 3 deletions arclet/entari/plugin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

from os import PathLike
from pathlib import Path
from typing import Any
from typing import Any, overload

from tarina import init_spec

from ..config import EntariConfig
from ..config import C, EntariConfig, config_model_validate
from ..logger import log
from .model import PluginMetadata as PluginMetadata
from .model import RegisterNotInPluginError
Expand Down Expand Up @@ -116,10 +116,20 @@ def metadata(data: PluginMetadata):
plugin._metadata = data # type: ignore


def plugin_config() -> dict[str, Any]:
@overload
def plugin_config() -> dict[str, Any]: ...


@overload
def plugin_config(model_type: type[C]) -> C: ...


def plugin_config(model_type: type[C] | None = None):
"""获取当前插件的配置"""
if not (plugin := _current_plugin.get(None)):
raise LookupError("no plugin context found")
if model_type:
return config_model_validate(model_type, plugin.config)
return plugin.config


Expand Down

0 comments on commit 23cb1a7

Please sign in to comment.