From 1f4dfa4a3998bd0696d88b79b95c35f2092f2068 Mon Sep 17 00:00:00 2001 From: RF-Tar-Railt Date: Tue, 13 Aug 2024 19:51:06 +0800 Subject: [PATCH] :bug: fix shortcut's dump & load --- src/arclet/alconna/_internal/_analyser.py | 16 +++++--- src/arclet/alconna/manager.py | 49 +++++++++++++++++------ src/arclet/alconna/typing.py | 22 ++++++++++ tests/core_test.py | 3 ++ 4 files changed, 73 insertions(+), 17 deletions(-) diff --git a/src/arclet/alconna/_internal/_analyser.py b/src/arclet/alconna/_internal/_analyser.py index 26661cca..44920860 100644 --- a/src/arclet/alconna/_internal/_analyser.py +++ b/src/arclet/alconna/_internal/_analyser.py @@ -288,7 +288,7 @@ def __repr__(self): def shortcut( self, argv: Argv[TDC], data: list[Any], short: Arparma | InnerShortcutArgs, reg: Match | None = None - ) -> Arparma[TDC]: + ) -> Arparma[TDC] | None: """处理被触发的快捷命令 Args: @@ -298,11 +298,13 @@ def shortcut( reg (Match | None): 可能的正则匹配结果 Returns: - Arparma[TDC]: Arparma 解析结果 + Arparma[TDC] | None: Arparma 解析结果 Raises: ParamsUnmatched: 若不允许快捷命令后随其他参数,则抛出此异常 """ + self.reset() + if isinstance(short, Arparma): return short @@ -326,7 +328,7 @@ def shortcut( argv.ndata = 0 argv.current_index = 0 argv.addon(data) - return self.process(argv) + return def process(self, argv: Argv[TDC]) -> Arparma[TDC]: """主体解析函数, 应针对各种情况进行解析 @@ -366,8 +368,12 @@ def process(self, argv: Argv[TDC]) -> Arparma[TDC]: argv.context[SHORTCUT_ARGS] = short argv.context[SHORTCUT_REST] = rest argv.context[SHORTCUT_REGEX_MATCH] = mat - self.reset() - return self.shortcut(argv, rest, short, mat) + + if arp := self.shortcut(argv, rest, short, mat): + return arp + + self.header_result = self.header_handler(self.command_header, argv) + self.header_result.origin = _next except RuntimeError as e: exc = InvalidParam(lang.require("header", "error").format(target=argv.release(recover=True)[0])) diff --git a/src/arclet/alconna/manager.py b/src/arclet/alconna/manager.py index fe10bd00..31383bcb 100644 --- a/src/arclet/alconna/manager.py +++ b/src/arclet/alconna/manager.py @@ -10,6 +10,7 @@ from datetime import datetime from typing import TYPE_CHECKING, Any, Match, MutableSet, Union from weakref import WeakValueDictionary +from pathlib import Path from nepattern import TPattern from tarina import LRU, lang @@ -47,7 +48,6 @@ def max_count(self) -> int: __shortcuts: dict[str, tuple[dict[str, Union[Arparma, InnerShortcutArgs]], dict[str, Union[Arparma, InnerShortcutArgs]]]] def __init__(self): - self.cache_path = f"{__file__.replace('manager.py', '')}manager_cache.db" self.sign = "ALCONNA::" self.current_count = 0 @@ -71,24 +71,49 @@ def _del(): weakref.finalize(self, _del) - def load_cache(self) -> None: + def load_shortcuts(self, file: str | Path | None = None) -> None: """加载缓存""" + path = Path(file or (Path.cwd() / "shortcut.db")) with contextlib.suppress(FileNotFoundError, KeyError): - with shelve.open(self.cache_path) as db: - self.__shortcuts = dict(db["shortcuts"]) # type: ignore - - def dump_cache(self) -> None: + with shelve.open(path.resolve().as_posix()) as db: + data: dict[str, tuple[dict, dict]] = dict(db["shortcuts"]) # type: ignore + for cmd, shorts in data.items(): + _data = self.__shortcuts.setdefault(cmd, ({}, {})) + for key, short in shorts[0].items(): + if isinstance(short, dict): + _data[0][key] = InnerShortcutArgs.load(short) + else: + _data[0][key] = short + for key, short in shorts[1].items(): + if isinstance(short, dict): + _data[1][key] = InnerShortcutArgs.load(short) + else: + _data[1][key] = short + + load_cache = load_shortcuts + + def dump_shortcuts(self, file: str | Path | None = None) -> None: """保存缓存""" data = {} - for key, short in self.__shortcuts.items(): - if isinstance(short, dict): - data[key] = {k: v for k, v in short.items() if k != "wrapper"} - else: - data[key] = short - with shelve.open(self.cache_path) as db: + for cmd, shorts in self.__shortcuts.items(): + _data = data.setdefault(cmd, ({}, {})) + for key, short in shorts[0].items(): + if isinstance(short, InnerShortcutArgs): + _data[0][key] = short.dump() + else: + _data[0][key] = short + for key, short in shorts[1].items(): + if isinstance(short, InnerShortcutArgs): + _data[1][key] = short.dump() + else: + _data[1][key] = short + path = Path(file or (Path.cwd() / "shortcut.db")) + with shelve.open(path.resolve().as_posix()) as db: db["shortcuts"] = data data.clear() + dump_cache = dump_shortcuts + @property def get_loaded_namespaces(self): """获取所有命名空间 diff --git a/src/arclet/alconna/typing.py b/src/arclet/alconna/typing.py index a893dfe3..a1758fc0 100644 --- a/src/arclet/alconna/typing.py +++ b/src/arclet/alconna/typing.py @@ -104,6 +104,28 @@ def __init__( def __repr__(self): return f"ShortcutArgs({self.command!r}, args={self.args!r}, fuzzy={self.fuzzy}, prefix={self.prefix})" + def dump(self): + return { + "command": self.command, + "args": self.args, + "fuzzy": self.fuzzy, + "prefix": self.prefix, + "prefixes": self.prefixes, + "flags": self.flags, + } + + @classmethod + def load(cls, data: dict[str, Any]) -> InnerShortcutArgs: + return cls( + data["command"], + data.get("args"), + data.get("fuzzy", True), + data.get("prefix", False), + data.get("prefixes"), + data.get("wrapper"), + data.get("flags", 0), + ) + @runtime_checkable class DataCollection(Protocol[DataUnit]): diff --git a/tests/core_test.py b/tests/core_test.py index 58eabc38..38c2dc96 100644 --- a/tests/core_test.py +++ b/tests/core_test.py @@ -466,6 +466,7 @@ def test_shortcut(): # 构造体缩写传入;{i} 将被可能的正则匹配替换 alc16.shortcut(r"TEST(\d+)(.+)", {"args": ["{0}", "bar {1}"]}) res = alc16.parse("TEST123aa") + assert res.header_match.origin == "TEST123aa" assert res.matched is True assert res.foo == 123 assert res.baz == "aa" @@ -488,9 +489,11 @@ def test_shortcut(): alc16_1.shortcut("echo", command="exec print({%0})") alc16_1.shortcut("echo1", command="exec print(\\'{*\n}\\')") res5 = alc16_1.parse("echo 123") + assert res5.header_match.origin == "echo" assert res5.content == "print(123)" assert not alc16_1.parse("echo 123 456").matched res6 = alc16_1.parse(["echo1", "123", "456 789"]) + assert res6.header_match.origin == "echo1" assert res6.content == "print('123\n456\n789')" res7 = alc16_1.parse([123]) assert not res7.matched