From 0a97d5fb2e27244885e29eea2d180aafb3677fac Mon Sep 17 00:00:00 2001 From: RF-Tar-Railt <3165388245@qq.com> Date: Sat, 25 Jun 2022 18:02:41 +0800 Subject: [PATCH] :tada: 1.0.0 --- README-EN.md | 6 +- README.md | 15 +- README.rst | 7 +- alconna.svg | 906 ++++++++++--------- benchmark.py | 72 ++ changelog.md | 24 + commander/broadcast.py | 142 --- commander/commander.py | 140 +++ commander/letoderea.py | 84 -- commander/test_commander_bcc.py | 34 - commander/test_commander_leto.py | 30 - dev_tools/benchmark.py | 50 - dev_tools/test_alconna_1.py | 211 ----- dev_tools/test_alconna_2.py | 124 --- dev_tools/test_alconna_decorate.py | 62 -- dev_tools/test_alconna_fire.py | 77 -- dev_tools/test_args.py | 69 -- dev_tools/test_behavior.py | 38 - dev_tools/test_duplication.py | 28 - dev_tools/test_formatter.py | 25 - dev_tools/test_manager.py | 5 - pyproject.toml | 4 +- requirements.txt | 2 +- requirements_text.txt | 1 + setup.py | 10 +- src/arclet/alconna/__init__.py | 13 +- src/arclet/alconna/analysis/analyser.py | 269 +++--- src/arclet/alconna/analysis/base.py | 56 +- src/arclet/alconna/analysis/parts.py | 285 +++--- src/arclet/alconna/analysis/special.py | 31 +- src/arclet/alconna/arpamar.py | 36 +- src/arclet/alconna/base.py | 178 ++-- src/arclet/alconna/builtin/actions.py | 26 +- src/arclet/alconna/builtin/analyser.py | 65 +- src/arclet/alconna/builtin/checker.py | 37 + src/arclet/alconna/builtin/construct.py | 248 ++--- src/arclet/alconna/builtin/formatter.py | 121 ++- src/arclet/alconna/builtin/pattern.py | 169 +--- src/arclet/alconna/components/action.py | 25 +- src/arclet/alconna/components/duplication.py | 14 +- src/arclet/alconna/components/output.py | 160 ++-- src/arclet/alconna/components/stub.py | 6 +- src/arclet/alconna/{lang.py => config.py} | 27 +- src/arclet/alconna/core.py | 124 ++- src/arclet/alconna/exceptions.py | 4 +- src/arclet/alconna/manager.py | 138 ++- src/arclet/alconna/typing.py | 420 ++++----- src/arclet/alconna/util.py | 139 +-- test_alconna/analyser_test.py | 70 ++ test_alconna/args_test.py | 154 ++++ test_alconna/base_test.py | 74 ++ test_alconna/components_test.py | 77 ++ test_alconna/config_test.py | 23 + test_alconna/construct_test.py | 123 +++ test_alconna/core_test.py | 275 ++++++ test_alconna/entry_test.py | 13 + test_alconna/type_test.py | 119 +++ test_alconna/util_test.py | 55 ++ 58 files changed, 2861 insertions(+), 2879 deletions(-) create mode 100644 benchmark.py delete mode 100644 commander/broadcast.py create mode 100644 commander/commander.py delete mode 100644 commander/letoderea.py delete mode 100644 commander/test_commander_bcc.py delete mode 100644 commander/test_commander_leto.py delete mode 100644 dev_tools/benchmark.py delete mode 100644 dev_tools/test_alconna_1.py delete mode 100644 dev_tools/test_alconna_2.py delete mode 100644 dev_tools/test_alconna_decorate.py delete mode 100644 dev_tools/test_alconna_fire.py delete mode 100644 dev_tools/test_args.py delete mode 100644 dev_tools/test_behavior.py delete mode 100644 dev_tools/test_duplication.py delete mode 100644 dev_tools/test_formatter.py delete mode 100644 dev_tools/test_manager.py create mode 100644 requirements_text.txt create mode 100644 src/arclet/alconna/builtin/checker.py rename src/arclet/alconna/{lang.py => config.py} (73%) create mode 100644 test_alconna/analyser_test.py create mode 100644 test_alconna/args_test.py create mode 100644 test_alconna/base_test.py create mode 100644 test_alconna/components_test.py create mode 100644 test_alconna/config_test.py create mode 100644 test_alconna/construct_test.py create mode 100644 test_alconna/core_test.py create mode 100644 test_alconna/entry_test.py create mode 100644 test_alconna/type_test.py create mode 100644 test_alconna/util_test.py diff --git a/README-EN.md b/README-EN.md index fe31f6e9..6c9ccde1 100644 --- a/README-EN.md +++ b/README-EN.md @@ -46,12 +46,12 @@ cmd = Alconna( ) result = cmd.parse("/pip install cesloi --upgrade") # This method returns an 'Arpamar' class instance. -print(result.get('install')) # Or result.install +print(result.query('install')) # Or result.install ``` Output as follows: ``` -{'pak_name': 'cesloi', 'upgrade': Ellipsis} +{'value': None, 'args': {'pak_name': 'cesloi'}, 'options': {'upgrade': Ellipsis}} ``` @@ -83,7 +83,7 @@ QQ Group: [Link](https://jq.qq.com/?_wv=1027&k=PUPOnCSH) ## Features -* High Performance. On i5-10210U, performance is about `41000~101000 msg/s`; test script: [benchmark](dev_tools/benchmark.py) +* High Performance. On i5-10210U, performance is about `41000~101000 msg/s`; test script: [benchmark](benchmark.py) * Simple and Flexible Constructor * Powerful Automatic Type Parse and Conversion * Support Synchronous and Asynchronous Actions diff --git a/README.md b/README.md index 7bd518ef..34494bc3 100644 --- a/README.md +++ b/README.md @@ -17,17 +17,22 @@ ## 关于 -`Alconna` 隶属于 `ArcletProject`, 是 `CommandAnalysis` 的重构版,是一个简单、灵活、高效的命令参数解析器, 并不局限于解析字符串。 +`Alconna` 隶属于 `ArcletProject`, 是一个简单、灵活、高效的命令参数解析器, 并且不局限于解析命令式字符串。 `Alconna` 拥有复杂的解析功能与命令组件,但 一般情况下请当作~~奇妙~~简易的消息链解析器/命令解析器(雾) ## 安装 pip -``` +```bash pip install --upgrade arclet-alconna ``` +完整安装 +```bash +pip install --upgrade arclet-alconna[full] +``` + ## 文档 文档链接: [👉指路](https://arcletproject.github.io/docs/alconna/tutorial) @@ -48,11 +53,11 @@ cmd = Alconna( ) result = cmd.parse("/pip install cesloi --upgrade") # 该方法返回一个Arpamar类的实例 -print(result.get('install')) # 或者 result.install +print(result.query('install')) # 或者 result.install ``` 其结果为 ``` -{'pak_name': 'cesloi', 'upgrade': Ellipsis} +{'value': None, 'args': {'pak_name': 'cesloi'}, 'options': {'upgrade': Ellipsis}} ``` ### 搭配响应函数 @@ -82,7 +87,7 @@ QQ 交流群: [链接](https://jq.qq.com/?_wv=1027&k=PUPOnCSH) ## 特点 -* 高效. 在 i5-10210U 处理器上, 性能大约为 `41000~101000 msg/s`; 测试脚本: [benchmark](dev_tools/benchmark.py) +* 高效. 在 i5-10210U 处理器上, 性能大约为 `41000~101000 msg/s`; 测试脚本: [benchmark](benchmark.py) * 精简、多样的构造方法 * 强大的类型解析与转换功能 * 可传入同步与异步的 action 函数 diff --git a/README.rst b/README.rst index c291587a..2c384689 100644 --- a/README.rst +++ b/README.rst @@ -7,8 +7,7 @@ **English**: `README `__ -``Alconna`` 隶属于 ``ArcletProject``, 是 ``CommandAnalysis`` -的高级版, 支持解析消息链或者其他原始消息数据 +``Alconna`` 隶属于 ``ArcletProject``, 是一个简单、灵活、高效的命令参数解析器, 并且不局限于解析命令式字符串。 ``Alconna`` 拥有复杂的解析功能与命令组件,但 一般情况下请当作\ [STRIKEOUT:奇妙]\ 简易的消息链解析器/命令解析器(雾) @@ -46,13 +45,13 @@ pip ) result = cmd.parse("/pip install cesloi --upgrade") # 该方法返回一个Arpamar类的实例 - print(result.get('install')) # 或者 result.install + print(result.query('install')) # 或者 result.install 其结果为 :: - {'pak_name': 'cesloi', 'upgrade': Ellipsis} + {'value': None, 'args': {'pak_name': 'cesloi'}, 'options': {'upgrade': Ellipsis}} 讨论 ---- diff --git a/alconna.svg b/alconna.svg index 461152d2..c5e65277 100644 --- a/alconna.svg +++ b/alconna.svg @@ -4,805 +4,823 @@ - - + + G - + alconna - -alconna + +alconna alconna_analysis - -alconna. -analysis + +alconna. +analysis alconna_analysis->alconna - - + + - + alconna_core - -alconna.core + +alconna.core alconna_analysis->alconna_core - - + + - + alconna_manager - -alconna. -manager + +alconna. +manager alconna_analysis->alconna_manager - - + + + alconna_analysis_analyser - -alconna. -analysis. -analyser + +alconna. +analysis. +analyser alconna_analysis_base - -alconna. -analysis. -base + +alconna. +analysis. +base alconna_analysis_analyser->alconna_analysis_base - - - - - + + + alconna_analysis_parts - -alconna. -analysis. -parts + +alconna. +analysis. +parts alconna_analysis_analyser->alconna_analysis_parts - + + alconna_analysis_special - -alconna. -analysis. -special + +alconna. +analysis. +special alconna_analysis_analyser->alconna_analysis_special - + + alconna_analysis_analyser->alconna_manager - - - + + + alconna_analysis_base->alconna - + + - + +alconna_analysis_base->alconna_manager + + + + + alconna_analysis_parts->alconna_analysis_base - + + - + alconna_analysis_parts->alconna_analysis_special - - - + + alconna_arpamar - -alconna. -arpamar + +alconna. +arpamar - + alconna_arpamar->alconna - + - + alconna_arpamar->alconna_analysis_analyser - + - + alconna_arpamar->alconna_analysis_base - - - + + - + alconna_components_action - -alconna. -components. -action + +alconna. +components. +action - + alconna_arpamar->alconna_components_action - - + - + alconna_components_behavior - -alconna. -components. -behavior + +alconna. +components. +behavior - + alconna_arpamar->alconna_components_behavior - - - + + + - + alconna_components_duplication - -alconna. -components. -duplication + +alconna. +components. +duplication + + + +alconna_arpamar->alconna_components_duplication + + + - + alconna_arpamar->alconna_manager - - - + + + alconna_base - -alconna.base + +alconna.base - + alconna_base->alconna - - - + - + alconna_base->alconna_analysis_analyser - - + + + - + alconna_base->alconna_analysis_base - - + + + - + alconna_base->alconna_analysis_parts - - - + + - + alconna_base->alconna_analysis_special - - - + + + - + alconna_base->alconna_arpamar - - + + - + alconna_components_output - -alconna. -components. -output + +alconna. +components. +output - + alconna_base->alconna_components_output - + + + - + alconna_components_stub - -alconna. -components. -stub + +alconna. +components. +stub - + alconna_base->alconna_components_stub - - + + - + alconna_base->alconna_core - - + + alconna_builtin - -alconna. -builtin + +alconna. +builtin - + alconna_builtin->alconna - - + - + alconna_builtin->alconna_core - - + + alconna_builtin_actions - -alconna. -builtin. -actions + +alconna. +builtin. +actions - + alconna_builtin_actions->alconna - + - + alconna_builtin_construct - -alconna. -builtin. -construct + +alconna. +builtin. +construct - + alconna_builtin_actions->alconna_builtin_construct - - + + alconna_builtin_analyser - -alconna. -builtin. -analyser + +alconna. +builtin. +analyser - + alconna_builtin_analyser->alconna_core - - - + + + + + +alconna_builtin_checker + +alconna. +builtin. +checker - + alconna_builtin_construct->alconna - + - + alconna_builtin_formatter - -alconna. -builtin. -formatter + +alconna. +builtin. +formatter - + alconna_builtin_formatter->alconna - + + + + + - + alconna_builtin_formatter->alconna_core - + + - + alconna_builtin_pattern - -alconna. -builtin. -pattern + +alconna. +builtin. +pattern - + alconna_builtin_pattern->alconna - + - + alconna_components - -alconna. -components + +alconna. +components - + alconna_components->alconna - - + - + alconna_components->alconna_analysis_special - - - + + - + alconna_components->alconna_arpamar - - + - + alconna_components->alconna_base - - - + + + + - + alconna_components->alconna_core - - + + - + alconna_components_action->alconna_base - - - + + + - + alconna_components_action->alconna_components_output - - + + - + alconna_components_action->alconna_core - - + + - + +alconna_components_behavior->alconna + + + + + alconna_components_behavior->alconna_components_action - - - + + - + alconna_components_behavior->alconna_core - + + + - + alconna_components_duplication->alconna - - - - - - -alconna_components_duplication->alconna_arpamar - - - + - + alconna_components_output->alconna - + + - + alconna_components_output->alconna_analysis_special - + + - + alconna_components_stub->alconna - + + - + alconna_components_stub->alconna_components_duplication - - + + - - -alconna_core->alconna - - - - - -alconna_core->alconna_analysis_analyser - - - + + +alconna_config + +alconna.config - + -alconna_core->alconna_analysis_base - - - +alconna_config->alconna + + - + -alconna_core->alconna_arpamar - - - +alconna_config->alconna_analysis_analyser + + + - + -alconna_core->alconna_components_duplication - - - +alconna_config->alconna_analysis_parts + + + - + -alconna_core->alconna_components_output - - - +alconna_config->alconna_arpamar + + + - + -alconna_core->alconna_manager - - - +alconna_config->alconna_base + + - - -alconna_exceptions - -alconna. -exceptions - - + -alconna_exceptions->alconna - - +alconna_config->alconna_components_action + + + + - + -alconna_exceptions->alconna_analysis_analyser - - - +alconna_config->alconna_components_duplication + + - + -alconna_exceptions->alconna_analysis_parts - - +alconna_config->alconna_components_stub + + - + -alconna_exceptions->alconna_arpamar - - - +alconna_config->alconna_core + + + - + -alconna_exceptions->alconna_base - - - +alconna_config->alconna_manager + + - + + +alconna_typing + +alconna.typing + + -alconna_exceptions->alconna_components_action - - +alconna_config->alconna_typing + + - + -alconna_exceptions->alconna_manager - - - - -alconna_typing - -alconna.typing +alconna_core->alconna + + + + - + -alconna_exceptions->alconna_typing - - - - - -alconna_lang - -alconna.lang +alconna_core->alconna_analysis_analyser + + + - + -alconna_lang->alconna - +alconna_core->alconna_analysis_base + + + - + -alconna_lang->alconna_analysis_analyser - +alconna_core->alconna_arpamar + + + - + -alconna_lang->alconna_analysis_parts - - +alconna_core->alconna_components_duplication + + + - + -alconna_lang->alconna_arpamar - - - - +alconna_core->alconna_components_output + + + - + -alconna_lang->alconna_base - - - +alconna_core->alconna_manager + + + - + + +alconna_exceptions + +alconna. +exceptions + + -alconna_lang->alconna_components_action - - +alconna_exceptions->alconna + + - + -alconna_lang->alconna_components_duplication - - +alconna_exceptions->alconna_analysis_analyser + + + - + -alconna_lang->alconna_components_stub - - - +alconna_exceptions->alconna_analysis_parts + - + -alconna_lang->alconna_core - - +alconna_exceptions->alconna_arpamar + - + -alconna_lang->alconna_manager - +alconna_exceptions->alconna_base + + + + - + -alconna_lang->alconna_typing - - +alconna_exceptions->alconna_components_action + - + -alconna_manager->alconna - - - +alconna_exceptions->alconna_manager + + + - - -alconna_manager->alconna_analysis_parts - - + + +alconna_exceptions->alconna_typing + + - - -alconna_manager->alconna_components_action - - + + +alconna_manager->alconna + + alconna_typing->alconna - - - + + + alconna_typing->alconna_analysis_analyser - - + + alconna_typing->alconna_analysis_base - - + + alconna_typing->alconna_analysis_parts - + alconna_typing->alconna_arpamar - - - - - + + alconna_typing->alconna_base - + + alconna_typing->alconna_components_action - - - + + alconna_typing->alconna_components_stub - - + + + alconna_typing->alconna_core - - + + alconna_typing->alconna_manager - - - - + - + alconna_util - -alconna.util + +alconna.util alconna_util->alconna - - - - - - + alconna_util->alconna_analysis_analyser - - + + alconna_util->alconna_analysis_parts - - - + + - + +alconna_util->alconna_components_action + + + + + alconna_util->alconna_components_output - - - + + + - + alconna_util->alconna_manager - - + + - + alconna_util->alconna_typing - - + + - + typing_extensions - -typing_extensions + +typing_extensions + + + +typing_extensions->alconna_builtin_checker + + - + typing_extensions->alconna_typing - - + + diff --git a/benchmark.py b/benchmark.py new file mode 100644 index 00000000..39968659 --- /dev/null +++ b/benchmark.py @@ -0,0 +1,72 @@ +import time +from arclet.alconna import Alconna, Args, AnyOne, compile, command_manager, config +import cProfile +import pstats + + +class Plain: + type = "Plain" + text: str + + def __init__(self, t: str): + self.text = t + + +class At: + type = "At" + target: int + + def __init__(self, t: int): + self.target = t + + +alc = Alconna( + headers=["."], + command="test", + main_args=Args["bar", AnyOne] +) +compile_alc = compile(alc) + +msg = [Plain(".test"), At(124)] +count = 20000 + +config.enable_message_cache = True + +if __name__ == "__main__": + + sec = 0.0 + for _ in range(count): + st = time.time() + compile_alc.process_message(msg) + compile_alc.analyse() + ed = time.time() + sec += ed - st + print(f"Alconna: {count / sec:.2f}msg/s") + + print("RUN 2:") + li = [] + + pst = time.time() + for _ in range(count): + st = time.thread_time_ns() + compile_alc.process_message(msg) + compile_alc.analyse() + ed = time.thread_time_ns() + li.append(ed - st) + led = time.time() + + print(f"Alconna: {sum(li) / count} ns per loop with {count} loops") + + command_manager.records.clear() + + prof = cProfile.Profile() + prof.enable() + for _ in range(count): + compile_alc.process_message(msg) + compile_alc.analyse() + prof.create_stats() + + stats = pstats.Stats(prof) + stats.strip_dirs() + stats.sort_stats('tottime') + stats.print_stats(20) diff --git a/changelog.md b/changelog.md index ea37307b..f58fe839 100644 --- a/changelog.md +++ b/changelog.md @@ -1,3 +1,27 @@ +# Alconna 1.0.x: + +\^o^/ + +## Alconna 1.0.0: +1. 将`lang`迁移到新增的`config`中,并为`config`加入了如全局分隔、开启缓存等选项 +2. 压缩代码量并规范化 +3. `--help` 选项允许在命令任何部位生效, 并且会根据当前命令选择是否展示选项的帮助信息 +4. `Args` name 的flag附加现在不需要以`|`分隔 +5. `Args` name 允许用`#...`为单个Arg提供注释, 其会展示在帮助信息内 +6. `Args` 允许传入 `Callable[[A], B]` 作为表达, 会自动解析输入类型与输出类型 +7. 完善了测试代码, 位于[测试文件夹](test_alconna)内, 通过[入口文件](test_alconna/entry_test.py)可执行全部测试 +8. 加入一个类似`beartype`的[`checker`](src/arclet/alconna/builtin/checker.py) +9. 命令头部允许使用非str类型, 即可以`Alconna(int)` +10. 解析器增加预处理器选项, 允许在分划数据单元前进行转化处理 +11. 性能提升, 理想情况最快约为 20w msg/s +12. 删除`Alconna.set_action` +13. 重构 `ObjectPattern` +14. 增加 `datetime`的 BasePattern, 支持传入时间戳或日期文字 +15. `Analyser` 的字段修改, `next_data` -> `popitem`, `reduce_data` -> `pushback` +16. `output_send` 合并到 `output_manager` +17. `Option` 添加参数`priority`, 仅在需要选项重载时安排优先级 +18. 修复bugs + # Alconna 0.9.x: ## Alconna 0.9.0 - 0.9.0.3: diff --git a/commander/broadcast.py b/commander/broadcast.py deleted file mode 100644 index 0bfc653e..00000000 --- a/commander/broadcast.py +++ /dev/null @@ -1,142 +0,0 @@ -from typing import List, Callable, Dict, Union, Type, Any -import re -from arclet.alconna import Alconna, Arpamar, Default, Option -from arclet.alconna.types import AnyStr, Bool -from arclet.alconna.util import split_once -from graia.broadcast.entities.decorator import Decorator -from graia.broadcast.interfaces.decorator import DecoratorInterface -from graia.broadcast.entities.dispatcher import BaseDispatcher -from graia.broadcast.interfaces.dispatcher import DispatcherInterface -from graia.broadcast.utilles import argument_signature - - -class Positional: - - def __init__( - self, - position: Union[str, int], - *, - type: Type = str, - default: Any = None - ): - self.position = position - self.type = type - self.default = default - - def convert(self): - alc_type = Bool if self.type == bool else AnyStr - if not self.default: - return alc_type - return Default(alc_type, self.default) - - -class AdditionParam(Decorator): - pre = True - - def __init__( - self, - params: List[str], - *, - type: Type = str, - default: Any = None, - return_pos: Union[str, int] = None - ): - self.params = params - self.type = type - self.default = default - self.return_pos = return_pos - - def convert(self, key: str): - opt_list = [] - alc_type = Bool if self.type == bool else AnyStr - for param in self.params: - if re.match(r"^.+{(.+)}", param): - if self.default: - opt_list.append(Option(split_once(param, " ")[0], **{key: Default(alc_type, self.default)})) - else: - opt_list.append(Option(split_once(param, " ")[0], **{key: alc_type})) - else: - opt_list.append(Option(param)) - return opt_list - - async def target(self, interface: DecoratorInterface): - if interface.name in interface.local_storage: - return interface.local_storage.get(interface.name) - elif self.default is not None: - return self.default - - -class BaseCommand: - name: str - alconna: Alconna - callable_func: Callable - result: Arpamar - - def __init__(self, alc: Alconna, func: Callable): - self.alconna = alc - self.name = alc.command - self.callable_func = func - - def __repr__(self): - return f'' - - def __eq__(self, other: "BaseCommand"): - return self.alconna == other.alconna - - def exec(self, msg: str) -> Dict: - self.result = self.alconna.analyse_message(msg) - return self.result.option_args - - -class AlconnaCommander: - command_list: List[BaseCommand] - - def __init__(self, broadcast): - self.broadcast = broadcast - self.command_list = [] - - def command(self, format_string: str): - def wrapper(func): - params = argument_signature(func) - format_args, reflect_map, option_list = self.param_handler(params) - bc = BaseCommand(Alconna.format(format_string, format_args, reflect_map), func) - if option_list: - bc.alconna.add_options(option_list) - self.command_list.append(bc) - return func - - return wrapper - - @staticmethod - def param_handler(param): - result_dict = {} - reflect_map = {} - option_list = [] - for name, _, default in param: - if isinstance(default, Positional): - index, args = default.position, default.convert() - result_dict[str(index)] = args - reflect_map[str(index)] = name - elif isinstance(default, AdditionParam): - option_list.extend(default.convert(name)) - return result_dict, reflect_map, option_list - - @staticmethod - def dispatcher_generator(opt_args): - class _Dispatcher(BaseDispatcher): - async def catch(self, interface: DispatcherInterface): - interface.execution_contexts[-1].local_storage = opt_args - if interface.name in opt_args: - return opt_args.get(interface.name) - - return _Dispatcher() - - def post_message(self, msg): - for command in self.command_list: - if args := command.exec(msg): - self.broadcast.loop.create_task( - self.broadcast.Executor( - command.callable_func, - dispatchers=[self.dispatcher_generator(args)] - ) - ) diff --git a/commander/commander.py b/commander/commander.py new file mode 100644 index 00000000..1ab2e441 --- /dev/null +++ b/commander/commander.py @@ -0,0 +1,140 @@ +from arclet.alconna import ( + Alconna, + AlconnaString, + output_manager, + command_manager, + split_once, + config +) +from arclet.letoderea.entities.subscriber import Subscriber +from arclet.letoderea.handler import await_exec_target +from typing import Callable, Dict, Type, Optional +from arclet.edoves.main.interact.module import BaseModule, ModuleMetaComponent, Component +from arclet.edoves.main.typings import TProtocol +from arclet.edoves.main.utilles.security import EDOVES_DEFAULT + +from arclet.edoves.builtin.event.message import MessageReceived +from arclet.edoves.builtin.medium import Message + + +class CommandParser: + command: Alconna + param_reaction: Callable + + def __init__(self, alconna: Alconna, func: Callable): + self.command = alconna + self.param_reaction = Subscriber(func) + + async def exec(self, params): + await await_exec_target(self.param_reaction, MessageReceived.param_export(**params)) + + +class CommanderData(ModuleMetaComponent): + verify_code: str = EDOVES_DEFAULT + identifier = "edoves.builtin.commander" + name = "Builtin Commander Module" + description = "Based on Edoves and Arclet-Alconna" + usage = """\n@commander.command("test ")\ndef test(foo: str):\n\t...""" + command_namespace: str + max_command_length: int = 10 + + +class CommandParsers(Component): + io: "Commander" + parsers: Dict[str, CommandParser] + + def __init__(self, io: "Commander"): + super(CommandParsers, self).__init__(io) + self.parsers = {} + + def command( + self, + command: str, + *option: str, + sep: str = " " + ): + alc = AlconnaString(command, *option, sep=sep) + + def __wrapper(func): + cmd = CommandParser(alc, func) + self.parsers.setdefault(alc.name, cmd) + return command + + return __wrapper + + def shortcut(self, shortcut: str, command: str): + name = split_once(command, " ")[0] + cmd = self.parsers.get(name) + if cmd is None: + return + cmd.command.shortcut(shortcut, command) + + def remove_handler(self, command: str): + del self.parsers[command] + + +class Commander(BaseModule): + prefab_metadata = CommanderData + command_parsers: CommandParsers + + __slots__ = ["command_parsers"] + + def __init__(self, protocol: TProtocol, namespace: Optional[str] = None): + super().__init__(protocol) + self.metadata.command_namespace = namespace or protocol.current_scene.scene_name + "_Commander" + self.command_parsers = CommandParsers(self) + if self.local_storage.get(self.__class__): + for k, v in self.local_storage[self.__class__].items(): + self.get_component(CommandParsers).parsers.setdefault(k, v) + config.set_loop(self.protocol.screen.edoves.loop) + + @self.behavior.add_handlers(MessageReceived) + async def command_message_handler(message: Message): + async def _action(doc: str): + await message.set(doc).send() + + for cmd, psr in self.command_parsers.parsers.items(): + output_manager.set_send_action(_action, psr.command.name) + result = psr.command.parse(message.content) + if result.matched: + await psr.exec( + { + **result.all_matched_args, + "result": result, + "message": message, + "sender": message.purveyor, + "edoves": self.protocol.screen.edoves, + "scene": self.protocol.current_scene + } + ) + break + + @self.command("help #显示帮助") + async def _(message: Message, page: int): + await message.set(command_manager.all_command_help( + self.metadata.command_namespace, + max_length=self.metadata.max_command_length, + page=page + )).send() + + def command( + __commander_self__, + command: str, + *option: str, + sep: str = " " + ): + alc = AlconnaString(command, *option, sep=sep).reset_namespace( + __commander_self__.metadata.command_namespace + ) + + def __wrapper(func): + cmd = CommandParser(alc, func) + try: + __commander_self__.command_parsers.parsers.setdefault(alc.headers[0], cmd) + except AttributeError: + if not __commander_self__.local_storage.get(__commander_self__.__class__): + __commander_self__.local_storage.setdefault(__commander_self__.__class__, {}) + __commander_self__.local_storage[__commander_self__.__class__].setdefault(alc.headers[0], cmd) + return command + + return __wrapper diff --git a/commander/letoderea.py b/commander/letoderea.py deleted file mode 100644 index f0775dcd..00000000 --- a/commander/letoderea.py +++ /dev/null @@ -1,84 +0,0 @@ -from typing import List, Callable, Dict, Union, Type, Any -from arclet.letoderea.utils import argument_analysis -from arclet.letoderea.handler import await_exec_target -from arclet.alconna import Alconna, Arpamar, Default -from arclet.alconna.types import AnyStr, Bool - - -class Positional: - - def __init__( - self, - position: Union[str, int], - *, - type: Type = str, - default: Any = None - ): - self.position = position - self.type = type - self.default = default - - def convert(self): - alc_type = Bool if self.type == bool else AnyStr - if not self.default: - return alc_type - return Default(alc_type, self.default) - - -class BaseCommand: - name: str - alconna: Alconna - callable_func: Callable - result: Arpamar - - def __init__(self, alc: Alconna, func: Callable): - self.alconna = alc - self.name = alc.command - self.callable_func = func - - def __repr__(self): - return f'' - - def __eq__(self, other: "BaseCommand"): - return self.alconna == other.alconna - - def exec(self, msg: str) -> Dict: - self.result = self.alconna.analyse_message(msg) - return self.result.option_args - - -class AlconnaCommander: - command_list: List[BaseCommand] - - def __init__(self, broadcast): - self.broadcast = broadcast - self.command_list = [] - - def command(self, format_string: str): - def wrapper(func): - params = argument_analysis(func) - format_args, reflect_map, option_list = self.param_handler(params) - bc = BaseCommand(Alconna.format(format_string, format_args, reflect_map), func) - if option_list: - bc.alconna.add_options(option_list) - self.command_list.append(bc) - return func - - return wrapper - - @staticmethod - def param_handler(param): - result_dict = {} - reflect_map = {} - option_list = [] - for name, _, default in param: - index, args = default.position, default.convert() - result_dict[str(index)] = args - reflect_map[str(index)] = name - return result_dict, reflect_map, option_list - - def post_message(self, msg): - for command in self.command_list: - if args := command.exec(msg): - self.broadcast.loop.create_task(await_exec_target(command.callable_func, args)) - diff --git a/commander/test_commander_bcc.py b/commander/test_commander_bcc.py deleted file mode 100644 index 61af900f..00000000 --- a/commander/test_commander_bcc.py +++ /dev/null @@ -1,34 +0,0 @@ -from commander.broadcast import AlconnaCommander, Positional, AdditionParam -from devtools import debug -from asyncio import sleep -from graia.broadcast import Broadcast - -bcc = Broadcast() -commander = AlconnaCommander(bcc) - - -@commander.command("lp user {0} permission set {1} {2}") -async def user_permission_set( - target: str = Positional(0, type=str), - perm_node: str = Positional(1, type=str), - perm_value: bool = Positional(2, type=bool, default=True), - param1: bool = AdditionParam(['-p1', '--param1'], type=bool, default=False), - param2: str = AdditionParam(['-p2 {0}', '--param2 {0}'], type=str, default="default") - -): - print("target", target) - print("perm_node", perm_node) - print("perm_value", perm_value) - print("param1", param1) - print("param2", param2) - - -debug(commander) - - -async def main(): - commander.post_message("lp user AAA permission set admin False -p2 a") - await sleep(0.1) - - -commander.broadcast.loop.run_until_complete(main()) diff --git a/commander/test_commander_leto.py b/commander/test_commander_leto.py deleted file mode 100644 index 6d83e8bd..00000000 --- a/commander/test_commander_leto.py +++ /dev/null @@ -1,30 +0,0 @@ -from commander.letoderea import AlconnaCommander, Positional -from devtools import debug -from asyncio import sleep -from arclet.letoderea import EventSystem - -es = EventSystem() -commander = AlconnaCommander(es) - - -@commander.command("lp user {0} permission set {1} {2}") -async def user_permission_set( - target: str = Positional(0, type=str), - perm_node: str = Positional(1, type=str), - perm_value: bool = Positional(2, type=bool, default=True), - -): - print("target", target) - print("perm_node", perm_node) - print("perm_value", perm_value) - - -debug(commander) - - -async def main(): - commander.post_message("lp user AAA permission set admin") - await sleep(0.1) - - -commander.broadcast.loop.run_until_complete(main()) diff --git a/dev_tools/benchmark.py b/dev_tools/benchmark.py deleted file mode 100644 index 53dfd395..00000000 --- a/dev_tools/benchmark.py +++ /dev/null @@ -1,50 +0,0 @@ -import time -from arclet.alconna import Alconna, Arpamar, Args, AnyOne, compile, command_manager -import cProfile -import pstats - - -class Plain: - type = "Plain" - text: str - - def __init__(self, t: str): - self.text = t - - -class At: - type = "At" - target: int - - def __init__(self, t: int): - self.target = t - - -ping = Alconna( - headers=["."], - command="test", - main_args=Args["bar", AnyOne] -) -s_ping = compile(ping) - -msg = [Plain(".test"), At(124)] -count = 10000 - -if __name__ == "__main__": - st = time.time() - - for _ in range(count): - s_ping.analyse(msg) - ed = time.time() - print(f"Alconna: {count / (ed - st):.2f}msg/s") - command_manager.records.clear() - prof = cProfile.Profile() - prof.enable() - for _ in range(count): - s_ping.analyse(msg) - prof.create_stats() - - stats = pstats.Stats(prof) - stats.strip_dirs() - stats.sort_stats('tottime') - stats.print_stats(20) diff --git a/dev_tools/test_alconna_1.py b/dev_tools/test_alconna_1.py deleted file mode 100644 index 752826b8..00000000 --- a/dev_tools/test_alconna_1.py +++ /dev/null @@ -1,211 +0,0 @@ -from typing import Union, Any - -from arclet.alconna import Alconna, Args, AlconnaString, Subcommand, Option -from graia.ariadne.message.chain import MessageChain -from arclet.alconna.builtin.formatter import ArgParserTextFormatter - -from graia.ariadne.message.element import At - -ar = Args["test", bool, True]["aaa", str, "bbb"] << Args["perm", str, ...] + ["month", int] -a = "bbb" -b = str -c = "fff" -ar1 = Args[a, b, c] -ar["foo"] = ["bar", ...] -print(ar) -print(ar1) - -ping = Alconna( - headers=["/", "!"], - command="ping", - options=[ - Subcommand( - "test", [Option("-u", Args["username", str], help_text="输入用户名")], args=Args["test", "Test"], - help_text="测试用例" - ), - Option("-n|--num", Args["count", int, 123], help_text="输入数字"), - Option("-u", Args(At=At), help_text="输入需要At的用户") - ], - main_args=Args["IP", "ip"], - help_text="简单的ping指令" -) -print(ping.get_help()) -msg = MessageChain.create("/ping -u", At(123), "test Test -u AAA -n 222 127.0.0.1") -print(msg) -print(ping.parse(msg)) - -msg1 = MessageChain.create("/ping 127.0.0.1 -u", At(123)) -print(msg1) -print(ping.parse(msg1).all_matched_args) - -msg2 = MessageChain.create("/ping a") -print(msg2) -result = ping.parse(msg2) -print(result.header) -print(result.head_matched) - -pip = Alconna( - command="/pip", - options=[ - Subcommand("install", [Option("--upgrade", help_text="升级包")], Args["pak", str], help_text="安装一个包"), - Option("--retries", Args["retries", int], help_text="设置尝试次数"), - Option("-t|--timeout", Args["sec", int], help_text="设置超时时间"), - Option("--exists-action", Args["action", str], help_text="添加行为"), - Option("--trusted-host", Args["host_name", "url"], help_text="选择可信赖地址") - ], - help_text="简单的pip指令", - formatter_type=ArgParserTextFormatter -) -print(pip.get_help()) -msg = "/pip install ces --upgrade -t 6 --trusted-host http://pypi.douban.com/simple" -print(msg) -print(pip.parse(msg).all_matched_args) - -aaa = Alconna(headers=[".", "!"], command="摸一摸", main_args=Args["At", At]) -msg = MessageChain.create(".摸一摸", At(123)) -print(msg) -print(aaa.parse(msg).matched) - -ccc = Alconna( - headers=[""], - command="4help", - main_args=Args["aaa", str], -) -msg = "4help 'what he say?'" -print(msg) -result = ccc.parse(msg) -print(result.main_args) - -eee = Alconna("RD{r:int}?=={e:int}") -msg = "RD100==36" -result = eee.parse(msg) -print(result.header) - -weather = Alconna( - headers=['渊白', 'cmd.', '/bot '], - command="{city}天气", - options=[ - Option("时间", "days:str").separate('='), - ], -) -msg = MessageChain.create('渊白桂林天气 时间=明天') -result = weather.parse(msg) -print(result) -print(result.header) - -msg = MessageChain.create('渊白桂林天气 aaa bbb') -result = weather.parse(msg) -print(result) - -msg = MessageChain.create(At(123)) -result = weather.parse(msg) -print(result) - -ddd = Alconna( - command="Cal", - options=[ - Subcommand( - "-div", - options=[ - Option( - "--round|-r", - args=Args(decimal=int), - action=lambda x: f"{x}a", - help_text="保留n位小数", - ) - ], - args=Args(num_a=int, num_b=int), - help_text="除法计算", - ) - ], -) - -msg = "Cal -div 12 23 --round 2" -print(msg) -print(ddd.get_help()) -result = ddd.parse(msg) -print(result.query('div')) - -ddd = Alconna( - "点歌" -).add_option( - "歌名", sep=":", args=Args(song_name=str) -).add_option( - "歌手", sep=":", args=Args(singer_name=str) -) -msg = "点歌 歌名:Freejia" -print(msg) -result = ddd.parse(msg, static=False) -print(result.all_matched_args) - -give = AlconnaString("give ") -print(give) -print(give.parse("give")) - - -def test_act(content): - print(content) - return content - - -wild = Alconna( - headers=[At(12345)], - command="丢漂流瓶", - main_args=Args["wild", Any], - action=test_act, - help_text="丢漂流瓶" -) -# print(wild.parse("丢漂流瓶 aaa bbb ccc").all_matched_args) -msg = MessageChain.create(At(12345), " 丢漂流瓶 aa\t\nvv") -print(wild.parse(msg)) - -get_ap = Alconna( - command="AP", - main_args=Args(type=str, text=str) -) - -test = Alconna( - command="test", - main_args=Args(t=int) -).reset_namespace("TEST") -print(test) -print(test.parse([get_ap.parse("AP Plain test"), get_ap.parse("AP At 123")])) - -# print(command_manager.commands) - -double_default = Alconna( - command="double", - main_args=Args(num=int).default(num=22), - options=[ - Option("--d", Args(num1=int).default(num1=22)) - ] -) - -result = double_default.parse("double --d") -print(result) - -choice = Alconna( - command="choice", - main_args=Args["part", ["a", "b", "c"]], - help_text="选择一个部分" -) -print(choice.parse("choice d")) -print(choice.get_help()) - -sub = Alconna( - command="test_sub_main", - main_args="baz:int", - options=[ - Subcommand( - "sub", - options=[Option("--subOption", Args["subOption", Union[At, int]])], - args=Args.foo[str] - ) - ] -) -print(sub.get_help()) -res = sub.parse("test_sub_main 1 sub --subOption 123 a") -print(res) -print(res.query('sub.foo')) - - diff --git a/dev_tools/test_alconna_2.py b/dev_tools/test_alconna_2.py deleted file mode 100644 index e173b8f1..00000000 --- a/dev_tools/test_alconna_2.py +++ /dev/null @@ -1,124 +0,0 @@ -from typing import Union, Dict - -from arclet.alconna import compile -from arclet.alconna.builtin.construct import AlconnaString, AlconnaFormat -from arclet.alconna.typing import pattern_gen -from arclet.alconna import Alconna, Args, Option -from arclet.alconna import command_manager -from graia.ariadne.message.element import At -from devtools import debug - -print(command_manager) - -ping = "Test" / AlconnaString("ping ") -ping1 = AlconnaString("ping ") - -Alconna.set_custom_types(digit=int) -alc = AlconnaFormat( - "lp user {target} perm set {perm} {default}", - {"target": str, "perm": str, "default": Args["de", bool, True]}, -) -alcc = AlconnaFormat( - "lp1 user {target}", - {"target": str} -) - -alcf = AlconnaFormat("music {artist} {title:str} singer {name:str}") -alcf.parse("music --help") -debug(alc) -alc.exception_in_time = False -debug(alc.parse("lp user AAA perm set admin")) - -aaa = AlconnaFormat("a {num}", {"num": int}) -r = aaa.parse("a 1") -print(aaa) -print(r) -print('\n') - - -def test(wild, text: str, num: int, boolean: bool = False): - print('wild:', wild) - print('text:', text) - print('num:', num) - print('boolean:', boolean) - - -alc1 = Alconna("test5", action=test) - -print(alc1) - - -@pattern_gen("test_type", r"(\[.*?])") -def test_type(text: str): - return eval(text) - - -alc2 = Alconna("test", help_text="测试help直接发送") + \ - Option("foo", Args["bar", str]["bar1", int, 12345]["bar2", test_type]) -alc2.parse("test --help") - -alc4 = Alconna( - command="test_multi", - options=[ - Option("--foo", Args["tags;S", str, 1]["str1", str]), - Option("--bar", Args["num", int]), - ] -) - -print(alc4.parse("test_multi --foo ab --bar 1")) -alc4.shortcut("st", "test_multi --foo ab --bar 1") -result = alc4.parse("st") -print(result) -print(result.query("foo")) - -alc5 = Alconna("test_anti", "path;A:int") -print(alc5.parse("test_anti a")) - -alc6 = Alconna("test_union", main_args=Args.path[[int, float, 'abc']]) -print(alc6.parse("test_union abc")) -print(alc6.parse(["test_union 123"])) - -alc7 = Alconna("test_list", main_args=Args.seq[list]) -print(alc7) -print(alc7.parse("test_list \"['1', '2', '3']\"")) - -alc8 = Alconna("test_dict", main_args=Args.map[Dict[str, int]]) -print(alc8) -print(alc8.parse("test_dict {'a':1,'b':2}")) - -alc9 = Alconna("test_str", main_args="foo;K:str, bar:list, baz;O:int") -print(alc9) -print(alc9.parse("test_str foo=a \"[1]\"")) - -alc10 = Alconna("test_bool", main_args="foo;H|O:str") -print(alc10.parse(["test_bool", 1])) -print(alc10.get_help()) - -alc11 = Alconna("test_header", headers=[(At(123456), "abc")]) -print("alc11:", alc11.parse([At(123456), "abctest_header"])) - -alc12 = Alconna("test_str1", Args["abcd"]["1234", "1234"]) -print("alc12:", alc12.parse("test_str1 abcd 1234")) - -alc13 = Alconna("image", Args["--width;O|K", int, 1920]["--height;O|K", int, 1080]) -print("alc13:", alc13.parse("image --height=720")) - -alc14 = Alconna(main_args="foo:str", headers=['!test_fuzzy'], is_fuzzy_match=True) -print(alc14.parse("test_fuzy foo bar")) - -alc15 = AlconnaString("my_string", "--foo [bar:bool]", "--bar &True") -print(alc15.parse("my_string --foo 123 --bar")) - -alc16 = Alconna( - "发涩图", - Args["min", r"(\d+)张"]["max;O", r"(\d+)张"] / "到", - options=[Option("从", Args["tags;3", str] / "和")], - action=lambda x, y: (int(x), int(y)) -) -s_alc15 = compile(alc15) -s_alc16 = compile(alc16) -print(s_alc16.analyse("发涩图 3张到5张 从 方舟和德能和拉德")) -print(s_alc15.analyse("发涩图 3张到5张 从 方舟和德能和拉德")) - -print(s_alc16.analyse("发涩图 3张到5张 从 方舟和德能和拉德")) -print(s_alc16.analyse("发涩图 3张到5张 从 德能和拉德")) \ No newline at end of file diff --git a/dev_tools/test_alconna_decorate.py b/dev_tools/test_alconna_decorate.py deleted file mode 100644 index e9ca804b..00000000 --- a/dev_tools/test_alconna_decorate.py +++ /dev/null @@ -1,62 +0,0 @@ -from arclet.alconna import Args -from arclet.alconna.builtin.construct import AlconnaDecorate -from graia.broadcast import Broadcast, DispatcherInterface -from graia.broadcast.entities.dispatcher import BaseDispatcher -from graia.broadcast.entities.exectarget import ExecTarget -from arclet.letoderea.handler import await_exec_target -from arclet.letoderea.entities.event import TemplateEvent -import asyncio - - -loop = asyncio.new_event_loop() - -c1 = AlconnaDecorate(loop=loop) -c2 = AlconnaDecorate(loop=loop) - - -@c1.set_default_parser -def _(func, args, l_args, loo): - class _D(BaseDispatcher): - async def catch(self, interface: DispatcherInterface): - if interface.name in args: - return args[interface.name] - if interface.name in l_args: - return l_args[interface.name] - if interface.default: - return interface.default - - target = ExecTarget(callable=func) - loo.run_until_complete(bcc.Executor(target, [_D()])) - - -bcc = Broadcast(loop=loop) - - -@c2.set_default_parser -def _(func, args, l_args, loo): - loo.run_until_complete( - await_exec_target( - func, - TemplateEvent.param_export(**{**args, **l_args}) - ) - ) - - -@c1.build_command() -@c1.option("--count", Args["num", int], help="Test Option Count") -@c1.option("--foo", Args["bar", str], help="Test Option Foo") -def hello(bar: str, num: int = 1): - """测试DOC""" - print(bar * num) - - -@c2.build_command("halo") -@c2.option("--count", Args["num", int], help="Test Option Count") -@c2.option("--foo", Args["bar", str], help="Test Option Foo") -def halo(bar: str, num: int = 1): - """测试DOC""" - print(bar * num) - - -if __name__ == "__main__": - hello("hello --foo John") diff --git a/dev_tools/test_alconna_fire.py b/dev_tools/test_alconna_fire.py deleted file mode 100644 index 19c1b02f..00000000 --- a/dev_tools/test_alconna_fire.py +++ /dev/null @@ -1,77 +0,0 @@ -from typing import Optional - -from arclet.alconna import Args, Option, Subcommand - -from arclet.alconna.builtin.construct import AlconnaFire, delegate - - -class Test: - """测试从类中构建对象""" - - def __init__(self, sender: Optional[str]): - """Constructor""" - self.sender = sender - - def talk(self, name="world"): - """Test Function""" - print(f"Hello {name} from {self.sender}") - - class MySub: - """Sub Class""" - - def __init__(self): - """Constructor""" - self.sender = "sub_command" - - def set(self, name="hello"): - """Test Function""" - print(f"SUBCOMMAND {name} from {self.sender}") - - class SubConfig: - command = "subcmd" - - class Config: - headers = ["!"] - command = "test_fire" - get_subcommand = True - - -alc = AlconnaFire(Test) -alc.parse("!test_fire alc talk") -alc.parse("!test_fire --help") -alc.parse("!test_fire talk ALC subcmd set") - - -def test_function(name="world"): - """测试从函数中构建对象""" - print(f"Hello {name}!") - - -alc1 = AlconnaFire(test_function) -alc1.parse("test_function --help") - - -class Test1: - def __init__(self): - ... - - def calculator(self, a, b, c, *nums: int, **kwargs: str): - """calculator""" - print(a, b, c) - print(nums, kwargs) - print(sum(nums)) - - -alc = AlconnaFire(Test1) -alc.parse("Test1 --help") -alc.parse("Test1 calculator 1 2 3 4 5 d=4 f=5") - - -@delegate -class Test: - args = Args["foo", int]["bar", str] - opt1 = Option("--opt", alias=["-o"]) - sub1 = Subcommand("sub1", args=Args["baz", int]) - - -print(Test.parse("Test --opt sub1 1 123 abc")) diff --git a/dev_tools/test_args.py b/dev_tools/test_args.py deleted file mode 100644 index 1554d269..00000000 --- a/dev_tools/test_args.py +++ /dev/null @@ -1,69 +0,0 @@ -import time -from typing import Union - -from arclet.alconna import Args -from arclet.alconna.analysis.base import analyse_args - -print("\nArgs KVWord construct:") -arg = Args(pak=str, upgrade=bool).default(upgrade=True) -print("arg:", arg) -print(analyse_args(arg, "arclet-alconna True")) - -print("\nArgs Magic construct:") -arg1 = Args["round", float]["test", bool, True]["aaa", str] << Args["perm", str, ...] + ["month", int] -arg1["foo"] = ["bar", ...] -arg11: Args = Args.baz[int] -arg11.add_argument("foo", value=int, default=1) -print("arg1:", arg1) -print("arg11:", arg11) - -print("\nArgs Feature: Default value") -arg2 = Args["foo", int]["de", bool, True] -print("arg2:", arg2) -print(analyse_args(arg2, "123 False")) -print(analyse_args(arg2, "123")) - -print("\nArgs Feature: Choice") -arg3 = Args["choice", ("a", "b", "c")] -print("arg3:", arg3) -print(analyse_args(arg3, "a")) # OK -time.sleep(0.1) -print(analyse_args(arg3, "d")) # error - -print("\nArgs Feature: Multi") -arg4 = Args["multi;S", str] -print("arg4:", arg4) -print(analyse_args(arg4, "a b c d")) -arg44 = Args["kwargs;W", str] -print("arg44:", arg44) -print(analyse_args(arg44, "a=b c=d")) - -print("\nArgs Feature: Anti") -arg5 = Args["anti;A", r"(.+?)/(.+?)\.py"] -print("arg5:", arg5) -print(analyse_args(arg5, "a/b.mp3")) # OK -time.sleep(0.1) -print(analyse_args(arg5, "a/b.py")) # error - -print("\nArgs Feature: Union") -arg6 = Args["bar", Union[float, int]] -print("arg6:", arg6) -print(analyse_args(arg6, "1.2")) # OK -time.sleep(0.1) -print(analyse_args(arg6, "1")) # OK - -print("\nArgs Feature: Force") -arg7 = Args["bar;F", bool] -print("arg7:", arg7) -print(analyse_args(arg7, "True")) # error - -print("\nArgs Feature: Optional") -arg8 = Args["bar;O", int] -print("arg8:", arg8) -print(analyse_args(arg8, "abc")) # OK - -print("\nArgs Feature: KWord") -arg9 = Args["bar;K", int] -print("arg9:", arg9) -print(analyse_args(arg9, "bar=123")) # OK -print(analyse_args(arg9, "123")) # error diff --git a/dev_tools/test_behavior.py b/dev_tools/test_behavior.py deleted file mode 100644 index d3f846f2..00000000 --- a/dev_tools/test_behavior.py +++ /dev/null @@ -1,38 +0,0 @@ -from arclet.alconna import Alconna, Args, Option, Arpamar -from arclet.alconna.components.behavior import ArpamarBehavior -from arclet.alconna.builtin.actions import set_default, exclusion, cool_down -import time - -alc = Alconna("command", Args["bar", int]) + Option("foo") - - -@alc.behaviors.append -class Test(ArpamarBehavior): - requires = [set_default(321, option="foo")] - - @classmethod - def operate(cls, interface: "Arpamar"): - print(interface.query("options.foo.value")) - interface.matched = False - - -print(alc.parse(["command", "123"])) - -alc1 = Alconna( - "test_exclusion", - options=[ - Option("foo"), - Option("bar"), - ], - behaviors=[exclusion(target_path="options.foo", other_path="options.bar")] -) -print(alc1.parse("test_exclusion\nfoo")) - -alc2 = Alconna( - "test_cool_down", - main_args=Args["bar", int], - behaviors=[cool_down(0.3)] -) -for i in range(4): - time.sleep(0.2) - print(alc2.parse(f"test_cool_down {i}")) diff --git a/dev_tools/test_duplication.py b/dev_tools/test_duplication.py deleted file mode 100644 index 31599c02..00000000 --- a/dev_tools/test_duplication.py +++ /dev/null @@ -1,28 +0,0 @@ -from arclet.alconna import Alconna, Args, Option, Subcommand -from arclet.alconna.components.duplication import AlconnaDuplication -from arclet.alconna.components.stub import ArgsStub, OptionStub, SubcommandStub - - -class Demo(AlconnaDuplication): - testArgs: ArgsStub - bar: OptionStub - sub: SubcommandStub - - -alc = Alconna( - "test", - Args["foo", int], - options=[ - Option("--bar", Args["bar", str]), - Subcommand("sub", options=[Option("--sub1", Args["baz", str])]) - ] -) -result = alc.parse("test 123 --bar abc sub --sub1 xyz") -print(result) -duplication = alc.parse("test 123 --bar abc sub --sub1 xyz", duplication=Demo) -print(duplication) -print(duplication.testArgs.get('foo')) -print(duplication.bar.available) -print(duplication.bar.args[0]) -print(duplication.subcommand("sub").available) -print(duplication.subcommand("sub").option("sub1").args.first_arg) diff --git a/dev_tools/test_formatter.py b/dev_tools/test_formatter.py deleted file mode 100644 index 0fec8b13..00000000 --- a/dev_tools/test_formatter.py +++ /dev/null @@ -1,25 +0,0 @@ -from arclet.alconna import Alconna, Args, Option, Subcommand -from arclet.alconna.builtin.formatter import DefaultTextFormatter, ArgParserTextFormatter - - -alc = Alconna("test_line", main_args="line:'...'") -print(alc.parse("test_line\nfoo\nbar\n")) - -alc1 = Alconna( - command="test", - help_text="test_help", - options=[ - Option("test", Args.foo[str], help_text="test_option"), - Subcommand( - "sub", - options=[ - Option("suboption", Args.foo[str], help_text="sub_option"), - Option("suboption2", Args.foo[str], help_text="sub_option2"), - ] - ) - ] -) -b = DefaultTextFormatter(alc1) -c = ArgParserTextFormatter(alc1) -print(b.format_node()) -print(c.format_node()) diff --git a/dev_tools/test_manager.py b/dev_tools/test_manager.py deleted file mode 100644 index 2d242738..00000000 --- a/dev_tools/test_manager.py +++ /dev/null @@ -1,5 +0,0 @@ -from dev_tools.test_alconna_2 import * - -print("\n\n## ------------- Test Manager -------------## \n\n") -print(command_manager.all_command_help(max_length=6, page=3, pages="[{current}/{total}]")) -print("\n") \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index de413418..a13afabb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,8 +5,8 @@ build-backend = "setuptools.build_meta" # Project metadata [project] name = "arclet-alconna" -version = "0.9.4" -description = "A Fast Command Analyser based on Dict" +version = "1.0.0" +description = "A High-performance, Generality, Humane Command Line Arguments Parser Library." readme = "README.md" requires_python = ">=3.8" license = "MIT" diff --git a/requirements.txt b/requirements.txt index a7cfe568..2126f50f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1 @@ -typing_extensions~=4.1.1 \ No newline at end of file +typing_extensions~=4.2.0 \ No newline at end of file diff --git a/requirements_text.txt b/requirements_text.txt new file mode 100644 index 00000000..a984e155 --- /dev/null +++ b/requirements_text.txt @@ -0,0 +1 @@ +pytest~=7.1.2 \ No newline at end of file diff --git a/setup.py b/setup.py index 4d01acf6..67ebcf4b 100644 --- a/setup.py +++ b/setup.py @@ -5,10 +5,10 @@ setuptools.setup( name="arclet-alconna", - version="0.9.4", + version="1.0.0", author="ArcletProject", author_email="rf_tar_railt@qq.com", - description="A Fast Command Analyser based on Dict", + description="A High-performance, Generality, Humane Command Line Arguments Parser Library.", license='MIT', long_description=long_description, long_description_content_type="text/rst", @@ -23,13 +23,17 @@ 'cli': [ 'arclet-alconna-cli' ], + 'full': [ + 'arclet-alconna-cli', 'arclet-alconna-graia' + ] }, classifiers=[ - "Development Status :: 4 - Beta", + "Development Status :: 5 - Production/Stable", "License :: OSI Approved :: MIT License", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", "Operating System :: OS Independent", ], include_package_data=True, diff --git a/src/arclet/alconna/__init__.py b/src/arclet/alconna/__init__.py index 33d469f9..5c66fc82 100644 --- a/src/arclet/alconna/__init__.py +++ b/src/arclet/alconna/__init__.py @@ -4,23 +4,24 @@ from .util import split_once, split, LruCache, Singleton from .base import CommandNode, Args, Option, Subcommand -from .typing import DataUnit, DataCollection, AnyOne, AllParam, Empty, PatternModel, set_converter, pattern_gen, Bind -from .exceptions import ParamsUnmatched, NullTextMessage, InvalidParam +from .typing import DataCollection, AnyOne, AllParam, Empty, PatternModel, set_converter, Bind +from .exceptions import ParamsUnmatched, NullMessage, InvalidParam from .analysis.base import compile, analyse from .core import Alconna from .arpamar import Arpamar from .manager import command_manager -from .lang import load_lang_file, lang_config +from .config import config, load_lang_file from .builtin.actions import store_value, set_default, exclusion, cool_down -from .builtin.construct import AlconnaDecorate, AlconnaFormat, AlconnaString, AlconnaFire, Argument +from .builtin.construct import AlconnaDecorate, AlconnaFormat, AlconnaString, AlconnaFire, Argument, delegate from .builtin.formatter import ArgParserTextFormatter, DefaultTextFormatter from .builtin.pattern import ObjectPattern -from .components.output import output_send, output_manager, AbstractTextFormatter +from .components.behavior import ArpamarBehavior +from .components.output import output_manager, AbstractTextFormatter from .components.duplication import AlconnaDuplication from .components.stub import ArgsStub, OptionStub, SubcommandStub -alconna_version = (0, 9, 4) +alconna_version = (1, 0, 0) if TYPE_CHECKING: from .builtin.actions import version diff --git a/src/arclet/alconna/analysis/analyser.py b/src/arclet/alconna/analysis/analyser.py index 88c1eea2..697fbe35 100644 --- a/src/arclet/alconna/analysis/analyser.py +++ b/src/arclet/alconna/analysis/analyser.py @@ -1,16 +1,17 @@ import re import traceback +from weakref import finalize +from copy import copy from abc import ABCMeta, abstractmethod -from typing import Dict, Union, List, Optional, TYPE_CHECKING, Tuple, Any, Pattern, Generic, TypeVar, \ - Set +from typing import Dict, Union, List, Optional, TYPE_CHECKING, Tuple, Any, Generic, TypeVar, Set, Callable from ..manager import command_manager -from ..exceptions import NullTextMessage -from ..base import Args, Option, Subcommand, Sentence +from ..exceptions import NullMessage +from ..base import Args, Option, Subcommand, Sentence, StrMounter from ..arpamar import Arpamar from ..util import split_once, split -from ..typing import DataUnit, DataCollection, pattern_map -from ..lang import lang_config +from ..typing import DataCollection, pattern_map, BasePattern, args_type_parser, TPattern +from ..config import config if TYPE_CHECKING: from ..core import Alconna @@ -19,27 +20,22 @@ class Analyser(Generic[T_Origin], metaclass=ABCMeta): - """ - Alconna使用的分析器基类, 实现了一些通用的方法 + """ Alconna使用的分析器基类, 实现了一些通用的方法 """ + preprocessors: Dict[str, Callable[..., Any]] = {} + text_sign: str = 'text' - Attributes: - current_index(int): 记录解析时当前数据的index - content_index(int): 记录内部index - head_matched: 是否匹配了命令头部 - """ alconna: 'Alconna' # Alconna实例 current_index: int # 当前数据的index content_index: int # 内部index is_str: bool # 是否是字符串 - raw_data: List[Union[Any, List[str]]] # 原始数据 + raw_data: List[Union[Any, StrMounter]] # 原始数据 ndata: int # 原始数据的长度 command_params: Dict[str, Union[Sentence, List[Option], Subcommand]] param_ids: Set[str] # 命令头部 command_header: Union[ - Pattern, - Tuple[Union[Tuple[List[Any], Pattern], List[Any]], Pattern], - List[Tuple[Any, Pattern]] + Union[TPattern, BasePattern], List[Tuple[Any, TPattern]], + Tuple[Union[Tuple[List[Any], TPattern], List[Any]], Union[TPattern, BasePattern]], ] separators: Set[str] # 分隔符 is_raise_exception: bool # 是否抛出异常 @@ -49,6 +45,7 @@ class Analyser(Generic[T_Origin], metaclass=ABCMeta): header: Optional[Union[Dict[str, Any], bool]] # 命令头部 need_main_args: bool # 是否需要主参数 head_matched: bool # 是否匹配了命令头部 + head_pos: Tuple[int, int] part_len: range # 分段长度 default_main_only: bool # 默认只有主参数 self_args: Args # 自身参数 @@ -58,14 +55,16 @@ class Analyser(Generic[T_Origin], metaclass=ABCMeta): temp_token: int # 临时token used_tokens: Set[int] # 已使用的token sentences: List[str] # 存放解析到的所有句子 + default_separate: bool def __init_subclass__(cls, **kwargs): if not hasattr(cls, "filter_out"): - raise TypeError(lang_config.analyser_filter_missing) + raise TypeError(config.lang.analyser_filter_missing) @staticmethod - def generate_token(data: List[Union[Any, List[str]]], hs=hash) -> int: - return hs(repr(data)) + def generate_token(data: List[Union[Any, List[str]]]) -> int: + # return hash(str(data)) + return hash(''.join(i.__str__() for i in data)) def __init__(self, alconna: "Alconna"): self.reset() @@ -77,10 +76,21 @@ def __init__(self, alconna: "Alconna"): self.is_raise_exception = alconna.is_raise_exception self.need_main_args = False self.default_main_only = False + self.default_separate = True + self.param_ids = set() + self.command_params = {} self.__handle_main_args__(alconna.args, alconna.nargs) self.__init_header__(alconna.command, alconna.headers) self.__init_actions__() + def _clr(a: 'Analyser'): + a.reset() + a.used_tokens.clear() + del a.origin_data + del a.alconna + + finalize(self, _clr, self) + def __handle_main_args__(self, main_args: Args, nargs: Optional[int] = None): nargs = nargs or len(main_args) if nargs > 0 and nargs > main_args.optional_count: @@ -91,15 +101,15 @@ def __handle_main_args__(self, main_args: Args, nargs: Optional[int] = None): def __init_header__( self, - command_name: str, - headers: Union[List[Union[str, DataUnit]], List[Tuple[DataUnit, str]]] + command_name: Union[str, type, BasePattern], + headers: Union[List[Union[str, Any]], List[Tuple[Any, str]]] ): - if len(parts := re.split("({.*?})", command_name)) > 1: + if isinstance(command_name, str) and len(parts := re.split(r"(\{.*?})", command_name)) > 1: for i, part in enumerate(parts): if not part: continue - if res := re.match("{(.*?)}", part): - _res = res.group(1) + if res := re.match(r"\{(.*?)}", part): + _res = res[1] if not _res: parts[i] = ".+?" continue @@ -113,14 +123,22 @@ def __init_header__( elif not _parts[1]: parts[i] = f"(?P<{_parts[0]}>.+?)" else: - parts[i] = f"(?P<{_parts[0]}>{pattern_map.get(_parts[1], _parts[1])})" + parts[i] = ( + f"(?P<{_parts[0]}>" + f"{pattern_map[_parts[1]].pattern if _parts[1] in pattern_map else _parts[1]})" + ) command_name = "".join(parts) + + if isinstance(command_name, str): + _command_name, _command_str = re.compile(command_name), command_name + else: + _command_name, _command_str = copy(args_type_parser(command_name)), str(command_name) if headers == [""]: - self.command_header = re.compile(command_name) + self.command_header = _command_name # type: ignore elif isinstance(headers[0], tuple): - mixins = [(h[0], re.compile(re.escape(h[1]) + command_name)) for h in headers] # type: ignore + mixins = [(h[0], re.compile(re.escape(h[1]) + _command_str)) for h in headers] # type: ignore self.command_header = mixins else: elements = [] @@ -131,11 +149,16 @@ def __init_header__( else: elements.append(h) if not elements: - self.command_header = re.compile(f"(?:{ch_text[:-1]})" + command_name) # noqa + if isinstance(_command_name, TPattern): + self.command_header = re.compile(f"(?:{ch_text[:-1]}){_command_str}") # noqa + else: + _command_name.pattern = f"(?:{ch_text[:-1]}){_command_name.pattern}" # type: ignore + _command_name.regex_pattern = re.compile(_command_name.pattern) # type: ignore + self.command_header = _command_name # type: ignore elif not ch_text: - self.command_header = (elements, re.compile(command_name)) + self.command_header = (elements, _command_name) # type: ignore else: - self.command_header = (elements, re.compile(f"(?:{ch_text[:-1]})")), re.compile(command_name) # noqa + self.command_header = (elements, re.compile(f"(?:{ch_text[:-1]})")), _command_name # type: ignore # noqa def __init_actions__(self): actions = self.alconna.action_list @@ -151,97 +174,78 @@ def __init_actions__(self): actions['subcommands'][f"{opt.dest}.{option.dest}"] = option.action @staticmethod - def default_params_generator(analyser: "Analyser"): - analyser.param_ids = set() - analyser.command_params = {} - analyser.part_len = range(len(analyser.alconna.options) + 1) + def default_params_compiler(analyser: "Analyser"): + analyser.part_len = range(len(analyser.alconna.options) + (1 if analyser.need_main_args else 0)) for opts in analyser.alconna.options: if isinstance(opts, Option): - if analyser.command_params.get(opts.name): - analyser.command_params[opts.name].append(opts) # type: ignore - else: - analyser.command_params[opts.name] = [opts] + for al in opts.aliases: + if (li := analyser.command_params.get(al)) and isinstance(li, list): + li.append(opts) # type: ignore + li.sort(key=lambda x: x.priority, reverse=True) + else: + analyser.command_params[al] = [opts] analyser.param_ids.update(opts.aliases) elif isinstance(opts, Subcommand): analyser.command_params[opts.name] = opts analyser.param_ids.add(opts.name) - opts.sub_part_len = range(len(opts.options) + 1) + opts.sub_part_len = range(len(opts.options) + (1 if opts.nargs else 0)) for sub_opts in opts.options: - if opts.sub_params.get(sub_opts.name): - opts.sub_params[sub_opts.name].append(sub_opts) # type: ignore - else: - opts.sub_params[sub_opts.name] = [sub_opts] + for al in sub_opts.aliases: + if (li := opts.sub_params.get(al)) and isinstance(li, list): + li.append(sub_opts) # type: ignore + li.sort(key=lambda x: x.priority, reverse=True) + else: + opts.sub_params[al] = [sub_opts] if sub_opts.requires: - opts.sub_params.update({k: Sentence(name=k) for k in sub_opts.requires}) + for k in sub_opts.requires: + opts.sub_params.setdefault(k, Sentence(name=k)) analyser.param_ids.update(sub_opts.aliases) + if not analyser.separators.issuperset(opts.separators): + analyser.default_separate &= False if opts.requires: analyser.param_ids.update(opts.requires) - analyser.command_params.update({k: Sentence(name=k) for k in opts.requires}) + for k in opts.requires: + analyser.command_params.setdefault(k, Sentence(name=k)) def __repr__(self): - return f"<{self.__class__.__name__}>" - - def __del__(self): - self.reset() + return f"<{self.__class__.__name__} of {self.alconna.path}>" def reset(self): """重置分析器""" - self.current_index = 0 - self.content_index = 0 - self.is_str = False - self.temporary_data = {} - self.raw_data = [] - self.head_matched = False - self.ndata = 0 - self.origin_data = None - self.temp_token = 0 - self.header = None - self.main_args = {} - self.options = {} - self.subcommands = {} - self.sentences = [] + self.current_index, self.content_index, self.ndata, self.temp_token = 0, 0, 0, 0 + self.is_str, self.head_matched = False, False + self.temporary_data, self.main_args, self.options, self.subcommands = {}, {}, {}, {} + self.raw_data, self.sentences = [], [] + self.origin_data, self.header = None, None + self.head_pos = (0, 0) - def next_data(self, separate: Optional[Set[str]] = None, pop: bool = True) -> Tuple[Union[str, Any], bool]: + def popitem(self, separate: Optional[Set[str]] = None, move: bool = True) -> Tuple[Union[str, Any], bool]: """获取解析需要的下个数据""" - if "separators" in self.temporary_data: - self.temporary_data.pop("separators", None) + self.temporary_data["separators"] = None if self.current_index == self.ndata: return "", True _current_data = self.raw_data[self.current_index] - if isinstance(_current_data, list): + if isinstance(_current_data, StrMounter): _rest_text: str = "" _text = _current_data[self.content_index] if separate and not self.separators.issuperset(separate): - _text, _rest_text = split_once(_text, separate) - if pop: + _text, _rest_text = split_once(_text, tuple(separate)) + if move: if _rest_text: # 这里实际上还是pop了 self.temporary_data["separators"] = separate - _current_data[self.content_index] = _rest_text # self.raw_data[self.current_index] + _current_data[self.content_index] = _rest_text else: self.content_index += 1 if len(_current_data) == self.content_index: self.current_index += 1 self.content_index = 0 return _text, True - if pop: + if move: self.current_index += 1 return _current_data, False - def rest_count(self, separate: Optional[Set[str]] = None) -> int: - """获取剩余的数据个数""" - _result = 0 - is_cur = False - for _data in self.raw_data[self.current_index:]: - if isinstance(_data, list): - for s in (_data[self.content_index:] if not is_cur else _data): - is_cur = True - _result += len(split(s, separate)) if separate and not self.separators.issuperset(separate) else 1 - else: - _result += 1 - return _result - - def reduce_data(self, data: Union[str, Any], replace=False): - """把pop的数据放回 (实际只是‘指针’移动)""" + def pushback(self, data: Union[str, Any], replace=False): + """把 pop的数据放回 (实际只是‘指针’移动)""" if not data: return if self.current_index == self.ndata: @@ -255,9 +259,9 @@ def reduce_data(self, data: Union[str, Any], replace=False): self.raw_data[self.current_index] = data else: _current_data = self.raw_data[self.current_index] - if isinstance(_current_data, list) and isinstance(data, str): + if isinstance(_current_data, StrMounter) and isinstance(data, str): if seps := self.temporary_data.get("separators", None): - _current_data[self.content_index] = f"{data}{seps.copy().pop()}{_current_data[self.content_index]}" + _current_data[self.content_index] = f"{data}{tuple(seps)[0]}{_current_data[self.content_index]}" else: self.content_index -= 1 if replace: @@ -267,71 +271,53 @@ def reduce_data(self, data: Union[str, Any], replace=False): if replace: self.raw_data[self.current_index] = data - def recover_raw_data(self) -> List[Union[str, Any]]: - """将处理过的命令数据大概还原""" + def release(self, separate: Optional[Set[str]] = None) -> List[Union[str, Any]]: _result = [] is_cur = False for _data in self.raw_data[self.current_index:]: - if isinstance(_data, list): - if not is_cur: - _result.append(f'{self.separators.copy().pop()}'.join(_data[self.content_index:])) + if isinstance(_data, StrMounter): + for s in _data[0 if is_cur else self.content_index:]: + _result.extend( + split(s, tuple(separate)) if separate and not self.separators.issuperset(separate) else [s]) is_cur = True - else: - _result.append(f'{self.separators.copy().pop()}'.join(_data)) else: _result.append(_data) - self.current_index = self.ndata - self.content_index = 0 return _result - def process_message(self, data: Union[str, DataCollection]) -> 'Analyser': + def process(self, data: DataCollection[Union[str, Any]]) -> 'Analyser': """命令分析功能, 传入字符串或消息链, 应当在失败时返回fail的arpamar""" self.origin_data = data if isinstance(data, str): self.is_str = True - if not (res := split(data.lstrip(), self.separators)): - exp = NullTextMessage(lang_config.analyser_handle_null_message.format(target=data)) - if self.is_raise_exception: - raise exp - self.temporary_data["fail"] = exp - else: - self.raw_data = [res] - self.ndata = 1 - self.temp_token = self.generate_token(self.raw_data) - else: - separates = self.separators - i, __t, exc = 0, False, None - raw_data = self.raw_data - for unit in data: # type: ignore - if text := getattr(unit, 'text', None): - if not (res := split(text.lstrip(), separates)): - continue - raw_data.append(res) - __t = True - elif isinstance(unit, str): - if not (res := split(unit.lstrip(), separates)): - continue - raw_data.append(res) - __t = True - elif unit.__class__.__name__ not in self.filter_out: - raw_data.append(unit) - else: + data = [data] + i, exc = 0, None + raw_data = self.raw_data + for unit in data: + if (uname := unit.__class__.__name__) in self.filter_out: + continue + if (proc := self.preprocessors.get(uname)) and (res := proc(unit)): + unit = res + if text := getattr(unit, self.text_sign, unit if isinstance(unit, str) else None): + if not (res := split(text.strip(), tuple(self.separators))): continue - i += 1 - if __t is False: - exp = NullTextMessage(lang_config.analyser_handle_null_message.format(target=data)) - if self.is_raise_exception: - raise exp - self.temporary_data["fail"] = exp + raw_data.append(StrMounter(res)) else: - self.ndata = i + raw_data.append(unit) + i += 1 + if i < 1: + exp = NullMessage(config.lang.analyser_handle_null_message.format(target=data)) + if self.is_raise_exception: + raise exp + self.temporary_data["fail"] = exp + else: + self.ndata = i + if config.enable_message_cache: self.temp_token = self.generate_token(raw_data) return self @abstractmethod - def analyse(self, message: Union[str, DataCollection, None] = None) -> Arpamar: + def analyse(self, message: Union[DataCollection[Union[str, Any]], None] = None) -> Arpamar: """主体解析函数, 应针对各种情况进行解析""" - pass @staticmethod def converter(command: str) -> T_Origin: @@ -341,15 +327,14 @@ def export(self, exception: Optional[BaseException] = None, fail: bool = False) """创建arpamar, 其一定是一次解析的最后部分""" result = Arpamar(self.alconna) result.head_matched = self.head_matched + result.matched = not fail if fail: - tb = traceback.format_exc(limit=1) - result.error_info = repr(exception or tb) - result.error_data = self.recover_raw_data() - result.matched = False + result.error_info = repr(exception or traceback.format_exc(limit=1)) + result.error_data = self.release() else: - result.matched = True result.encapsulate_result(self.header, self.main_args, self.options, self.subcommands) - command_manager.record(self.temp_token, self.origin_data, result) # type: ignore - self.used_tokens.add(self.temp_token) + if config.enable_message_cache: + command_manager.record(self.temp_token, self.origin_data, result) # type: ignore + self.used_tokens.add(self.temp_token) self.reset() return result diff --git a/src/arclet/alconna/analysis/base.py b/src/arclet/alconna/analysis/base.py index bafd975b..5e9af574 100644 --- a/src/arclet/alconna/analysis/base.py +++ b/src/arclet/alconna/analysis/base.py @@ -11,19 +11,17 @@ from ..core import Alconna -def compile(alconna: "Alconna", params_generator: Optional[Callable[[Analyser], None]] = None): +def compile(alconna: "Alconna", params_compiler: Optional[Callable[[Analyser], None]] = None): _analyser = alconna.analyser_type(alconna) - if params_generator: - params_generator(_analyser) + if params_compiler: + params_compiler(_analyser) else: - Analyser.default_params_generator(_analyser) + Analyser.default_params_compiler(_analyser) return _analyser -def analyse(alconna: "Alconna", command: Union[str, DataCollection]) -> "Arpamar": - ana = compile(alconna) - ana.process_message(command) - return ana.analyse().execute() +def analyse(alconna: "Alconna", command: DataCollection[Union[str, Any]]) -> "Arpamar": + return compile(alconna).process(command).analyse().execute() class AnalyseError(Exception): @@ -35,28 +33,27 @@ class _DummyAnalyser(Analyser): class _DummyALC: is_fuzzy_match = False + options = [] def __new__(cls, *args, **kwargs): cls.alconna = cls._DummyALC() # type: ignore cls.command_params = {} cls.param_ids = set() + cls.default_separate = True return super().__new__(cls) - def analyse(self, message: Union[str, DataCollection, None] = None): + def analyse(self, message: Union[DataCollection[Union[str, Any]], None] = None): pass -def analyse_args( - args: Args, - command: Union[str, DataCollection], - raise_exception: bool = True -): +def analyse_args(args: Args, command: DataCollection[Union[str, Any]], raise_exception: bool = True): _analyser = _DummyAnalyser.__new__(_DummyAnalyser) _analyser.reset() _analyser.separators = {' '} + _analyser.need_main_args = True _analyser.is_raise_exception = True try: - _analyser.process_message(command) + _analyser.process(command) return ala(_analyser, args, len(args)) except Exception as e: if raise_exception: @@ -67,17 +64,18 @@ def analyse_args( def analyse_header( headers: Union[List[Union[str, Any]], List[Tuple[Any, str]]], command_name: str, - command: Union[str, DataCollection], + command: DataCollection[Union[str, Any]], sep: str = " ", raise_exception: bool = True ): _analyser = _DummyAnalyser.__new__(_DummyAnalyser) _analyser.reset() _analyser.separators = {sep} + _analyser.need_main_args = False _analyser.is_raise_exception = True _analyser.__init_header__(command_name, headers) try: - _analyser.process_message(command) + _analyser.process(command) return alh(_analyser) except Exception as e: if raise_exception: @@ -85,17 +83,17 @@ def analyse_header( return -def analyse_option( - option: Option, - command: Union[str, DataCollection], - raise_exception: bool = True -): +def analyse_option(option: Option, command: DataCollection[Union[str, Any]], raise_exception: bool = True): _analyser = _DummyAnalyser.__new__(_DummyAnalyser) _analyser.reset() _analyser.separators = {" "} + _analyser.need_main_args = False _analyser.is_raise_exception = True + _analyser.alconna.options.append(option) + _analyser.default_params_compiler(_analyser) + _analyser.alconna.options.clear() try: - _analyser.process_message(command) + _analyser.process(command) return alo(_analyser, option) except Exception as e: if raise_exception: @@ -103,17 +101,17 @@ def analyse_option( return -def analyse_subcommand( - subcommand: Subcommand, - command: Union[str, DataCollection], - raise_exception: bool = True -): +def analyse_subcommand(subcommand: Subcommand, command: DataCollection[Union[str, Any]], raise_exception: bool = True): _analyser = _DummyAnalyser.__new__(_DummyAnalyser) _analyser.reset() _analyser.separators = {" "} + _analyser.need_main_args = False _analyser.is_raise_exception = True + _analyser.alconna.options.append(subcommand) + _analyser.default_params_compiler(_analyser) + _analyser.alconna.options.clear() try: - _analyser.process_message(command) + _analyser.process(command) return als(_analyser, subcommand) except Exception as e: if raise_exception: diff --git a/src/arclet/alconna/analysis/parts.py b/src/arclet/alconna/analysis/parts.py index 66fe645c..20e5d9b7 100644 --- a/src/arclet/alconna/analysis/parts.py +++ b/src/arclet/alconna/analysis/parts.py @@ -1,18 +1,17 @@ import re -from typing import Iterable, Union, List, Any, Dict, Pattern, Tuple, Set +from typing import Iterable, Union, List, Any, Dict, Tuple, Set from .analyser import Analyser from ..exceptions import ParamsUnmatched, ArgumentMissing, FuzzyMatchSuccess -from ..typing import AllParam, Empty, DataUnit, MultiArg, BasePattern +from ..typing import AllParam, Empty, MultiArg, BasePattern, TPattern from ..base import Args, Option, Subcommand, OptionResult, SubcommandResult, Sentence from ..util import levenshtein_norm, split_once -from ..manager import command_manager -from ..lang import lang_config +from ..config import config def multi_arg_handler( analyser: Analyser, - may_arg: Union[str, DataUnit], + may_arg: Union[str, Any], key: str, value: MultiArg, default: Any, @@ -23,24 +22,24 @@ def multi_arg_handler( # 当前args 已经解析 m 个参数, 总共需要 n 个参数,总共剩余p个参数, # q = n - m 为剩余需要参数(包括自己), p - q + 1 为自己可能需要的参数个数 _m_rest_arg = nargs - len(result_dict) - 1 - _m_all_args_count = analyser.rest_count(seps) - _m_rest_arg + 1 + _m_all_args_count = len(analyser.release(seps)) - _m_rest_arg + 1 if value.array_length: _m_all_args_count = min(_m_all_args_count, value.array_length) - analyser.reduce_data(may_arg) + analyser.pushback(may_arg) if value.flag == 'args': result = [] for i in range(_m_all_args_count): - _m_arg, _m_str = analyser.next_data(seps) + _m_arg, _m_str = analyser.popitem(seps) if _m_str and _m_arg in analyser.param_ids: - analyser.reduce_data(_m_arg) + analyser.pushback(_m_arg) for ii in range(min(len(result), _m_rest_arg)): - analyser.reduce_data(result.pop(-1)) + analyser.pushback(result.pop(-1)) break - try: - result.append(value.match(_m_arg)) - except ParamsUnmatched: - analyser.reduce_data(_m_arg) + res, s = value.validate(_m_arg) + if s != 'V': + analyser.pushback(_m_arg) break + result.append(res) if len(result) == 0: result = [default] if default else [] result_dict[key] = tuple(result) @@ -48,37 +47,35 @@ def multi_arg_handler( result = {} def __putback(data): - analyser.reduce_data(data) + analyser.pushback(data) for _ in range(min(len(result), _m_rest_arg)): arg = result.popitem() # type: ignore - analyser.reduce_data(f'{arg[0]}={arg[1]}') + analyser.pushback(f'{arg[0]}={arg[1]}') for i in range(_m_all_args_count): - _m_arg, _m_str = analyser.next_data(seps) + _m_arg, _m_str = analyser.popitem(seps) if not _m_str: - analyser.reduce_data(_m_arg) + analyser.pushback(_m_arg) break if _m_str and _m_arg in analyser.command_params: __putback(_m_arg) break if _kwarg := re.match(r'^([^=]+)=([^=]+?)$', _m_arg): - _key = _kwarg.group(1) _m_arg = _kwarg.group(2) - try: - result[_key] = value.match(_m_arg) - except ParamsUnmatched: - analyser.reduce_data(_m_arg) + res, s = value.validate(_m_arg) + if s != 'V': + analyser.pushback(_m_arg) break + result[_kwarg.group(1)] = res elif _kwarg := re.match(r'^([^=]+)=\s?$', _m_arg): - _key = _kwarg.group(1) - _m_arg, _m_str = analyser.next_data(seps) - try: - result[_key] = value.match(_m_arg) - except ParamsUnmatched: + _m_arg, _m_str = analyser.popitem(seps) + res, s = value.validate(_m_arg) + if s != 'V': __putback(_m_arg) break + result[_kwarg.group(1)] = res else: - analyser.reduce_data(_m_arg) + analyser.pushback(_m_arg) break if len(result) == 0: result = [default] if default else [] @@ -106,73 +103,61 @@ def analyse_args( for key, arg in opt_args.argument.items(): value = arg['value'] default = arg['default'] - kwonly = arg['kwonly'] optional = arg['optional'] - may_arg, _str = analyser.next_data(seps) - if not may_arg: + may_arg, _str = analyser.popitem(seps) + if not may_arg or (_str and may_arg in analyser.param_ids): + analyser.pushback(may_arg) if default is None: if optional: continue - raise ArgumentMissing(lang_config.args_missing.format(key=key)) + raise ArgumentMissing(config.lang.args_missing.format(key=key)) option_dict[key] = None if default is Empty else default continue - if kwonly: + if arg['kwonly']: _kwarg = re.findall(f'^{key}=(.*)$', may_arg) if not _kwarg: - analyser.reduce_data(may_arg) + analyser.pushback(may_arg) if analyser.alconna.is_fuzzy_match and (k := may_arg.split('=')[0]) != may_arg: - if levenshtein_norm(k, key) >= 0.6: - raise FuzzyMatchSuccess(lang_config.common_fuzzy_matched.format(source=k, target=key)) + if levenshtein_norm(k, key) >= config.fuzzy_threshold: + raise FuzzyMatchSuccess(config.lang.common_fuzzy_matched.format(source=k, target=key)) if default is None and analyser.is_raise_exception: - raise ParamsUnmatched(lang_config.args_key_missing.format(target=may_arg, key=key)) + raise ParamsUnmatched(config.lang.args_key_missing.format(target=may_arg, key=key)) option_dict[key] = None if default is Empty else default continue may_arg = _kwarg[0] if may_arg == '': - may_arg, _str = analyser.next_data(seps) + may_arg, _str = analyser.popitem(seps) if _str: - analyser.reduce_data(may_arg) + analyser.pushback(may_arg) if default is None and analyser.is_raise_exception: - raise ParamsUnmatched(lang_config.args_type_error.format(target=may_arg.__class__)) + raise ParamsUnmatched(config.lang.args_type_error.format(target=may_arg.__class__)) option_dict[key] = None if default is Empty else default continue - if may_arg in analyser.param_ids: - analyser.reduce_data(may_arg) - if default is None: - if optional: - continue - raise ArgumentMissing(lang_config.args_missing.format(key=key)) - else: - option_dict[key] = None if default is Empty else default - elif isinstance(value, BasePattern): + if isinstance(value, BasePattern): if value.__class__ is MultiArg: multi_arg_handler(analyser, may_arg, key, value, default, nargs, seps, option_dict) # type: ignore else: - res, state = value.validate(may_arg, default) + res, state = value.invalidate(may_arg, default) if value.anti else value.validate(may_arg, default) if state != "V": - analyser.reduce_data(may_arg) + analyser.pushback(may_arg) if state == "E": if optional: continue raise res option_dict[key] = res elif value is AllParam: - rest_data = analyser.recover_raw_data() - if not rest_data: - rest_data = [may_arg] - elif isinstance(rest_data[0], str): - rest_data[0] = may_arg + seps.copy().pop() + rest_data[0] - else: - rest_data.insert(0, may_arg) - option_dict[key] = rest_data + analyser.pushback(may_arg) + option_dict[key] = analyser.release() + analyser.current_index = analyser.ndata + analyser.content_index = 0 return option_dict elif may_arg == value: option_dict[key] = may_arg + elif default is None: + if optional: + continue + raise ParamsUnmatched(config.lang.args_error.format(target=may_arg)) else: - if default is None: - if optional: - continue - raise ParamsUnmatched(lang_config.args_error.format(target=may_arg)) option_dict[key] = None if default is Empty else default if opt_args.var_keyword: kwargs = option_dict[opt_args.var_keyword] @@ -189,39 +174,28 @@ def analyse_args( return option_dict -def analyse_params( - analyser: Analyser, - params: Dict[str, Union[List[Option], Sentence, Subcommand]] +def analyse_unmatch_params( + params: Iterable[Union[List[Option], Sentence, Subcommand]], + text: str, + is_fuzzy_match: bool = False ): - _text, _str = analyser.next_data(analyser.separators, pop=False) - if not _str: - return Ellipsis - if not _text: - return _text - if param := params.get(_text, None): - return param - for p in params: - _p = params[p] - if isinstance(_p, List): + for _p in params: + if isinstance(_p, list): res = [] for _o in _p: - if not _o.is_compact: - _may_param, _ = split_once(_text, _o.separators) - if _may_param in _o.aliases: - res.append(_o) - continue - if analyser.alconna.is_fuzzy_match and levenshtein_norm(_may_param, p) >= 0.6: - raise FuzzyMatchSuccess(lang_config.common_fuzzy_matched.format(source=_may_param, target=p)) - elif any(map(lambda x: _text.startswith(x), _o.aliases)): + _may_param = split_once(text, tuple(_o.separators))[0] + if _may_param in _o.aliases or any(map(lambda x: _may_param.startswith(x), _o.aliases)): res.append(_o) + continue + if is_fuzzy_match and levenshtein_norm(_may_param, _o.name) >= config.fuzzy_threshold: + raise FuzzyMatchSuccess(config.lang.common_fuzzy_matched.format(source=_may_param, target=_o.name)) if res: return res else: - _may_param, _ = split_once(_text, _p.separators) - if _may_param == _p.name: + if (_may_param := split_once(text, tuple(_p.separators))[0]) == _p.name: return _p - if analyser.alconna.is_fuzzy_match and levenshtein_norm(_may_param, p) >= 0.6: - raise FuzzyMatchSuccess(lang_config.common_fuzzy_matched.format(source=_may_param, target=p)) + if is_fuzzy_match and levenshtein_norm(_may_param, _p.name) >= config.fuzzy_threshold: + raise FuzzyMatchSuccess(config.lang.common_fuzzy_matched.format(source=_may_param, target=_p.name)) def analyse_option( @@ -235,20 +209,19 @@ def analyse_option( analyser: 使用的分析器 param: 目标Option """ - if param.requires: - if analyser.sentences != param.requires: - raise ParamsUnmatched(f"{param.name}'s required is not '{' '.join(analyser.sentences)}'") - analyser.sentences = [] + if param.requires and analyser.sentences != param.requires: + raise ParamsUnmatched(f"{param.name}'s required is not '{' '.join(analyser.sentences)}'") + analyser.sentences = [] if param.is_compact: - name, _ = analyser.next_data() + name, _ = analyser.popitem() for al in param.aliases: - if name.startswith(al): - analyser.reduce_data(name.lstrip(al), replace=True) + if mat := re.fullmatch(f"{al}(?P.*?)", name): + analyser.pushback(mat.groupdict()['rest'], replace=True) break else: raise ParamsUnmatched(f"{name} dose not matched with {param.name}") else: - name, _ = analyser.next_data(param.separators) + name, _ = analyser.popitem(param.separators) if name not in param.aliases: # 先匹配选项名称 raise ParamsUnmatched(f"{name} dose not matched with {param.name}") name = param.dest @@ -271,18 +244,16 @@ def analyse_subcommand( analyser: 使用的分析器 param: 目标Subcommand """ - if param.requires: - if analyser.sentences != param.requires: - raise ParamsUnmatched(f"{param.name}'s required is not '{' '.join(analyser.sentences)}'") - analyser.sentences = [] + if param.requires and analyser.sentences != param.requires: + raise ParamsUnmatched(f"{param.name}'s required is not '{' '.join(analyser.sentences)}'") + analyser.sentences = [] if param.is_compact: - name, _ = analyser.next_data() - if name.startswith(param.name): - analyser.reduce_data(name.lstrip(param.name), replace=True) - else: + name, _ = analyser.popitem() + if not name.startswith(param.name): raise ParamsUnmatched(f"{name} dose not matched with {param.name}") + analyser.pushback(name.lstrip(param.name), replace=True) else: - name, _ = analyser.next_data(param.separators) + name, _ = analyser.popitem(param.separators) if name != param.name: # 先匹配选项名称 raise ParamsUnmatched(f"{name} dose not matched with {param.name}") name = param.dest @@ -292,16 +263,20 @@ def analyse_subcommand( return name, res args = False - subcommand = res['options'] - need_args = param.nargs > 0 for _ in param.sub_part_len: - sub_param = analyse_params(analyser, param.sub_params) # type: ignore - if sub_param and isinstance(sub_param, List): - for p in sub_param: + _text, _str = analyser.popitem(param.separators, move=False) + _param = _param if (_param := (param.sub_params.get(_text) if _str and _text else Ellipsis)) else ( + analyse_unmatch_params(param.sub_params.values(), _text, analyser.alconna.is_fuzzy_match) + ) + if (not _param or _param is Ellipsis) and not args: + res['args'] = analyse_args(analyser, param.args, param.nargs) + args = True + elif isinstance(_param, List): + for p in _param: _current_index = analyser.current_index _content_index = analyser.content_index try: - subcommand.setdefault(*analyse_option(analyser, p)) + res['options'].setdefault(*analyse_option(analyser, p)) break except Exception as e: exc = e @@ -311,11 +286,8 @@ def analyse_subcommand( else: raise exc # type: ignore # noqa - elif not args: - res['args'] = analyse_args(analyser, param.args, param.nargs) - args = True - if need_args and not args: - raise ArgumentMissing(lang_config.subcommand_args_missing.format(name=name)) + if not args and param.nargs > 0: + raise ArgumentMissing(config.lang.subcommand_args_missing.format(name=name)) return name, res @@ -331,38 +303,47 @@ def analyse_header( head_match: 当命令头内写有正则表达式并且匹配成功的话, 返回匹配结果 """ command = analyser.command_header - separators = analyser.separators - head_text, _str = analyser.next_data(separators) - if isinstance(command, Pattern): - if _str and (_head_find := command.fullmatch(head_text)): - analyser.head_matched = True - return _head_find.groupdict() or True + head_text, _str = analyser.popitem() + if isinstance(command, TPattern) and _str and (_head_find := command.fullmatch(head_text)): + analyser.head_matched = True + return _head_find.groupdict() or True + elif isinstance(command, BasePattern) and (_head_find := command.validate(head_text, Empty)[0]): + analyser.head_matched = True + return _head_find or True else: - may_command, _m_str = analyser.next_data(separators) - if _m_str and not _str: - if isinstance(command, List): - for _command in command: - if (_head_find := _command[1].fullmatch(may_command)) and head_text == _command[0]: - analyser.head_matched = True - return _head_find.groupdict() or True - elif isinstance(command[0], list): - if (_head_find := command[1].fullmatch(may_command)) and head_text in command[0]: # type: ignore + may_command, _m_str = analyser.popitem() + if isinstance(command, List) and _m_str and not _str: + for _command in command: + if (_head_find := _command[1].fullmatch(may_command)) and head_text == _command[0]: analyser.head_matched = True return _head_find.groupdict() or True - else: - if (_command_find := command[1].fullmatch(may_command)) and head_text in command[0][0]: # type: ignore + if isinstance(command, tuple): + if not _str and ( + (isinstance(command[0], list) and head_text in command[0]) or + (isinstance(command[0], tuple) and head_text in command[0][0]) + ): + if isinstance(command[1], TPattern): + if _m_str and (_command_find := command[1].fullmatch(may_command)): + analyser.head_matched = True + return _command_find.groupdict() or True + elif _command_find := command[1].validate(may_command, Empty)[0]: analyser.head_matched = True - return _command_find.groupdict() or True - - elif _str: - pat = re.compile(command[0][1].pattern + command[1].pattern) # type: ignore - if _head_find := pat.fullmatch(head_text): - analyser.reduce_data(may_command) - analyser.head_matched = True - return _head_find.groupdict() or True - elif _m_str and (_command_find := pat.fullmatch(head_text + may_command)): - analyser.head_matched = True - return _command_find.groupdict() or True + return _command_find or True + elif _str and isinstance(command[0][1], TPattern): + if _m_str: + pat = re.compile(command[0][1].pattern + command[1].pattern) # type: ignore + if _head_find := pat.fullmatch(head_text): + analyser.pushback(may_command) + analyser.head_matched = True + return _head_find.groupdict() or True + elif _command_find := pat.fullmatch(head_text + may_command): + analyser.head_matched = True + return _command_find.groupdict() or True + elif isinstance(command[1], BasePattern) and (_head_find := command[0][1].fullmatch(head_text)) and ( + _command_find := command[1].validate(may_command, Empty)[0] + ): + analyser.head_matched = True + return _command_find or True if not analyser.head_matched: if _str and analyser.alconna.is_fuzzy_match: @@ -370,20 +351,20 @@ def analyse_header( if analyser.alconna.headers and analyser.alconna.headers != [""]: for i in analyser.alconna.headers: if isinstance(i, str): - headers_text.append(i + analyser.alconna.command) + headers_text.append(f"{i}{analyser.alconna.command}") else: headers_text.extend((f"{i}", analyser.alconna.command)) elif analyser.alconna.command: headers_text.append(analyser.alconna.command) - if isinstance(command, Pattern): + if isinstance(command, (TPattern, BasePattern)): source = head_text else: source = head_text + analyser.separators.copy().pop() + str(may_command) # type: ignore # noqa - if command_manager.get_command(source): + if source == analyser.alconna.command: analyser.head_matched = False - raise ParamsUnmatched(lang_config.header_error.format(target=head_text)) + raise ParamsUnmatched(config.lang.header_error.format(target=head_text)) for ht in headers_text: - if levenshtein_norm(source, ht) >= 0.6: + if levenshtein_norm(source, ht) >= config.fuzzy_threshold: analyser.head_matched = True - raise FuzzyMatchSuccess(lang_config.common_fuzzy_matched.format(target=source, source=ht)) - raise ParamsUnmatched(lang_config.header_error.format(target=head_text)) + raise FuzzyMatchSuccess(config.lang.common_fuzzy_matched.format(target=source, source=ht)) + raise ParamsUnmatched(config.lang.header_error.format(target=head_text)) diff --git a/src/arclet/alconna/analysis/special.py b/src/arclet/alconna/analysis/special.py index b35121b6..5d31e67d 100644 --- a/src/arclet/alconna/analysis/special.py +++ b/src/arclet/alconna/analysis/special.py @@ -1,40 +1,29 @@ -from ..components.output import output_send +from ..components.output import output_manager from ..base import ShortcutOption from .parts import analyse_option from .analyser import Analyser def handle_help(analyser: Analyser): - _help_param = analyser.recover_raw_data() - _help_param[0] = _help_param[0].replace("--help", "", 1).replace("-h", "", 1).lstrip() + analyser.current_index, analyser.content_index = analyser.head_pos + _help_param = [str(i) for i in analyser.release() if i not in {"-h", "--help"}] def _get_help(): formatter = analyser.alconna.formatter_type(analyser.alconna) return formatter.format_node(_help_param) - output_send(analyser.alconna.name, _get_help).handle( - {}, is_raise_exception=analyser.is_raise_exception - ) + output_manager.get(analyser.alconna.name, _get_help).handle(is_raise_exception=analyser.is_raise_exception) return analyser.export(fail=True) def handle_shortcut(analyser: Analyser): - def _shortcut(sct: str, command: str, expiration: int, delete: bool): - return analyser.alconna.shortcut( - sct, None if command == "_" else analyser.converter(command), delete, expiration - ) - - _, opt_v = analyse_option(analyser, ShortcutOption) + opt_v = analyse_option(analyser, ShortcutOption)[1]['args'] try: - msg = _shortcut( - opt_v['args']['name'], opt_v['args']['command'], - opt_v['args']['expiration'], True if opt_v['args'].get('delete') else False + msg = analyser.alconna.shortcut( + opt_v['name'], None if opt_v['command'] == "_" else analyser.converter(opt_v['command']), + bool(opt_v.get('delete')), opt_v['expiration'] ) - output_send( - analyser.alconna.name, lambda: msg - ).handle({}, is_raise_exception=analyser.is_raise_exception) + output_manager.get(analyser.alconna.name, lambda: msg).handle(is_raise_exception=analyser.is_raise_exception) except Exception as e: - output_send(analyser.alconna.name, lambda: str(e)).handle( - {}, is_raise_exception=analyser.is_raise_exception - ) + output_manager.get(analyser.alconna.name, lambda: str(e)).handle(is_raise_exception=analyser.is_raise_exception) return analyser.export(fail=True) diff --git a/src/arclet/alconna/arpamar.py b/src/arclet/alconna/arpamar.py index 36c1d69d..ad8c5106 100644 --- a/src/arclet/alconna/arpamar.py +++ b/src/arclet/alconna/arpamar.py @@ -1,6 +1,6 @@ from typing import Union, Dict, List, Any, Optional, TYPE_CHECKING, Type, TypeVar, Tuple, overload from .typing import DataCollection -from .lang import lang_config +from .config import config from .base import SubcommandResult, OptionResult from .exceptions import BehaveCancelled from .components.behavior import T_ABehavior, requirement_handler @@ -17,23 +17,16 @@ class Arpamar: """ 亚帕玛尔(Arpamar), Alconna的珍藏宝书 - Example: - - `Arpamar.main_args`: 当 Alconna 写入了 main_argument 时,该参数返回对应的解析出来的值 - - `Arpamar.header`: 当 Alconna 的 command 内写有正则表达式时,该参数返回对应的匹配值 - - `Arpamar.find`: 判断 Arpamar 内是否有对应的属性 - - `Arpamar.query`: 返回 Arpamar 中指定的属性 - - `Arpamar.matched`: 返回命令是否匹配成功 - """ def __init__(self, alc: "Alconna"): self.source: "Alconna" = alc - self.origin: Union[str, DataCollection] = '' + self.origin: DataCollection[Union[str, Any]] = '' self.matched: bool = False self.head_matched: bool = False self.error_data: List[Union[str, Any]] = [] @@ -118,10 +111,7 @@ def encapsulate_result( self.other_args = {**self.other_args, **vv['args']} def execute(self, behaviors: Optional[List[T_ABehavior]] = None): - behaviors = [ - *self.source.behaviors, - *(behaviors or []) - ] + behaviors = [*self.source.behaviors, *(behaviors or [])] for behavior in behaviors: for b in requirement_handler(behavior): try: @@ -130,10 +120,7 @@ def execute(self, behaviors: Optional[List[T_ABehavior]] = None): continue return self - def __require__( - self, - parts: List[str] - ) -> Tuple[Optional[Union[Dict[str, Any], OptionResult, SubcommandResult]], str]: + def __require__(self, parts: List[str]) -> Tuple[Union[Dict[str, Any], OptionResult, SubcommandResult, None], str]: """如果能够返回, 除开基本信息, 一定返回该path所在的dict""" if len(parts) == 1: part = parts[0] @@ -152,7 +139,7 @@ def __require__( return None, part prefix = parts.pop(0) # parts[0] if prefix in {"options", "subcommands"} and prefix in self.components: - raise RuntimeError(lang_config.arpamar_ambiguous_name.format(target=prefix)) + raise RuntimeError(config.lang.arpamar_ambiguous_name.format(target=prefix)) def _r_opt(_p: str, _s: List[str], _opts: Dict[str, OptionResult]): if _p == "options": @@ -186,12 +173,12 @@ def _r_opt(_p: str, _s: List[str], _opts: Dict[str, OptionResult]): if end in _cache['args']: return _cache['args'], end if end == "options" and end in _cache['options']: - raise RuntimeError(lang_config.arpamar_ambiguous_name.format(target=f"{prefix}.{end}")) + raise RuntimeError(config.lang.arpamar_ambiguous_name.format(target=f"{prefix}.{end}")) if end == "options" or end in _cache['options']: return _r_opt(end, parts, _cache['options']) return None, prefix - def query(self, path: str, default: Any = None) -> Union[Dict[str, Any], Any, None]: + def query(self, path: str, default: T = None) -> Union[Dict[str, Any], T, None]: """根据path查询值""" parts = path.split('.') cache, endpoint = self.__require__(parts) @@ -212,11 +199,14 @@ def update(self, path: str, value: Any): else: cache[endpoint] = value - def query_with(self, arg_type: Type[T], name: Optional[str] = None) -> Optional[Union[Dict[str, T], T]]: + def query_with( + self, arg_type: Type[T], + name: Optional[str] = None, + default: Any = None + ) -> Optional[Union[Dict[str, T], T]]: """根据类型查询参数""" if name: - res = self.query(name) - return res if isinstance(res, arg_type) else None + return res if isinstance(res := self.query(name, default), arg_type) else None return {k: v for k, v in self.all_matched_args.items() if isinstance(v, arg_type)} def find(self, name: str) -> bool: diff --git a/src/arclet/alconna/base.py b/src/arclet/alconna/base.py index 4d4f7885..4b4bf0af 100644 --- a/src/arclet/alconna/base.py +++ b/src/arclet/alconna/base.py @@ -2,26 +2,26 @@ import re import inspect -from copy import copy +from copy import deepcopy from enum import Enum +from contextlib import suppress from dataclasses import dataclass, field -from typing import Union, Tuple, Dict, Iterable, Callable, Any, Optional, Sequence, List, Literal, TypedDict, Set +from typing import Union, Tuple, Dict, Iterable, Callable, Any, Optional, Sequence, List, Literal, TypedDict, \ + Set, FrozenSet -from .exceptions import InvalidParam, NullTextMessage -from .typing import BasePattern, Empty, DataUnit, AllParam, AnyOne, MultiArg, UnionArg, argument_type_validator -from .lang import lang_config +from .exceptions import InvalidParam, NullMessage +from .typing import BasePattern, Empty, AllParam, AnyOne, MultiArg, UnionArg, args_type_parser, pattern_map +from .config import config from .components.action import ArgAction - TAValue = Union[BasePattern, AllParam.__class__, type] -TADefault = Union[Any, DataUnit, Empty] +TADefault = Union[Any, object, Empty] class ArgFlag(str, Enum): """ 参数标记 """ - VAR_POSITIONAL = "S" VAR_KEYWORD = "W" OPTIONAL = 'O' @@ -32,22 +32,17 @@ class ArgFlag(str, Enum): class ArgUnit(TypedDict): - """ - 参数单元 - """ - + """参数单元 """ value: TAValue """参数值""" - default: TADefault """默认值""" - + notice: Optional[str] + """参数提示""" optional: bool """是否可选""" - kwonly: bool """是否键值对参数""" - hidden: bool """是否隐藏类型参数""" @@ -61,14 +56,14 @@ def __init__(cls, name, bases, attrs): cls.selecting = False def __getattr__(cls, name): - if name == 'shape': + if name in ('shape', '__test__', '_pytestfixturefunction'): return super().__getattribute__(name) cls.last_key = name cls.selecting = True return cls def __getitem__(self, item): - if isinstance(item, slice): + if isinstance(item, slice) or isinstance(item, tuple) and list(filter(lambda x: isinstance(x, slice), item)): raise InvalidParam(f"{self.__name__} 现在不支持切片; 应从 Args[a:b:c, x:y:z] 变为 Args[a,b,c][x,y,z]") if not isinstance(item, tuple): if self.selecting: @@ -110,26 +105,27 @@ def from_string_list(cls, args: List[List[str]], custom_types: Dict) -> "Args": """ _args = cls() for arg in args: - _le = len(arg) - if _le == 0: - raise NullTextMessage + if (_le := len(arg)) == 0: + raise NullMessage default = arg[2].strip(" ") if _le > 2 else None value = AllParam if arg[0].startswith("...") else ( - AnyOne if arg[0].startswith("..") else ( - arg[1].strip(" ") if _le > 1 else arg[0].lstrip(".-")) + AnyOne if arg[0].startswith("..") else (arg[1].strip(" ") if _le > 1 else arg[0].lstrip(".-")) ) name = arg[0].replace("...", "").replace("..", "") if value not in (AllParam, AnyOne): if custom_types and custom_types.get(value) and not inspect.isclass(custom_types[value]): - raise InvalidParam(lang_config.common_custom_type_error.format(target=custom_types[value])) - try: - value = eval(value, custom_types) # type: ignore - if default: - default = value(default) - except (NameError, ValueError, TypeError): - pass + raise InvalidParam(config.lang.common_custom_type_error.format(target=custom_types[value])) + with suppress(NameError, ValueError, TypeError): + if pattern_map.get(value, None): + value = pattern_map[value] + if default: + default = value.origin(default) + else: + value = eval(value, custom_types) # type: ignore + if default: + default = value(default) _args.add_argument(name, value=value, default=default) return _args @@ -191,15 +187,11 @@ def __init__( self.var_positional = None self.var_keyword = None self.optional_count = 0 - if isinstance(separators, str): - self.separators = {separators} - else: - self.separators = set(separators) + self.separators = {separators} if isinstance(separators, str) else set(separators) self.argument = { # type: ignore - k: { - "value": argument_type_validator(v), - "default": None, 'optional': False, 'hidden': False, 'kwonly': False - } for k, v in kwargs.items() + k: {"value": args_type_parser(v), "default": None, 'notice': None, + 'optional': False, 'hidden': False, 'kwonly': False} + for k, v in kwargs.items() } for arg in (args or []): self.__check_var__(arg) @@ -213,7 +205,7 @@ def add_argument(self, name: str, *, value: Any, default: Any = None, flags: Opt if name in self.argument: return self if flags: - name += ";" + "|".join(flags) + name += ";" + "".join(flags) self.__check_var__([name, value, default]) return self @@ -231,71 +223,75 @@ def separate(self, *separator: str): def __check_var__(self, val: Sequence): if not val: - raise InvalidParam(lang_config.args_name_empty) + raise InvalidParam(config.lang.args_name_empty) name, value, default = val[0], val[1] if len(val) > 1 else val[0], val[2] if len(val) > 2 else None if not isinstance(name, str): - raise InvalidParam(lang_config.args_name_error) + raise InvalidParam(config.lang.args_name_error) if not name.strip(): - raise InvalidParam(lang_config.args_name_empty) - _value = argument_type_validator(value, self.extra) + raise InvalidParam(config.lang.args_name_empty) + _value = args_type_parser(value, self.extra) if isinstance(_value, UnionArg) and _value.optional: default = Empty if default is None else default if default in ("...", Ellipsis): default = Empty if _value is Empty: - raise InvalidParam(lang_config.args_value_error.format(target=name)) - name, arg = self.__handle_flags__(name, value, _value, default) - self.argument[name] = arg - - def __handle_flags__(self, name, value, _value, default): - slot = {'value': _value, 'default': default, 'optional': False, 'hidden': False, 'kwonly': False} - if res := re.match(r"^.+?;(?P[^;]+?)$", name): - flags = res.group("flag").split("|") + raise InvalidParam(config.lang.args_value_error.format(target=name)) + slot: ArgUnit = { + 'value': _value, 'default': default, 'notice': None, + 'optional': False, 'hidden': False, 'kwonly': False + } + if res := re.match(r"^.+?#(?P[^;#]+)", name): + slot['notice'] = res.group("notice") + name = name.replace(f"#{res.group('notice')}", "") + if res := re.match(r"^.+?;(?P[^;#]+)", name): + flags = res.group("flag") name = name.replace(f";{res.group('flag')}", "") _limit = False for flag in flags: if flag == ArgFlag.FORCE and not _limit: - slot['value'] = ( - BasePattern(value, alias=f"\'{value}\'") if isinstance(value, str) else BasePattern.of(value) - ) + self.__handle_force__(slot, value) _limit = True if flag == ArgFlag.ANTI and not _limit: - if isinstance(_value, UnionArg): - slot['value'].reverse() - elif _value not in (AnyOne, AllParam): - slot['value'] = copy(_value).reverse() + if slot['value'] not in (AnyOne, AllParam): + slot['value'] = deepcopy(_value).reverse() # type: ignore _limit = True if flag == ArgFlag.VAR_KEYWORD and not _limit: if self.var_keyword: - raise InvalidParam(lang_config.args_duplicate_kwargs) - if _value not in (AnyOne, AllParam): - slot['value'] = MultiArg(_value, flag='kwargs') + raise InvalidParam(config.lang.args_duplicate_kwargs) + if _value is not AllParam: + slot['value'] = MultiArg(_value, flag='kwargs') # type: ignore self.var_keyword = name _limit = True if flag == ArgFlag.VAR_POSITIONAL and not _limit: if self.var_positional: - raise InvalidParam(lang_config.args_duplicate_varargs) - if _value not in (AnyOne, AllParam): - slot['value'] = MultiArg(_value) + raise InvalidParam(config.lang.args_duplicate_varargs) + if _value is not AllParam: + slot['value'] = MultiArg(_value) # type: ignore self.var_positional = name if flag.isdigit() and not _limit: if self.var_positional: - raise InvalidParam(lang_config.args_duplicate_varargs) - if _value not in (AnyOne, AllParam): - slot['value'] = MultiArg(_value, array_length=int(flag)) + raise InvalidParam(config.lang.args_duplicate_varargs) + if _value is not AllParam: + slot['value'] = MultiArg(_value, length=int(flag)) # type: ignore self.var_positional = name if flag == ArgFlag.OPTIONAL: if self.var_keyword or self.var_positional: - raise InvalidParam(lang_config.args_exclude_mutable_args) + raise InvalidParam(config.lang.args_exclude_mutable_args) slot['optional'] = True self.optional_count += 1 if flag == ArgFlag.HIDDEN: slot['hidden'] = True if flag == ArgFlag.KWONLY: if self.var_keyword or self.var_positional: - raise InvalidParam(lang_config.args_exclude_mutable_args) + raise InvalidParam(config.lang.args_exclude_mutable_args) slot['kwonly'] = True - return name, slot + self.argument[name] = slot + + @staticmethod + def __handle_force__(slot: ArgUnit, value): + slot['value'] = ( + BasePattern(value, alias=f"\'{value}\'") if isinstance(value, str) else BasePattern.of(value) + ) def __len__(self): return len(self.argument) @@ -314,12 +310,9 @@ def __setattr__(self, key, value): return self def __getitem__(self, item) -> Union["Args", Tuple[TAValue, TADefault]]: - if isinstance(item, str): - if self.argument.get(item): - return self.argument[item]['value'], self.argument[item]['default'] - else: - raise KeyError(lang_config.args_key_not_found) - if isinstance(item, slice): + if isinstance(item, str) and self.argument.get(item): + return self.argument[item]['value'], self.argument[item]['default'] + if isinstance(item, slice) or isinstance(item, tuple) and list(filter(lambda x: isinstance(x, slice), item)): raise InvalidParam(f"{self.__name__} 现在不支持切片; 应从 Args[a:b:c, x:y:z] 变为 Args[a,b,c][x,y,z]") if not isinstance(item, tuple): self.__check_var__([str(item), item]) @@ -399,17 +392,16 @@ def __init__( help_text(str): 命令帮助信息 """ if not name: - raise InvalidParam(lang_config.node_name_empty) + raise InvalidParam(config.lang.node_name_empty) if re.match(r"^[`~?/.,<>;\':\"|!@#$%^&*()_+=\[\]}{]+.*$", name): - raise InvalidParam(lang_config.node_name_error) + raise InvalidParam(config.lang.node_name_error) _parts = name.split(" ") self.name = _parts[-1] - self.requires = _parts[:-1] if not requires else ( - requires if isinstance(requires, (list, tuple, set)) else (requires,) - ) - self.args = Args() if not args else args if isinstance(args, Args) else Args.from_string_list( + self.requires = (requires if isinstance(requires, (list, tuple, set)) else (requires,)) \ + if requires else _parts[:-1] + self.args = (args if isinstance(args, Args) else Args.from_string_list( [re.split("[:=]", p) for p in re.split(r"\s*,\s*", args)], {} - ) + )) if args else Args() self.action = ArgAction.__validator__(action, self.args) self.separators = {' '} if separators is None else ( {separators} if isinstance(separators, str) else set(separators) @@ -438,7 +430,8 @@ def __eq__(self, other): class Option(CommandNode): """命令选项, 可以使用别名""" - aliases: List[str] + aliases: FrozenSet[str] + priority: int def __init__( self, @@ -450,17 +443,19 @@ def __init__( separators: Optional[Union[str, Sequence[str], Set[str]]] = None, help_text: Optional[str] = None, requires: Optional[Union[str, Sequence[str], Set[str]]] = None, - + priority: int = 0 ): - self.aliases = alias or [] + aliases = alias or [] parts = name.split(" ") name, rest = parts[-1], parts[:-1] if "|" in name: aliases = name.split('|') aliases.sort(key=len, reverse=True) name = aliases[0] - self.aliases.extend(aliases[1:]) - self.aliases.insert(0, name) + aliases.extend(aliases[1:]) + aliases.insert(0, name) + self.aliases = frozenset(aliases) + self.priority = priority super().__init__( " ".join(rest) + (" " if rest else "") + name, args, dest, action, separators, help_text, requires ) @@ -506,9 +501,12 @@ class SubcommandResult(TypedDict): options: Dict[str, OptionResult] +class StrMounter(List[str]): + pass + + HelpOption = Option("--help|-h", help_text="显示帮助信息") ShortcutOption = Option( - '--shortcut|-SCT', - Args["delete;O", "delete"]["name", str]["command", str, "_"]["expiration;K", int, 0], + '--shortcut|-SCT', Args["delete;O", "delete"]["name", str]["command", str, "_"]["expiration;K", int, 0], help_text='设置快捷命令' ) diff --git a/src/arclet/alconna/builtin/actions.py b/src/arclet/alconna/builtin/actions.py index db2d2f1e..2bcc8262 100644 --- a/src/arclet/alconna/builtin/actions.py +++ b/src/arclet/alconna/builtin/actions.py @@ -5,7 +5,7 @@ from arclet.alconna.components.action import ArgAction from arclet.alconna.components.behavior import ArpamarBehavior from arclet.alconna.exceptions import BehaveCancelled, OutBoundsBehavior -from arclet.alconna.lang import lang_config +from arclet.alconna.config import config class _StoreValue(ArgAction): @@ -51,13 +51,13 @@ def operate(self, interface: "Arpamar"): raise BehaveCancelled if option and subcommand is None: options = interface.query("options", {}) - options[option] = {"value": value, "args": {}} + options.setdefault(option, {"value": value, "args": {}}) if subcommand and option is None: subcommands = interface.query("subcommands", {}) - subcommands[subcommand] = {"value": value, "args": {}, "options": {}} + subcommands.setdefault(subcommand, {"value": value, "args": {}, "options": {}}) if option and subcommand: sub_options = interface.query(f"{subcommand}.options", {}) - sub_options[option] = {"value": value, "args": {}} + sub_options.setdefault(option, {"value": value, "args": {}}) return _SetDefault() @@ -74,12 +74,13 @@ def exclusion(target_path: str, other_path: str): class _EXCLUSION(ArpamarBehavior): def operate(self, interface: "Arpamar"): if interface.query(target_path) and interface.query(other_path): + interface.matched = False if interface.source.is_raise_exception: raise OutBoundsBehavior( - lang_config.behavior_exclude_matched.format(target=target_path, other=other_path) + config.lang.behavior_exclude_matched.format(target=target_path, other=other_path) ) interface.error_info = OutBoundsBehavior( - lang_config.behavior_exclude_matched.format(target=target_path, other=other_path) + config.lang.behavior_exclude_matched.format(target=target_path, other=other_path) ) return _EXCLUSION() @@ -101,9 +102,9 @@ def operate(self, interface: "Arpamar"): current_time = datetime.now() if (current_time - self.last_time).total_seconds() < seconds: if interface.source.is_raise_exception: - raise OutBoundsBehavior(lang_config.behavior_cooldown_matched) + raise OutBoundsBehavior(config.lang.behavior_cooldown_matched) interface.matched = False - interface.error_info = OutBoundsBehavior(lang_config.behavior_cooldown_matched) + interface.error_info = OutBoundsBehavior(config.lang.behavior_cooldown_matched) else: self.last_time = current_time @@ -125,14 +126,11 @@ def operate(self, interface: "Arpamar"): for target in targets: if not interface.query(target): interface.matched = False - interface.error_info = OutBoundsBehavior(lang_config.behavior_inclusion_matched) + interface.error_info = OutBoundsBehavior(config.lang.behavior_inclusion_matched) break else: - all_count = len(targets) - sum( - 1 for target in targets if interface.require(target) - ) - + all_count = len(targets) - sum(1 for target in targets if interface.require(target)) if all_count > 0: interface.matched = False - interface.error_info = OutBoundsBehavior(lang_config.behavior_inclusion_matched) + interface.error_info = OutBoundsBehavior(config.lang.behavior_inclusion_matched) return _Inclusion() diff --git a/src/arclet/alconna/builtin/analyser.py b/src/arclet/alconna/builtin/analyser.py index 40b434bd..df88e207 100644 --- a/src/arclet/alconna/builtin/analyser.py +++ b/src/arclet/alconna/builtin/analyser.py @@ -1,4 +1,4 @@ -from typing import Union +from typing import Union, Any from arclet.alconna.base import Subcommand, Sentence from arclet.alconna.arpamar import Arpamar @@ -7,88 +7,87 @@ from arclet.alconna.analysis.analyser import Analyser from arclet.alconna.analysis.special import handle_help, handle_shortcut from arclet.alconna.analysis.parts import analyse_args, analyse_option, analyse_subcommand, analyse_header, \ - analyse_params + analyse_unmatch_params from arclet.alconna.exceptions import ParamsUnmatched, ArgumentMissing, FuzzyMatchSuccess -from arclet.alconna.lang import lang_config -from arclet.alconna.components.output import output_send +from arclet.alconna.config import config +from arclet.alconna.components.output import output_manager class DefaultCommandAnalyser(Analyser): """ 内建的默认分析器 - """ filter_out = ["Source", "File", "Quote"] - def analyse(self, message: Union[str, DataCollection, None] = None) -> Arpamar: + def analyse(self, message: Union[DataCollection[Union[str, Any]], None] = None) -> Arpamar: if command_manager.is_disable(self.alconna): return self.export(fail=True) if self.ndata == 0 and not self.temporary_data.get('fail'): if not message: - raise ValueError(lang_config.analyser_handle_null_message.format(target=message)) - self.process_message(message) + raise ValueError(config.lang.analyser_handle_null_message.format(target=message)) + self.process(message) if self.temporary_data.get('fail'): - self.reset() return self.export(fail=True, exception=self.temporary_data.get('exception')) if (res := command_manager.get_record(self.temp_token)) and self.temp_token in self.used_tokens: self.reset() return res try: self.header = analyse_header(self) + self.head_pos = self.current_index, self.content_index except ParamsUnmatched as e: self.current_index = 0 self.content_index = 0 try: - _res = command_manager.find_shortcut(self.next_data(self.separators, pop=False)[0], self.alconna) + _res = command_manager.find_shortcut(self.popitem(move=False)[0], self.alconna) self.reset() if isinstance(_res, Arpamar): return _res - self.process_message(_res) - return self.analyse() + return self.process(_res).analyse() except ValueError: return self.export(fail=True, exception=e) except FuzzyMatchSuccess as Fuzzy: - output_send(self.alconna.name, lambda: str(Fuzzy)).handle({}, is_raise_exception=self.is_raise_exception) + output_manager.get(self.alconna.name, lambda: str(Fuzzy)).handle(is_raise_exception=self.is_raise_exception) return self.export(fail=True) for _ in self.part_len: - _param = analyse_params(self, self.command_params) try: - if not _param or _param is Ellipsis: - if not self.main_args: - self.main_args = analyse_args(self, self.self_args, self.alconna.nargs) + _text, _str = self.popitem(move=False) + _param = _param if (_param := (self.command_params.get(_text) if _str and _text else Ellipsis)) else ( + None if self.default_separate else analyse_unmatch_params( + self.command_params.values(), _text, self.alconna.is_fuzzy_match + ) + ) + if (not _param or _param is Ellipsis) and not self.main_args: + self.main_args = analyse_args(self, self.self_args, self.alconna.nargs) elif isinstance(_param, list): for opt in _param: if opt.name == "--help": return handle_help(self) if opt.name == "--shortcut": return handle_shortcut(self) - _current_index = self.current_index - _content_index = self.content_index + _current_index, _content_index = self.current_index, self.content_index try: opt_n, opt_v = analyse_option(self, opt) self.options[opt_n] = opt_v break except Exception as e: exc = e - self.current_index = _current_index - self.content_index = _content_index + self.current_index, self.content_index = _current_index, _content_index continue else: raise exc # noqa elif isinstance(_param, Subcommand): - sub_n, sub_v = analyse_subcommand(self, _param) - self.subcommands[sub_n] = sub_v + self.subcommands.setdefault(*analyse_subcommand(self, _param)) elif isinstance(_param, Sentence): - self.sentences.append(self.next_data(self.separators)[0]) - except FuzzyMatchSuccess as Fuzzy: - output_send( - self.alconna.name, lambda: str(Fuzzy) - ).handle({}, is_raise_exception=self.is_raise_exception) + self.sentences.append(self.popitem()[0]) + except FuzzyMatchSuccess as e: + output_manager.get(self.alconna.name, lambda: str(e)).handle(is_raise_exception=self.is_raise_exception) return self.export(fail=True) except (ParamsUnmatched, ArgumentMissing): + if self.release()[-1] in ("--help", "-h"): + return handle_help(self) if self.is_raise_exception: raise return self.export(fail=True) @@ -99,16 +98,14 @@ def analyse(self, message: Union[str, DataCollection, None] = None) -> Arpamar: if self.default_main_only and not self.main_args: self.main_args = analyse_args(self, self.self_args, self.alconna.nargs) - if self.current_index == self.ndata and (not self.need_main_args or (self.need_main_args and self.main_args)): + if self.current_index == self.ndata and (not self.need_main_args or self.main_args): return self.export() - data_len = self.rest_count(self.separators) + data_len = len(self.release()) if data_len > 0: - exc = ParamsUnmatched( - lang_config.analyser_param_unmatched.format(target=self.next_data(self.separators, pop=False)[0]) - ) + exc = ParamsUnmatched(config.lang.analyser_param_unmatched.format(target=self.popitem(move=False)[0])) else: - exc = ArgumentMissing(lang_config.analyser_param_missing) + exc = ArgumentMissing(config.lang.analyser_param_missing) if self.is_raise_exception: raise exc return self.export(fail=True, exception=exc) diff --git a/src/arclet/alconna/builtin/checker.py b/src/arclet/alconna/builtin/checker.py new file mode 100644 index 00000000..8f7c7397 --- /dev/null +++ b/src/arclet/alconna/builtin/checker.py @@ -0,0 +1,37 @@ +from typing_extensions import ParamSpec +from typing import TypeVar, Callable, Optional +from functools import wraps +from arclet.alconna.base import Args +from arclet.alconna.analysis.base import analyse_args + +T = TypeVar("T") +P = ParamSpec("P") + + +def simple_type(raise_exception: bool = False): + def deco(func: Callable[P, T]) -> Callable[P, Optional[T]]: + _args, _ = Args.from_callable(func) + + @wraps(func) + def __wrapper__(*args: P.args, **kwargs: P.kwargs): + param = list(args) + for k, v in kwargs.items(): + param.extend([f"{k}=", v]) + if not (result := analyse_args(_args, param, raise_exception)): + return None + res_args = [] + res_kwargs = {} + varargs = [] + if '__kwargs__' in result: + res_kwargs, kw_key = result.pop('__kwargs__') + result.pop(kw_key) + if '__varargs__' in result: + varargs, var_key = result.pop('__varargs__') + result.pop(var_key) + res_args.extend(iter(result.values())) + res_args.extend(varargs) + return func(*res_args, **res_kwargs) # type: ignore + + return __wrapper__ + + return deco diff --git a/src/arclet/alconna/builtin/construct.py b/src/arclet/alconna/builtin/construct.py index ea0702a9..b0212523 100644 --- a/src/arclet/alconna/builtin/construct.py +++ b/src/arclet/alconna/builtin/construct.py @@ -2,7 +2,7 @@ import sys import re import inspect -from functools import partial +from functools import partial, wraps from types import FunctionType, MethodType, ModuleType from typing import Dict, Any, Optional, Callable, Union, TypeVar, List, Type, FrozenSet, Literal, get_args, Tuple, \ Iterable, cast @@ -10,8 +10,7 @@ from arclet.alconna.core import Alconna from arclet.alconna.base import Args, TAValue, ArgAction, Option, Subcommand, ArgFlag from arclet.alconna.util import split, split_once -from arclet.alconna.lang import lang_config -from arclet.alconna.manager import command_manager +from arclet.alconna.config import config as global_config from .actions import store_value @@ -65,25 +64,26 @@ def set_parser(self, parser_func: PARSER_TYPE): self.parser_func = parser_func return self - def __call__(self, message: Union[str, DataCollection]) -> Any: + def __call__(self, message: DataCollection[Union[str, Any]]) -> Any: if not self.exec_target: - raise RuntimeError(lang_config.construct_decorate_error) + raise RuntimeError(global_config.lang.construct_decorate_error) result = self.command.parse(message) if result.matched: self.parser_func( self.exec_target, result.all_matched_args, self.local_args, - command_manager.loop + global_config.loop ) + return result def from_commandline(self): """从命令行解析参数""" if not self.command: - raise RuntimeError(lang_config.construct_decorate_error) + raise RuntimeError(global_config.lang.construct_decorate_error) args = sys.argv[1:] args.insert(0, self.command.command) - self.__call__(" ".join(args)) + return self.__call__(" ".join(args)) F = TypeVar("F", bound=Callable[..., Any]) @@ -111,26 +111,22 @@ class AlconnaDecorate: Attributes: namespace (str): 命令的命名空间 - loop (AbstractEventLoop): 事件循环 """ namespace: str building: bool __storage: Dict[str, Any] default_parser: PARSER_TYPE - def __init__(self, namespace: str = "Alconna", loop: Optional[AbstractEventLoop] = None): + def __init__(self, namespace: str = "Alconna"): """ 初始化构造器 Args: namespace (str): 命令的命名空间 - loop (AbstractEventLoop): 事件循环 """ self.namespace = namespace self.building = False self.__storage = {"options": []} - if loop: - command_manager.loop = loop self.default_parser = default_parser def build_command(self, name: Optional[str] = None) -> Callable[[F], ALCCommand]: @@ -148,10 +144,9 @@ def wrapper(func: Callable[..., Any]) -> ALCCommand: command_name = name or self.__storage['func'].__name__ help_string = self.__storage.get('func').__doc__ command = Alconna( - command=command_name, + command_name, self.__storage.get("main_args"), options=self.__storage.get("options"), namespace=self.namespace, - main_args=self.__storage.get("main_args"), help_text=help_string or command_name ) self.building = False @@ -178,14 +173,12 @@ def option( sep (str): 参数分隔符 """ if not self.building: - raise RuntimeError(lang_config.construct_decorate_error) + raise RuntimeError(global_config.lang.construct_decorate_error) def wrapper(func: FC) -> FC: if not self.__storage.get('func'): self.__storage['func'] = func - self.__storage['options'].append( - Option(name, args=args, action=action, separators=sep, help_text=help or name) - ) + self.__storage['options'].append(Option(name, args, action=action, separators=sep, help_text=help or name)) return func return wrapper @@ -198,7 +191,7 @@ def arguments(self, args: Args) -> Callable[[FC], FC]: args (Args): 参数 """ if not self.building: - raise RuntimeError(lang_config.construct_decorate_error) + raise RuntimeError(global_config.lang.construct_decorate_error) def wrapper(func: FC) -> FC: if not self.__storage.get('func'): @@ -219,11 +212,6 @@ def set_default_parser(self, parser_func: PARSER_TYPE): return self -# ---------------------------------------- -# format -# ---------------------------------------- - - def _from_format( format_string: str, format_args: Optional[Dict[str, Union[TAValue, Args, Option, List[Option]]]] = None, @@ -270,7 +258,7 @@ def _from_format( _name, _requires = _string_stack[-1], _string_stack[:-1] if isinstance(value, Option): options.append(Subcommand(_name, [value], requires=_requires)) - elif isinstance(value, List): + elif isinstance(value, list): options.append(Subcommand(_name, value, requires=_requires)) elif isinstance(value, Args): options.append(Option(_name, args=value, requires=_requires)) @@ -281,7 +269,7 @@ def _from_format( if i == 0: if isinstance(value, Args): main_args.__merge__(value) - elif not isinstance(value, Option) and not isinstance(value, List): + elif not isinstance(value, Option) and not isinstance(value, list): main_args.__merge__(Args(**{key: value})) elif isinstance(value, Option): options.append(value) @@ -297,7 +285,7 @@ def _from_format( if _string_stack: if _key_ref > 1: options[-1].args.__merge__(_arg) - options[-1].nargs += 1 + options[-1].nargs = len(options[-1].args.argument) else: options.append(Option(_string_stack[-1], _arg, requires=_string_stack[:-1])) _string_stack.clear() @@ -306,17 +294,7 @@ def _from_format( return Alconna(command=command, options=options, main_args=main_args) -# ---------------------------------------- -# koishi-like -# ---------------------------------------- - - -def _from_string( - command: str, - *option: str, - custom_types: Optional[Dict[str, Type]] = None, - sep: str = " " -) -> "Alconna": +def _from_string(command: str, *option: str, sep: str = " ") -> "Alconna": """ 以纯字符串的形式构造Alconna的简易方式, 或者说是koishi-like的方式 @@ -342,12 +320,9 @@ def _from_string( res = re.split("[:=]", p) res[0] = f"{res[0]};O" args.append(res) - if not (help_string := re.findall(r"#(.+)", others)): + if not (help_string := re.findall(r"(?: )#(.+)$", others)): # noqa help_string = headers - if not custom_types: - custom_types = Alconna.custom_types.copy() - else: - custom_types.update(Alconna.custom_types) + custom_types = Alconna.custom_types.copy() custom_types.update(getattr(inspect.getmodule(inspect.stack()[1][0]), "__dict__", {})) _args = Args.from_string_list(args, custom_types.copy()) for opt in option: @@ -357,15 +332,14 @@ def _from_string( res = re.split("[:=]", p) res[0] = f"{res[0]};O" opt_args.append(res) - _opt_args = Args.from_string_list(opt_args, custom_types.copy()) - opt_action_value = re.findall(r"&(.+?)(?:#.+?)?$", opt_others) - if not (opt_help_string := re.findall(r"#(.+)$", opt_others)): + _typs = custom_types.copy() + _opt_args = Args.from_string_list(opt_args, _typs) + opt_action_value = re.findall(r"&(.+?)(?: #.+?)?$", opt_others) + if not (opt_help_string := re.findall(r"(?: )#(.+)$", opt_others)): # noqa opt_help_string = [opt_head] + _options.append(Option(opt_head, args=_opt_args)) if opt_action_value: - val = eval(opt_action_value[0].rstrip(), {"true": True, "false": False}) - _options.append(Option(opt_head, args=_opt_args, action=store_value(val))) - else: - _options.append(Option(opt_head, args=_opt_args)) + _options[-1].action = store_value(eval(opt_action_value[0].rstrip(), {"true": True, "false": False})) _options[-1].help_text = opt_help_string[0] return Alconna(headers=headers, main_args=_args, options=_options, help_text=help_string[0], is_fuzzy_match=True) @@ -393,16 +367,9 @@ def visit_config(obj: Any, config_keys: Iterable[str]): _contents = re.split(r"\s*=\s*", line.strip()) if len(_contents) == 2 and _contents[0] in config_keys: result[_contents[0]] = eval(_contents[1]) - elif config := inspect.getmembers( - obj, - predicate=lambda x: inspect.isclass(x) - and x.__name__.endswith("Config"), - ): + elif config := inspect.getmembers(obj, lambda x: inspect.isclass(x) and x.__name__.endswith("Config")): config = config[0][1] - configs = list(filter(lambda x: not x.startswith("_"), dir(config))) - for key in config_keys: - if key in configs: - result[key] = getattr(config, key) + result = {k: getattr(config, k) for k in config_keys if k in dir(config)} return result @@ -419,16 +386,19 @@ def _instance_action(self, option_dict, varargs, kwargs): setattr(self.instance, key, value) return option_dict - def _inject_instance(self, target: Callable): - return partial(target, self.instance) - def _get_instance(self): return self.instance + def _inject_instance(self, target: Callable): + @wraps(target) + def __wrapper(*args, **kwargs): + return target(self._get_instance(), *args, **kwargs) + return __wrapper + def _parse_action(self, message): ... - def parse(self, message: Union[str, DataCollection], duplication: Optional[Any] = None, + def parse(self, message: DataCollection[Union[str, Any]], duplication: Optional[Any] = None, static: bool = True): # noqa message = self._parse_action(message) or message return super(AlconnaMounter, self).parse(message, duplication=duplication, static=static) @@ -440,7 +410,7 @@ def __init__(self, func: Union[FunctionType, MethodType], config: Optional[dict] config = config or visit_config(func, self.config_keys) func_name = func.__name__ if func_name.startswith("_"): - raise ValueError(lang_config.construct_function_name_error) + raise ValueError(global_config.lang.construct_function_name_error) _args, method = Args.from_callable(func, extra=config.get("extra", "ignore")) if method and isinstance(func, MethodType): self.instance = func.__self__ @@ -459,19 +429,15 @@ def __init__(self, func: Union[FunctionType, MethodType], config: Optional[dict] def visit_subcommand(obj: Any): result = [] subcommands: List[Tuple[str, Type]] = inspect.getmembers( - obj, predicate=lambda x: inspect.isclass(x) and not x.__name__.endswith("Config") + obj, lambda x: inspect.isclass(x) and not x.__name__.endswith("Config") ) class _MountSubcommand(Subcommand): sub_instance: object - for cls_name, subcommand_cls in subcommands: - if cls_name.startswith("_"): - continue + for cls_name, subcommand_cls in filter(lambda x: not x[0].startswith("_"), subcommands): init = inspect.getfullargspec(subcommand_cls.__init__) - members = inspect.getmembers( - subcommand_cls, predicate=lambda x: inspect.isfunction(x) or inspect.ismethod(x) - ) + members = inspect.getmembers(subcommand_cls, lambda x: inspect.isfunction(x) or inspect.ismethod(x)) config = visit_config(subcommand_cls, ["command", "description"]) _options = [] sub_help_text = subcommand_cls.__doc__ or subcommand_cls.__init__.__doc__ or cls_name @@ -479,60 +445,42 @@ class _MountSubcommand(Subcommand): if len(init.args + init.kwonlyargs) > 1: sub_args = Args.from_callable(subcommand_cls.__init__, extra='ignore')[0] sub = _MountSubcommand( - config.get("command", cls_name), - help_text=config.get("description", sub_help_text), - args=sub_args + config.get("command", cls_name), help_text=config.get("description", sub_help_text), args=sub_args ) sub.sub_instance = subcommand_cls - def _instance_action(option_dict, varargs, kwargs): - if not sub.sub_instance: - sub.sub_instance = subcommand_cls(*option_dict.values(), *varargs, **kwargs) - else: - for key, value in option_dict.items(): - setattr(sub.sub_instance, key, value) - return option_dict - class _InstanceAction(ArgAction): - def handle(self, option_dict, varargs=None, kwargs=None, is_raise_exception=False): - return _instance_action(option_dict, varargs, kwargs) - - class _TargetAction(ArgAction): - origin: Callable - - def __init__(self, target: Callable): - self.origin = target - super().__init__(target) def handle(self, option_dict, varargs=None, kwargs=None, is_raise_exception=False): - self.action = partial(self.origin, sub.sub_instance) - return super().handle(option_dict, varargs, kwargs, is_raise_exception) + if not sub.sub_instance: + sub.sub_instance = subcommand_cls(*option_dict.values(), *varargs, **kwargs) + else: + for key, value in option_dict.items(): + setattr(sub.sub_instance, key, value) + return option_dict - for name, func in members: - if name.startswith("_"): - continue - help_text = func.__doc__ or name - _opt_args, method = Args.from_callable(func, extra='ignore') - if method: - _options.append(Option(name, args=_opt_args, action=_TargetAction(func), help_text=help_text)) - else: - _options.append(Option(name, args=_opt_args, action=ArgAction(func), help_text=help_text)) - sub.options = _options sub.action = _InstanceAction(lambda: None) - result.append(sub) else: sub = _MountSubcommand(config.get("command", cls_name), help_text=config.get("description", sub_help_text)) sub.sub_instance = subcommand_cls() - for name, func in members: - if name.startswith("_"): - continue - help_text = func.__doc__ or name - _opt_args, method = Args.from_callable(func, extra='ignore') - if method: - func = partial(func, sub.sub_instance) - _options.append(Option(name, args=_opt_args, action=ArgAction(func), help_text=help_text)) - sub.options = _options - result.append(sub) + + def _get_sub_instance(_sub): + return _sub.sub_instance + + def _inject_sub_instance(target: Callable): + @wraps(target) + def __wrapper(*args, **kwargs): + return target(_get_sub_instance, *args, **kwargs) + return __wrapper + + for name, func in filter(lambda x: not x[0].startswith("_"), members): + help_text = func.__doc__ or name + _opt_args, method = Args.from_callable(func, extra='ignore') + if method: + func = _inject_sub_instance(func) + _options.append(Option(name, _opt_args, action=ArgAction(func), help_text=help_text)) + sub.options = _options + result.append(sub) return result @@ -543,9 +491,7 @@ def __init__(self, mount_cls: Type, config: Optional[dict] = None): self.instance: mount_cls = None config = config or visit_config(mount_cls, self.config_keys) init = inspect.getfullargspec(mount_cls.__init__) - members = inspect.getmembers( - mount_cls, predicate=lambda x: inspect.isfunction(x) or inspect.ismethod(x) - ) + members = inspect.getmembers(mount_cls, lambda x: inspect.isfunction(x) or inspect.ismethod(x)) _options = [] if config.get('get_subcommand', False): subcommands = visit_subcommand(mount_cls) @@ -561,53 +507,32 @@ class _InstanceAction(ArgAction): def handle(self, option_dict, varargs=None, kwargs=None, is_raise_exception=False): return instance_handle(option_dict, varargs, kwargs) - inject = self._inject_instance - - class _TargetAction(ArgAction): - origin: Callable - - def __init__(self, target: Callable): - self.origin = target - super().__init__(target) - - def handle(self, option_dict, varargs=None, kwargs=None, is_raise_exception=False): - self.action = inject(self.origin) - return super().handle(option_dict, varargs, kwargs, is_raise_exception) - main_action = _InstanceAction(lambda: None) - for name, func in members: - if name.startswith("_"): - continue + for name, func in filter(lambda x: not x[0].startswith("_"), members): help_text = func.__doc__ or name _opt_args, method = Args.from_callable(func, extra=config.get("extra", "ignore")) if method: - _options.append(Option(name, args=_opt_args, action=_TargetAction(func), help_text=help_text)) - else: - _options.append(Option(name, args=_opt_args, action=ArgAction(func), help_text=help_text)) + func = self._inject_instance(func) + _options.append(Option(name, _opt_args, action=ArgAction(func), help_text=help_text)) super().__init__( - headers=config.get('headers', None), + config.get('command', mount_cls.__name__), main_args, config.get('headers', None), _options, namespace=config.get('namespace', None), - command=config.get('command', mount_cls.__name__), - main_args=main_args, - options=_options, help_text=config.get('description', main_help_text), is_raise_exception=config.get('raise_exception', True), action=main_action, ) else: self.instance = mount_cls() - for name, func in members: - if name.startswith("_"): - continue + for name, func in filter(lambda x: not x[0].startswith("_"), members): help_text = func.__doc__ or name _opt_args, method = Args.from_callable(func, extra=config.get("extra", "ignore")) if method: - func = partial(func, self.instance) + func = self._inject_instance(func) _options.append(Option(name, args=_opt_args, action=ArgAction(func), help_text=help_text)) super().__init__( + config.get('command', mount_cls.__name__), headers=config.get('headers', None), namespace=config.get('namespace', None), - command=config.get('command', mount_cls.__name__), options=_options, help_text=config.get('description', main_help_text), is_raise_exception=config.get('raise_exception', True), @@ -627,9 +552,7 @@ def __init__(self, module: ModuleType, config: Optional[dict] = None): self.instance = module config = config or visit_config(module, self.config_keys) _options = [] - members = inspect.getmembers( - module, predicate=lambda x: inspect.isfunction(x) or inspect.ismethod(x) - ) + members = inspect.getmembers(module, lambda x: inspect.isfunction(x) or inspect.ismethod(x)) for name, func in members: if name.startswith("_") or func.__name__.startswith("_"): continue @@ -639,9 +562,9 @@ def __init__(self, module: ModuleType, config: Optional[dict] = None): func = partial(func, func.__self__) _options.append(Option(name, args=_opt_args, action=ArgAction(func), help_text=help_text)) super().__init__( + config.get('command', module.__name__), headers=config.get('headers', None), namespace=config.get('namespace', None), - command=config.get('command', module.__name__), options=_options, help_text=config.get("description", module.__doc__ or module.__name__), is_raise_exception=config.get("raise_exception", True) @@ -664,17 +587,13 @@ def __init__(self, obj: object, config: Optional[dict] = None): config = config or visit_config(obj, self.config_keys) obj_name = obj.__class__.__name__ init = inspect.getfullargspec(obj.__init__) - members = inspect.getmembers( - obj, predicate=lambda x: inspect.isfunction(x) or inspect.ismethod(x) - ) + members = inspect.getmembers(obj, lambda x: inspect.isfunction(x) or inspect.ismethod(x)) _options = [] if config.get('get_subcommand', False): subcommands = visit_subcommand(obj) _options.extend(subcommands) main_help_text = obj.__doc__ or obj.__init__.__doc__ or obj_name - for name, func in members: - if name.startswith("_"): - continue + for name, func in filter(lambda x: not x[0].startswith("_"), members): help_text = func.__doc__ or name _opt_args, _ = Args.from_callable(func, extra=config.get("extra", "ignore")) _options.append(Option(name, args=_opt_args, action=ArgAction(func), help_text=help_text)) @@ -693,10 +612,7 @@ def handle(self, option_dict, varargs=None, kwargs=None, is_raise_exception=Fals main_action = _InstanceAction(lambda: None) super().__init__( - headers=config.get('headers', None), - command=config.get('command', obj_name), - main_args=main_args, - options=_options, + config.get('command', obj_name), main_args, config.get('headers', None), _options, help_text=config.get("description", main_help_text), is_raise_exception=config.get("raise_exception", True), action=main_action, @@ -704,8 +620,8 @@ def handle(self, option_dict, varargs=None, kwargs=None, is_raise_exception=Fals ) else: super().__init__( + config.get('command', obj_name), headers=config.get('headers', None), - command=config.get('command', obj_name), options=_options, namespace=config.get('namespace', None), help_text=config.get("description", main_help_text), @@ -752,9 +668,7 @@ def delegate(cls: Type) -> Alconna: _main_args = None _options = [] _headers = [] - for name, attr in attrs: - if name.startswith("_"): - continue + for name, attr in filter(lambda x: not x[0].startswith("_"), attrs): if isinstance(attr, Args): _main_args = attr elif isinstance(attr, (Option, Subcommand)): @@ -775,10 +689,8 @@ def _argument( action: Optional[Union[ArgAction, Callable]] = None, ): """类似于 argparse.ArgumentParser.add_argument() 的方法""" - opt = Option(name=name, alias=list(alias), dest=dest, help_text=description, action=action) - opt.args.add_argument( - name=name.strip('-'), value=value, default=default, flags=[] if required else [ArgFlag.OPTIONAL] - ) + opt = Option(name, alias=list(alias), dest=dest, help_text=description, action=action) + opt.args.add_argument(name.strip('-'), value=value, default=default, flags=[] if required else [ArgFlag.OPTIONAL]) opt.nargs += 1 return opt diff --git a/src/arclet/alconna/builtin/formatter.py b/src/arclet/alconna/builtin/formatter.py index 247feb80..a913b2f6 100644 --- a/src/arclet/alconna/builtin/formatter.py +++ b/src/arclet/alconna/builtin/formatter.py @@ -19,13 +19,10 @@ def format(self, trace: Trace) -> str: def param(self, name: str, parameter: ArgUnit) -> str: arg = ("[" if parameter['optional'] else "<") + name if not parameter['hidden']: - _sep = "@" if parameter['kwonly'] else ":" if parameter['value'] is AllParam: return f"<...{name}>" - if isinstance(parameter['value'], BasePattern) and parameter['value'].pattern == name: - pass - else: - arg += f"{_sep}{parameter['value']}" + if not isinstance(parameter['value'], BasePattern) or parameter['value'].pattern != name: + arg += f"{'@' if parameter['kwonly'] else ':'}{parameter['value']}" if parameter['default'] is Empty: arg += " = None" elif parameter['default'] is not None: @@ -35,49 +32,46 @@ def param(self, name: str, parameter: ArgUnit) -> str: def parameters(self, args: Args) -> str: param_texts = [self.param(k, param) for k, param in args.argument.items()] if len(args.separators) == 1: - separator = args.separators.copy().pop() - return separator.join(param_texts) - return " ".join(param_texts) + " splitBy:" + "/".join(args.separators) + separator = tuple(args.separators)[0] + res = separator.join(param_texts) + else: + res = " ".join(param_texts) + " splitBy:" + "/".join(args.separators) + notice = [(k, param['notice']) for k, param in args.argument.items() if param['notice']] + if not notice: + return res + return res + "\n## 注释\n " + "\n ".join(f"{v[0]}: {v[1]}" for v in notice) def header(self, root: Dict[str, Any], separators: Set[str]) -> str: help_string = ("\n" + root['description']) if root.get('description') else "" - if usage := re.findall(r".*Usage:(.+?);", help_string, flags=re.S): - help_string = help_string.replace(f"Usage:{usage[0]};", "") - usage = '\n用法:\n' + usage[0] - else: - usage = "" - if example := re.findall(r".*Example:(.+?);", help_string, flags=re.S): - help_string = help_string.replace(f"Example:{example[0]};", "") - example = '\n使用示例:\n' + example[0] - else: - example = "" - headers = root.get('headers', ['']) - command = root.get('name', '') - headers = f"[{''.join(map(str, headers))}]" if headers != [''] else "" - cmd = f"{headers}{command}" - sep = separators.copy().pop() - command_string = cmd or (root['name'] + sep) + usage = "" + if res := re.findall(r".*Usage:(.+?);", help_string, flags=re.S): + help_string = help_string.replace(f"Usage:{res[0]};", "") + usage = '\n用法:\n' + res[0] + example = "" + if res := re.findall(r".*Example:(.+?);", help_string, flags=re.S): + help_string = help_string.replace(f"Example:{res[0]};", "") + example = '\n使用示例:\n' + res[0] + headers = f"[{''.join(map(str, headers))}]" if (headers := root.get('header', [''])) != [''] else "" + cmd = f"{headers}{root.get('name', '')}" + command_string = cmd or (root['name'] + tuple(separators)[0]) return f"{command_string} %s{help_string}{usage}\n%s{example}" def part(self, node: Union[Subcommand, Option]) -> str: if isinstance(node, Subcommand): - sep = node.separators.copy().pop() name = " ".join(node.requires) + (' ' if node.requires else '') + node.name option_string = "".join([self.part(i).replace("\n", "\n ") for i in node.options]) option_help = "## 该子命令内可用的选项有:\n " if option_string else "" return ( f"# {node.help_text}\n" - f" {name}{sep}" + f" {name}{tuple(node.separators)[0]}" f"{self.parameters(node.args)}\n" f"{option_help}{option_string}" ) elif isinstance(node, Option): - sep = node.separators.copy().pop() - alias_text = ", ".join(node.aliases) - alias_text = " ".join(node.requires) + (' ' if node.requires else '') + alias_text + alias_text = " ".join(node.requires) + (' ' if node.requires else '') + ", ".join(node.aliases) return ( f"# {node.help_text}\n" - f" {alias_text}{sep}" + f" {alias_text}{tuple(node.separators)[0]}" f"{self.parameters(node.args)}\n" ) else: @@ -120,13 +114,10 @@ def param(self, name: str, parameter: ArgUnit) -> str: # FOO[str], BAR= arg = ("[" if parameter['optional'] else "") + name.upper() if not parameter['hidden']: - _sep = "=[%s]" if parameter['kwonly'] else "[%s]" if parameter['value'] is AllParam: return f"{name.upper()}..." - if isinstance(parameter['value'], BasePattern) and parameter['value'].pattern == name: - pass - else: - arg += _sep % f"{parameter['value']}" + if not isinstance(parameter['value'], BasePattern) or parameter['value'].pattern != name: + arg += f"=[{parameter['value']}]" if parameter['kwonly'] else f"[{parameter['value']}]" if parameter['default'] is Empty: arg += "=None" elif parameter['default'] is not None: @@ -136,27 +127,28 @@ def param(self, name: str, parameter: ArgUnit) -> str: def parameters(self, args: Args) -> str: param_texts = [self.param(k, param) for k, param in args.argument.items()] if len(args.separators) == 1: - separator = args.separators.copy().pop() - return separator.join(param_texts) - return " ".join(param_texts) + ", USED SPLIT:" + "/".join(args.separators) + separator = tuple(args.separators)[0] + res = separator.join(param_texts) + else: + res = " ".join(param_texts) + ", USED SPLIT:" + "/".join(args.separators) + notice = [(k, param['notice']) for k, param in args.argument.items() if param['notice']] + if not notice: + return res + return res + "\n 内容:\n " + "\n ".join(f"{v[0]}: {v[1]}" for v in notice) def header(self, root: Dict[str, Any], separators: Set[str]) -> str: help_string = ("\n描述: " + root['description'] + "\n") if root.get('description') else "" - if usage := re.findall(r".*Usage:(.+?);", help_string, flags=re.S): - help_string = help_string.replace(f"Usage:{usage[0]};", "") - usage = '\n用法:' + usage[0] + '\n' - else: - usage = "" - if example := re.findall(r".*Example:(.+?);", help_string, flags=re.S): - help_string = help_string.replace(f"Example:{example[0]};", "") - example = '\n样例:' + example[0] + '\n' - else: - example = "" - headers = root.get('headers', ['']) - command = root.get('name', '') - header_text = f"/{''.join(map(str, headers))}/" if headers != [''] else "" - cmd = f"{header_text}{command}" - sep = separators.copy().pop() + usage = "" + if res := re.findall(r".*Usage:(.+?);", help_string, flags=re.S): + help_string = help_string.replace(f"Usage:{res[0]};", "") + usage = '\n用法:' + res[0] + '\n' + example = "" + if res := re.findall(r".*Example:(.+?);", help_string, flags=re.S): + help_string = help_string.replace(f"Example:{res[0]};", "") + example = '\n样例:' + res[0] + '\n' + header_text = f"/{''.join(map(str, headers))}/" if (headers := root.get('header', [''])) != [''] else "" + cmd = f"{header_text}{root.get('name', '')}" + sep = tuple(separators)[0] command_string = cmd or (root['name'] + sep) return f"\n命令: {command_string}{help_string}{usage}%s\n%s{example}" @@ -168,17 +160,12 @@ def body(self, parts: List[Union[Option, Subcommand]]) -> str: options = [] opt_description = [] for opt in filter(lambda x: isinstance(x, Option) and x.name != "--shortcut", parts): - alias_text = ", ".join(opt.aliases) - alias_text = " ".join(opt.requires) + (' ' if opt.requires else '') + alias_text - args = self.parameters(opt.args) - sep = opt.separators.copy().pop() - options.append(f" {alias_text}{sep}{args}") + alias_text = " ".join(opt.requires) + (' ' if opt.requires else '') + ", ".join(opt.aliases) + options.append(f" {alias_text}{tuple(opt.separators)[0]}{self.parameters(opt.args)}") opt_description.append(opt.help_text) if options: max_len = max(map(lambda x: len(x), options)) - option_string = "\n".join( - [f"{i.ljust(max_len)} {j}" for i, j in zip(options, opt_description)] - ) + option_string = "\n".join(f"{i.ljust(max_len)} {j}" for i, j in zip(options, opt_description)) subcommand_string = "" subcommands = [] sub_description = [] @@ -186,17 +173,11 @@ def body(self, parts: List[Union[Option, Subcommand]]) -> str: name = " ".join(sub.requires) + (' ' if sub.requires else '') + sub.name sub_topic = " ".join(f"[{i.name}]" for i in sub.options) # type: ignore args = self.parameters(sub.args) - sep = sub.separators.copy().pop() - subcommands.append(f" {name} {sep.join([args, sub_topic])}") + subcommands.append(f" {name} {tuple(sub.separators)[0].join([args, sub_topic])}") sub_description.append(sub.help_text) if subcommands: max_len = max(map(lambda x: len(x), subcommands)) - subcommand_string = "\n".join( - [f"{i.ljust(max_len)} {j}" for i, j in zip(subcommands, sub_description)] - ) + subcommand_string = "\n".join(f"{i.ljust(max_len)} {j}" for i, j in zip(subcommands, sub_description)) option_help = "选项:\n" if option_string else "" subcommand_help = "子命令:\n" if subcommand_string else "" - return ( - f"{subcommand_help}{subcommand_string}\n" - f"{option_help}{option_string}\n" - ) + return f"{subcommand_help}{subcommand_string}\n{option_help}{option_string}\n" diff --git a/src/arclet/alconna/builtin/pattern.py b/src/arclet/alconna/builtin/pattern.py index d1dcea88..7937672e 100644 --- a/src/arclet/alconna/builtin/pattern.py +++ b/src/arclet/alconna/builtin/pattern.py @@ -1,149 +1,76 @@ import inspect -from types import LambdaType -from typing import Type, Tuple, Callable, Literal, TypeVar, Dict, Any, get_args - -from arclet.alconna import lang_config, ParamsUnmatched +import re +from typing import Type, Tuple, Callable, Literal, TypeVar, Any, Union +from arclet.alconna import config, ParamsUnmatched, Args +from arclet.alconna.analysis.base import analyse_args from arclet.alconna.typing import BasePattern, Empty, pattern_map, PatternModel, set_converter TOrigin = TypeVar("TOrigin") class ObjectPattern(BasePattern): - def __init__( self, origin: Type[TOrigin], limit: Tuple[str, ...] = (), - head: str = "", - flag: Literal["http", "part", "json"] = "part", + flag: Literal["urlget", "part", "json"] = "part", **suppliers: Callable ): - """ - 将传入的对象类型转换为接收序列号参数解析后实例化的对象 - - Args: - origin: 原始对象 - limit: 指定该对象初始化时需要的参数 - head: 是否需要匹配一个头部 - flag: 匹配类型 - suppliers: 对象属性的匹配方法 - """ - self.origin = origin - self._require_map: Dict[str, Callable] = {} - self._supplement_map: Dict[str, Callable] = {} - self._transform_map: Dict[str, Callable] = {} - self._params: Dict[str, Any] = {} - _re_pattern = "" - _re_patterns = [] - sig = inspect.signature(origin.__init__) - for param in sig.parameters.values(): + self._args = Args() + self._names = [] + for param in inspect.signature(origin.__init__).parameters.values(): name = param.name + anno = param.annotation + default = param.default if name in ("self", "cls"): continue if limit and name not in limit: continue if param.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD): continue - self._params[name] = None - if name in suppliers: + if anno is Empty: + anno = pattern_map[str] + elif inspect.isclass(anno) and issubclass(anno, str): + anno = pattern_map[str] + elif inspect.isclass(anno) and issubclass(anno, int): + anno = pattern_map[int] + if name in suppliers and inspect.isclass(anno): _s_sig = inspect.signature(suppliers[name]) - if _s_sig.return_annotation in get_args(param.annotation): - if len(_s_sig.parameters) == 0 or ( - len(_s_sig.parameters) == 1 and inspect.ismethod(suppliers[name]) - ): - self._supplement_map[name] = suppliers[name] - elif len(_s_sig.parameters) == 1 or ( - len(_s_sig.parameters) == 2 and inspect.ismethod(suppliers[name]) - ): - self._require_map[name] = suppliers[name] - if flag == "http": - _re_patterns.append(f"{name}=(?P<{name}>.+?)") # & - elif flag == "json": - _re_patterns.append(f"\\'{name}\\':\\'(?P<{name}>.+?)\\'") # , - elif flag == "part": - _re_patterns.append(f"(?P<{name}>.+?)") # ; - else: - raise TypeError( - lang_config.types_supplier_params_error.format(target=name, origin=origin.__name__) - ) - elif isinstance(suppliers[name], LambdaType): - if len(_s_sig.parameters) == 0: - self._supplement_map[name] = suppliers[name] - elif len(_s_sig.parameters) == 1: - self._require_map[name] = suppliers[name] - if flag == "http": - _re_patterns.append(f"{name}=(?P<{name}>.+?)") # & - elif flag == "json": - _re_patterns.append(f"\\'{name}\\':\\'(?P<{name}>.+?)\\'") # , - elif flag == "part": - _re_patterns.append(f"(?P<{name}>.+?)") # ; - else: - raise TypeError( - lang_config.types_supplier_params_error.format(target=name, origin=origin.__name__) - ) + if len(_s_sig.parameters) == 1 or (len(_s_sig.parameters) == 2 and inspect.ismethod(suppliers[name])): + anno = BasePattern( + model=PatternModel.TYPE_CONVERT, origin=anno, converter=lambda x: suppliers[name](x) + ) + elif len(_s_sig.parameters) == 0 or (len(_s_sig.parameters) == 1 and inspect.ismethod(suppliers[name])): + default = suppliers[name]() else: - raise TypeError(lang_config.types_supplier_return_error.format( - target=name, origin=origin.__name__, source=param.annotation - )) - elif param.default not in (Empty, None, Ellipsis): - self._params[name] = param.default - else: - if not (args := get_args(param.annotation)): - args = (param.annotation,) - for anno in args: - pat: BasePattern = pattern_map.get(anno, None) - if pat is not None: - break - else: - pat = param.annotation - if param.annotation is Empty: - pat = pattern_map[str] - elif inspect.isclass(param.annotation) and issubclass(param.annotation, str): - pat = pattern_map[str] - elif inspect.isclass(param.annotation) and issubclass(param.annotation, int): - pat = pattern_map[int] - if pat is None: - raise TypeError(lang_config.types_supplier_missing.format(target=name, origin=origin.__name__)) - - if isinstance(pat, ObjectPattern): - raise TypeError(lang_config.types_type_error.format(target=pat)) - self._require_map[name] = pat.match - if pat.model == PatternModel.REGEX_CONVERT: - self._transform_map[name] = pat.converter - if flag == "http": - _re_patterns.append(f"{name}=(?P<{name}>{pat.pattern.strip('()')})") # & - elif flag == "part": - _re_patterns.append(f"(?P<{name}>{pat.pattern.strip('()')})") # ; - elif flag == "json": - _re_patterns.append(f"\\'{name}\\':\\'(?P<{name}>{pat.pattern.strip('()')})\\'") # , - if _re_patterns: - if flag == "http": - _re_pattern = (rf"{head}\?" if head else "") + "&".join(_re_patterns) - elif flag == "json": - _re_pattern = (f"{head}:" if head else "") + "{" + ",".join(_re_patterns) + "}" - elif flag == "part": - _re_pattern = (f"{head};" if head else "") + ";".join(_re_patterns) + raise TypeError( + config.lang.types_supplier_params_error.format(target=name, origin=origin.__name__) + ) + self._names.append(name) + self._args.add_argument(name, value=anno, default=default) + self.flag = flag + if flag == 'part': + self._re_pattern = re.compile(";".join(f"(?P<{i}>.+?)" for i in self._names)) + elif flag == 'urlget': + self._re_pattern = re.compile("&".join(f"{i}=(?P<{i}>.+?)" for i in self._names)) + elif flag == 'json': + self._re_pattern = re.compile(r"\{" + ",".join(f"\\'{i}\\':\\'(?P<{i}>.+?)\\'" for i in self._names) + "}") else: - _re_pattern = f"{head}" if head else f"{self.origin.__name__}" - - super().__init__( - _re_pattern, - model=PatternModel.REGEX_MATCH, origin_type=self.origin, alias=head or self.origin.__name__, - ) + raise TypeError(config.lang.types_type_error.format(target=flag)) + super().__init__(model=PatternModel.TYPE_CONVERT, origin=origin, alias=origin.__name__) set_converter(self) - def match(self, text: str): - if matched := self.regex_pattern.fullmatch(text): - args = matched.groupdict() - for k in self._require_map: - if k in args: - self._params[k] = self._require_map[k](args[k]) - if self._transform_map.get(k, None): - self._params[k] = self._transform_map[k](self._params[k]) - for k in self._supplement_map: - self._params[k] = self._supplement_map[k]() - return self.origin(**self._params) - raise ParamsUnmatched(lang_config.args_error.format(target=text)) + def match(self, input_: Union[str, Any]) -> TOrigin: + if isinstance(input_, self.origin): + return input_ # type: ignore + elif not isinstance(input_, str): + raise ParamsUnmatched(config.lang.args_type_error.format(target=input_.__class__)) + if not (mat := self._re_pattern.fullmatch(input_)): + raise ParamsUnmatched(config.lang.args_error.format(target=input_)) + res = analyse_args(self._args, list(mat.groupdict().values()), raise_exception=False) + if not res: + raise ParamsUnmatched(config.lang.args_error.format(target=input_)) + return self.origin(**res) def __call__(self, *args, **kwargs): return self.origin(*args, **kwargs) diff --git a/src/arclet/alconna/components/action.py b/src/arclet/alconna/components/action.py index 3f09f298..f2299f1d 100644 --- a/src/arclet/alconna/components/action.py +++ b/src/arclet/alconna/components/action.py @@ -2,10 +2,10 @@ from types import LambdaType from typing import Optional, Dict, List, Callable, Any, Sequence, TYPE_CHECKING, Union -from ..typing import AnyOne, AllParam, argument_type_validator -from ..lang import lang_config +from ..typing import AnyOne, AllParam, args_type_parser +from ..config import config from ..exceptions import InvalidParam -from ..manager import command_manager +from ..util import is_async from .behavior import ArpamarBehavior if TYPE_CHECKING: @@ -23,12 +23,7 @@ class ArgAction: action: Callable[..., Any] def __init__(self, action: Callable): - """ - ArgAction的构造函数 - - Args: - action: (...) -> Sequence - """ + """ArgAction的构造函数""" self.action = action def handle( @@ -50,8 +45,8 @@ def handle( varargs = varargs or [] kwargs = kwargs or {} try: - if inspect.iscoroutinefunction(self.action): - loop = command_manager.loop + if is_async(self.action): + loop = config.loop if loop.is_running(): loop.create_task(self.action(*option_dict.values(), *varargs, **kwargs)) return option_dict @@ -88,7 +83,7 @@ def __validator__(action: Union[Callable, "ArgAction", None], args: "Args"): if name not in ["self", "cls", "option_dict", "exception_in_time"] ] if len(argument) != len(args.argument): - raise InvalidParam(lang_config.action_length_error) + raise InvalidParam(config.lang.action_length_error) if not isinstance(action, LambdaType): for i, k in enumerate(args.argument): anno = argument[i][1] @@ -97,9 +92,9 @@ def __validator__(action: Union[Callable, "ArgAction", None], args: "Args"): value = args.argument[k]['value'] if value in (AnyOne, AllParam): continue - if value != argument_type_validator(anno, args.extra): - raise InvalidParam(lang_config.action_args_error.format( - target=argument[i][0], key=k, source=value.origin_type # type: ignore + if value != args_type_parser(anno, args.extra): + raise InvalidParam(config.lang.action_args_error.format( + target=argument[i][0], key=k, source=value.origin # type: ignore )) return ArgAction(action) diff --git a/src/arclet/alconna/components/duplication.py b/src/arclet/alconna/components/duplication.py index f86346ea..aea27e34 100644 --- a/src/arclet/alconna/components/duplication.py +++ b/src/arclet/alconna/components/duplication.py @@ -1,7 +1,7 @@ from typing import TYPE_CHECKING, cast, Optional from inspect import isclass -from ..lang import lang_config +from ..config import config from .stub import BaseStub, ArgsStub, SubcommandStub, OptionStub, Subcommand, Option if TYPE_CHECKING: @@ -53,7 +53,7 @@ def __init__(self, alconna: 'Alconna'): setattr(self, key, OptionStub(option)) # type: ignore self.__stubs__["options"].append(key) else: - raise TypeError(lang_config.duplication_stub_type_error.format(target=value)) + raise TypeError(config.lang.duplication_stub_type_error.format(target=value)) def __repr__(self): return f'<{self.__class__.__name__} with {self.__stubs__}>' @@ -72,15 +72,11 @@ def generate_duplication(command: "Alconna") -> AlconnaDuplication: options = filter(lambda x: isinstance(x, Option), command.options) subcommands = filter(lambda x: isinstance(x, Subcommand), command.options) return cast(AlconnaDuplication, type( - command.name.replace("ALCONNA::", "") + 'Interface', - (AlconnaDuplication,), - { + command.name.strip("/\\.-:") + 'Interface', + (AlconnaDuplication,), { "__annotations__": { **{"args": ArgsStub}, - **{ - opt.dest: OptionStub for opt in options - if opt.name.lstrip('-') not in ("help", "shortcut") - }, + **{opt.dest: OptionStub for opt in options if opt.name.lstrip('-') not in ("help", "shortcut")}, **{sub.dest: SubcommandStub for sub in subcommands}, } } diff --git a/src/arclet/alconna/components/output.py b/src/arclet/alconna/components/output.py index d4b32e10..9df79e49 100644 --- a/src/arclet/alconna/components/output.py +++ b/src/arclet/alconna/components/output.py @@ -8,10 +8,22 @@ from ..base import Option, Subcommand, Args, ArgUnit +class OutputAction(ArgAction): + output_text_call: Callable[[], str] + + def __init__(self, send_action, out_call, command=None): + super().__init__(send_action) + self.output_text_call = out_call + self.command = command + + def handle(self, option_dict=None, varargs=None, kwargs=None, is_raise_exception=False): + return super().handle({"help": self.output_text_call()}, varargs, kwargs, is_raise_exception) + + class OutputActionManager(metaclass=Singleton): """帮助信息""" cache: Dict[str, Callable] - outputs: Dict[str, "OutputAction"] + outputs: Dict[str, OutputAction] send_action: Callable[[str], Union[Any, Coroutine]] def __init__(self): @@ -26,11 +38,19 @@ def _clr(mgr: 'OutputActionManager'): finalize(self, _clr, self) - def set_send_action( - self, - action: Callable[[str], Any], - command: Optional[str] = None - ): + def get(self, command: str, output_call: Callable[[], str]) -> OutputAction: + """获取发送帮助信息的 action""" + if command not in self.outputs: + self.outputs[command] = OutputAction(self.send_action, output_call, command) + else: + self.outputs[command].output_text_call = output_call + + if command in self.cache: + self.outputs[command].action = self.cache[command] + del self.cache[command] + return self.outputs[command] + + def set_action(self, action: Callable[[str], Any], command: Optional[str] = None): """修改help_send_action""" if command is None: self.send_action = action @@ -43,33 +63,39 @@ def set_send_action( output_manager = OutputActionManager() -class OutputAction(ArgAction): - output_text_call: Callable[[], str] - - def __init__(self, out_call, command=None): - super().__init__(output_manager.send_action) - self.output_text_call = out_call - self.command = command +if TYPE_CHECKING: + from ..core import Alconna, AlconnaGroup - def handle(self, option_dict, varargs=None, kwargs=None, is_raise_exception=False): - return super().handle({"help": self.output_text_call()}, varargs, kwargs, is_raise_exception) +def resolve_requires(options: List[Union[Option, Subcommand]]): + reqs: Dict[str, Union[dict, Union[Option, Subcommand]]] = {} -def output_send(command: str, output_call: Callable[[], str]) -> OutputAction: - """帮助信息的发送 action""" - if command not in output_manager.outputs: - output_manager.outputs[command] = OutputAction(output_call, command) - else: - output_manager.outputs[command].output_text_call = output_call + def _u(target, source): + for k in source: + if k not in target or isinstance(target[k], (Option, Subcommand)): + target.update(source) + break + _u(target[k], source[k]) - if command in output_manager.cache: - output_manager.outputs[command].action = output_manager.cache[command] - del output_manager.cache[command] - return output_manager.outputs[command] + for opt in options: + if not opt.requires: + reqs.setdefault(opt.name, opt) + [reqs.setdefault(i, opt) for i in opt.aliases] if isinstance(opt, Option) else None + else: + _reqs = _cache = {} + for req in opt.requires: + if not _reqs: + _reqs[req] = {} + _cache = _reqs[req] + else: + _cache[req] = {} + _cache = _cache[req] + _cache[opt.name] = opt # type: ignore + [_cache.setdefault(i, opt) for i in opt.aliases] if isinstance(opt, Option) else None # type: ignore + _u(reqs, _reqs) -if TYPE_CHECKING: - from ..core import Alconna, AlconnaGroup + return reqs @dataclass @@ -90,12 +116,16 @@ class AbstractTextFormatter(metaclass=ABCMeta): 该格式化器负责将传入的命令节点字典解析并生成帮助文档字符串 """ + def __init__(self, base: Union['Alconna', 'AlconnaGroup']): self.data = [] def _handle(command: 'Alconna'): + hds = command.headers.copy() + if command.name in hds: + hds.remove(command.name) # type: ignore return Trace( - {'name': command.name, 'header': command.headers, 'description': command.help_text}, + {'name': command.name, 'header': hds or [''], 'description': command.help_text}, command.args, command.separators, command.options ) @@ -105,57 +135,67 @@ def _handle(command: 'Alconna'): else: self.data.append(_handle(cmd)) # type: ignore - def format_node(self, end: Optional[List[str]] = None): + def format_node(self, end: Optional[list] = None): """ 格式化命令节点 """ - end = end # TODO: 依据end确定起始位置 - res = '' - for trace in self.data: - res += self.format(trace) + '\n' - return res + + def _handle(trace: Trace): + if not end or end == ['']: + return self.format(trace) + _cache = resolve_requires(trace.body) + _parts = [] + for text in end: + if text in _cache: + _cache = _cache[text] + _parts.append(text) + if not isinstance(_cache, dict): + break + else: + return self.format(trace) + if isinstance(_cache, dict): + return self.format(Trace( + {"name": _parts[-1], 'header': [''], 'description': _parts[-1]}, Args(), trace.separators, + [Option(k, requires=_parts) if isinstance(i, dict) else i for k, i in _cache.items()] + )) + if isinstance(_cache, Option): + _hdr = [i for i in _cache.aliases if i != _cache.name] + return self.format(Trace( + {"name": _cache.name, "header": _hdr or [""], "description": _cache.help_text}, _cache.args, + _cache.separators, [] + )) + if isinstance(_cache, Subcommand): + return self.format(Trace( + {"name": _cache.name, "header": [""], "description": _cache.help_text}, _cache.args, + _cache.separators, _cache.options # type: ignore + )) + return self.format(trace) + + return "\n".join(map(_handle, self.data)) @abstractmethod def format(self, trace: Trace) -> str: - """ - help text的生成入口 - """ - pass + """help text的生成入口""" @abstractmethod - def param(self, name: str, parameter: ArgUnit) -> str: - """ - 对单个参数的描述 - """ - pass + def param(self, name: str, parameter: ArgUnit) -> str: + """对单个参数的描述""" @abstractmethod def parameters(self, args: Args) -> str: - """ - 参数列表的描述 - """ - pass + """参数列表的描述""" @abstractmethod def header(self, root: Dict[str, Any], separators: Set[str]) -> str: - """ - 头部节点的描述 - """ - pass + """头部节点的描述""" @abstractmethod def part(self, node: Union[Subcommand, Option]) -> str: - """ - 每个子节点的描述 - """ - pass + """每个子节点的描述""" @abstractmethod def body(self, parts: List[Union[Option, Subcommand]]) -> str: - """ - 子节点列表的描述 - """ - pass + """子节点列表的描述""" -__all__ = ["AbstractTextFormatter", "output_send", "output_manager", "Trace"] +__all__ = ["AbstractTextFormatter", "output_manager", "Trace"] diff --git a/src/arclet/alconna/components/stub.py b/src/arclet/alconna/components/stub.py index cccd9102..d973f301 100644 --- a/src/arclet/alconna/components/stub.py +++ b/src/arclet/alconna/components/stub.py @@ -4,7 +4,7 @@ from ..typing import BasePattern, AllParam from ..base import Args, Option, Subcommand, OptionResult, SubcommandResult -from ..lang import lang_config +from ..config import config T = TypeVar('T') T_Origin = TypeVar('T_Origin') @@ -48,7 +48,7 @@ def __init__(self, args: Args): if value['value'] is AllParam: self.__annotations__[key] = Any elif isinstance(value['value'], BasePattern): - self.__annotations__[key] = value['value'].origin_type + self.__annotations__[key] = value['value'].origin else: self.__annotations__[key] = value['value'] setattr(self, key, value['default']) @@ -96,7 +96,7 @@ def __getitem__(self, item): elif isinstance(item, int): return list(self.value.values())[item] else: - raise TypeError(lang_config.stub_key_error.format(target=item)) + raise TypeError(config.lang.stub_key_error.format(target=item)) class OptionStub(BaseStub[Option]): diff --git a/src/arclet/alconna/lang.py b/src/arclet/alconna/config.py similarity index 73% rename from src/arclet/alconna/lang.py rename to src/arclet/alconna/config.py index 18bdd55c..f581b678 100644 --- a/src/arclet/alconna/lang.py +++ b/src/arclet/alconna/config.py @@ -1,6 +1,7 @@ +import asyncio import json from pathlib import Path -from typing import Union, Dict, Final, Optional +from typing import Union, Dict, Final, Optional, Set class _LangConfig: @@ -51,7 +52,25 @@ def __getattr__(self, item: str) -> str: return self.__config[item] -lang_config = _LangConfig() -load_lang_file = lang_config.reload +class _AlconnaConfig: + lang: _LangConfig = _LangConfig() + loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() + namespace: str = "Alconna" + fuzzy_threshold: float = 0.6 + separators: Set[str] = {" "} + fuzzy_match: bool = False + raise_exception: bool = False + command_max_count: int = 200 + message_max_cache: int = 100 + enable_message_cache: bool = True -__all__ = ['lang_config', 'load_lang_file'] + @classmethod + def set_loop(cls, loop: asyncio.AbstractEventLoop) -> None: + """设置事件循环""" + cls.loop = loop + + +config = _AlconnaConfig() +load_lang_file = config.lang.reload + +__all__ = ['config', 'load_lang_file'] diff --git a/src/arclet/alconna/core.py b/src/arclet/alconna/core.py index ae2bdc4a..2d1bbb62 100644 --- a/src/arclet/alconna/core.py +++ b/src/arclet/alconna/core.py @@ -1,17 +1,17 @@ """Alconna 主体""" import sys from typing import Dict, List, Optional, Union, Type, Callable, Tuple, TypeVar, overload, TYPE_CHECKING, TypedDict, \ - Iterable + Iterable, Any -from .lang import lang_config +from .config import config from .analysis.base import compile from .base import CommandNode, Args, ArgAction, Option, Subcommand, HelpOption, ShortcutOption -from .typing import DataCollection +from .typing import DataCollection, BasePattern from .manager import command_manager from .arpamar import Arpamar from .components.action import ActionHandler from .components.output import AbstractTextFormatter -from .components.behavior import ArpamarBehavior, T_ABehavior +from .components.behavior import T_ABehavior from .components.duplication import AlconnaDuplication from .builtin.formatter import DefaultTextFormatter from .builtin.analyser import DefaultCommandAnalyser @@ -30,16 +30,16 @@ class _Actions(TypedDict): class AlconnaGroup(CommandNode): _group = True + commands: List['Alconna'] def __init__( self, name: str, *commands: "Alconna", namespace: Optional[str] = None, - ): self.commands = list(commands) - self.namespace = namespace + self.namespace = namespace or config.namespace name = command_manager.sign + name super().__init__(name, ) self.name.replace(command_manager.sign, '') @@ -75,22 +75,31 @@ def __union__(self, other: Union["AlconnaGroup", "Alconna"]): self.commands.append(other) return self + def reset_namespace(self, namespace: str): + """重新设置命名空间""" + command_manager.delete(self) + self.namespace = namespace + command_manager.register(self) + return self + def __iter__(self): yield from self.commands def __getitem__(self, item: str): try: return next(filter(lambda x: x.name == item, self.commands)) - except StopIteration: - raise KeyError(item) + except StopIteration as e: + raise KeyError(item) from e + + def __add__(self, other): + return self.__union__(other) - def parse(self, message: Union[str, DataCollection]): + def parse(self, message: DataCollection[Union[str, Any]]): res = None for command in self.commands: if (res := command.parse(message)).matched: return res - else: - return res + return res class Alconna(CommandNode): @@ -130,35 +139,30 @@ class Alconna(CommandNode): """ _group = False headers: Union[List[Union[str, object]], List[Tuple[object, str]]] - command: str + command: Union[str, type, BasePattern] options: List[Union[Option, Subcommand]] analyser_type: Type["Analyser"] formatter_type: Type[AbstractTextFormatter] namespace: str - behaviors: List[Union[ArpamarBehavior, Type[ArpamarBehavior]]] + behaviors: List[T_ABehavior] action_list: _Actions local_args = {} custom_types = {} global_headers: Union[List[Union[str, object]], List[Tuple[object, str]]] = [""] - global_behaviors: List[Union[ArpamarBehavior, Type[ArpamarBehavior]]] = [] + global_behaviors: List[T_ABehavior] = [] global_analyser_type: Type["Analyser"] = DefaultCommandAnalyser # type: ignore global_formatter_type: Type[AbstractTextFormatter] = DefaultTextFormatter # type: ignore - global_separators = {" "} - global_fuzzy_match: bool = False - global_raise_exception: bool = False @classmethod def config( cls, *, headers: Optional[Union[List[Union[str, object]], List[Tuple[object, str]]]] = None, - behaviors: Optional[List[Union[ArpamarBehavior, Type[ArpamarBehavior]]]] = None, + behaviors: Optional[List[T_ABehavior]] = None, analyser_type: Optional[Type["Analyser"]] = None, formatter_type: Optional[Type[AbstractTextFormatter]] = None, - separator: Optional[str] = None, - fuzzy_match: bool = False, - raise_exception: bool = False, + separator: Optional[str] = None ): """ 配置 Alconna 的默认属性 @@ -172,14 +176,12 @@ def config( if formatter_type is not None: cls.global_formatter_type = formatter_type if separator is not None: - cls.global_separators = {separator} - cls.global_fuzzy_match = fuzzy_match - cls.global_raise_exception = raise_exception + config.separators = {separator} return cls def __init__( self, - command: Optional[str] = None, + command: Optional[Union[str, type, BasePattern]] = None, main_args: Union[Args, str, None] = None, headers: Optional[Union[List[Union[str, object]], List[Tuple[object, str]]]] = None, options: Optional[List[Union[Option, Subcommand]]] = None, @@ -220,20 +222,20 @@ def __init__( f"{command_manager.sign}{command or self.headers[0]}", main_args, action=action, - separators=separators or self.__class__.global_separators.copy(), # type: ignore + separators=separators or config.separators.copy(), # type: ignore help_text=help_text or "Unknown Information" ) self.action_list = {"options": {}, "subcommands": {}, "main": None} - self.namespace = namespace or command_manager.default_namespace + self.namespace = namespace or config.namespace self.options.extend([HelpOption, ShortcutOption]) self.analyser_type = analyser_type or self.__class__.global_analyser_type self.behaviors = behaviors or self.__class__.global_behaviors.copy() self.behaviors.insert(0, ActionHandler()) self.formatter_type = formatter_type or self.__class__.global_formatter_type - self.is_fuzzy_match = is_fuzzy_match or self.__class__.global_fuzzy_match - self.is_raise_exception = is_raise_exception or self.__class__.global_raise_exception + self.is_fuzzy_match = is_fuzzy_match or config.fuzzy_match + self.is_raise_exception = is_raise_exception or config.raise_exception - command_manager.register(compile(self)) + command_manager.register(self) self.name = self.name.replace(command_manager.sign, "") def __union__(self, other: Union["Alconna", AlconnaGroup]) -> AlconnaGroup: @@ -254,7 +256,7 @@ def reset_namespace(self, namespace: str): """重新设置命名空间""" command_manager.delete(self) self.namespace = namespace - command_manager.register(compile(self)) + command_manager.register(self) return self def reset_behaviors(self, behaviors: List[T_ABehavior]): @@ -283,24 +285,21 @@ def shortcut( try: if delete: command_manager.delete_shortcut(short_key, self) - return lang_config.shortcut_delete_success.format( - shortcut=short_key, target=self.path.split(".")[-1]) + return config.lang.shortcut_delete_success.format(shortcut=short_key, target=self.path.split(".")[-1]) if command: command_manager.add_shortcut(self, short_key, command, expiration) - return lang_config.shortcut_add_success.format( - shortcut=short_key, target=self.path.split(".")[-1]) + return config.lang.shortcut_add_success.format(shortcut=short_key, target=self.path.split(".")[-1]) elif cmd := command_manager.recent_message: alc = command_manager.last_using if alc and alc == self: command_manager.add_shortcut(self, short_key, cmd, expiration) - return lang_config.shortcut_add_success.format( - shortcut=short_key, target=self.path.split(".")[-1]) + return config.lang.shortcut_add_success.format(shortcut=short_key, target=self.path.split(".")[-1]) raise ValueError( - lang_config.shortcut_recent_command_error.format( + config.lang.shortcut_recent_command_error.format( target=self.path, source=getattr(alc, "path", "Unknown Source")) ) else: - raise ValueError(lang_config.shortcut_no_recent_command) + raise ValueError(config.lang.shortcut_no_recent_command) except Exception as e: if self.is_raise_exception: raise e @@ -309,13 +308,12 @@ def shortcut( def __repr__(self): return f"<{self.namespace}::{self.name} with {len(self.options)} options; args={self.args}>" - def add_option( - self, - name: str, - *alias: str, - args: Optional[Args] = None, - sep: str = " ", - help_text: Optional[str] = None, + def add( + self, + name: str, *alias: str, + args: Optional[Args] = None, + sep: str = " ", + help_text: Optional[str] = None, ): """链式注册一个 Option""" command_manager.delete(self) @@ -323,46 +321,28 @@ def add_option( name, requires = names[-1], names[:-1] opt = Option(name, args, list(alias), separators=sep, help_text=help_text, requires=requires) self.options.append(opt) - command_manager.register(compile(self)) - return self - - def set_action(self, action: Union[Callable, str, ArgAction], custom_types: Optional[Dict[str, Type]] = None): # type: ignore - """设置针对main_args的action""" - if isinstance(action, str): - ns = {} - exec(action, getattr(self, "custom_types", custom_types), ns) - action: Callable = ns.popitem()[1] - self.action = ArgAction.__validator__(action, self.args) + command_manager.register(self) return self @overload def parse( - self, - message: Union[str, DataCollection], - duplication: Type[T_Duplication], - static: bool = True, - + self, message: DataCollection[Union[str, Any]], duplication: Type[T_Duplication], static: bool = True, ) -> T_Duplication: ... @overload def parse( - self, - message: Union[str, DataCollection], - duplication=None, - static: bool = True + self, message: DataCollection[Union[str, Any]], duplication=None, static: bool = True ) -> Arpamar: ... def parse( - self, - message: Union[str, DataCollection], - duplication: Optional[Type[T_Duplication]] = None, + self, message: DataCollection[Union[str, Any]], duplication: Optional[Type[T_Duplication]] = None, static: bool = True, ): """命令分析功能, 传入字符串或消息链, 返回一个特定的数据集合类""" analyser = command_manager.require(self) if static else compile(self) - analyser.process_message(message) + analyser.process(message) arp = analyser.analyse() if arp.matched: arp.execute() @@ -379,15 +359,17 @@ def __rtruediv__(self, other): return self def __rshift__(self, other): + if isinstance(other, Alconna): + return self.__union__(other) if isinstance(other, Option): command_manager.delete(self) self.options.append(other) - command_manager.register(compile(self)) + command_manager.register(self) elif isinstance(other, str): command_manager.delete(self) _part = other.split("/") self.options.append(Option(_part[0], _part[1] if len(_part) > 1 else None)) - command_manager.register(compile(self)) + command_manager.register(self) return self def __add__(self, other): diff --git a/src/arclet/alconna/exceptions.py b/src/arclet/alconna/exceptions.py index 90c5dcdf..c98853a2 100644 --- a/src/arclet/alconna/exceptions.py +++ b/src/arclet/alconna/exceptions.py @@ -13,8 +13,8 @@ class InvalidParam(Exception): """构造 alconna 时某个传入的参数不正确""" -class NullTextMessage(Exception): - """传入了不含有 text 的消息""" +class NullMessage(Exception): + """传入了无法解析的消息""" class UnexpectedElement(Exception): diff --git a/src/arclet/alconna/manager.py b/src/arclet/alconna/manager.py index f986833d..7c6f1724 100644 --- a/src/arclet/alconna/manager.py +++ b/src/arclet/alconna/manager.py @@ -1,20 +1,20 @@ """Alconna 负责记录命令的部分""" -import asyncio import weakref from datetime import datetime -from typing import TYPE_CHECKING, Dict, Optional, Union, List, Tuple +from typing import TYPE_CHECKING, Dict, Optional, Union, List, Tuple, Any import shelve +import contextlib from .exceptions import ExceedMaxCount from .util import Singleton, LruCache from .typing import DataCollection -from .lang import lang_config +from .config import config if TYPE_CHECKING: + from .analysis.analyser import Analyser from .core import Alconna, AlconnaGroup from .arpamar import Arpamar - from .analysis.analyser import Analyser class CommandManager(metaclass=Singleton): @@ -25,7 +25,6 @@ class CommandManager(metaclass=Singleton): """ sign: str - default_namespace: str current_count: int max_count: int @@ -33,44 +32,34 @@ class CommandManager(metaclass=Singleton): __analysers: Dict['Alconna', 'Analyser'] __abandons: List["Alconna"] __record: LruCache[int, "Arpamar"] - __shortcuts: LruCache[str, Union['Arpamar', Union[str, DataCollection]]] + __shortcuts: LruCache[str, Union['Arpamar', DataCollection[Union[str, Any]]]] def __init__(self): - self.loop = asyncio.new_event_loop() self.cache_path = f"{__file__.replace('manager.py', '')}manager_cache.db" - self.default_namespace = "Alconna" self.sign = "ALCONNA::" - self.max_count = 200 + self.max_count = config.command_max_count self.current_count = 0 self.__commands = {} self.__analysers = {} self.__abandons = [] self.__shortcuts = LruCache() - self.__record = LruCache(100) + self.__record = LruCache(config.message_max_cache) weakref.finalize(self, self.__del__) - def __del__(self): # td: save to file - try: + def __del__(self): + with contextlib.suppress(AttributeError): self.__commands.clear() self.__abandons.clear() self.__record.clear() self.__shortcuts.clear() Singleton.remove(self.__class__) - except AttributeError: - pass - - def set_loop(self, loop: asyncio.AbstractEventLoop) -> None: - """设置事件循环""" - self.loop = loop def load_cache(self) -> None: """加载缓存""" - try: + with contextlib.suppress(FileNotFoundError, KeyError): with shelve.open(self.cache_path) as db: - self.__shortcuts.update(db["shortcuts"]) # type: ignore - except (FileNotFoundError, KeyError): - pass + self.__shortcuts = db["shortcuts"] # type: ignore def dump_cache(self) -> None: """保存缓存""" @@ -82,41 +71,53 @@ def get_loaded_namespaces(self): """获取所有命名空间""" return list(self.__commands.keys()) - def _command_part(self, command: str) -> Tuple[str, str]: + @staticmethod + def _command_part(command: str) -> Tuple[str, str]: """获取命令的组成部分""" - command_parts = command.split(".")[-2:] + command_parts = command.split(".", maxsplit=1)[-2:] if len(command_parts) != 2: - command_parts.insert(0, self.default_namespace) + command_parts.insert(0, config.namespace) return command_parts[0], command_parts[1] - def register(self, delegate: "Analyser") -> None: + def register(self, command: Union["Alconna", "AlconnaGroup"]) -> None: """注册命令解析器, 会同时记录解析器对应的命令""" + from .analysis.base import compile if self.current_count >= self.max_count: raise ExceedMaxCount - self.__analysers[delegate.alconna] = delegate - if delegate.alconna.namespace not in self.__commands: - self.__commands[delegate.alconna.namespace] = {} - cid = delegate.alconna.name.replace(self.sign, "") - if _cmd := self.__commands[delegate.alconna.namespace].get(cid): - if _cmd == delegate.alconna: + if not command._group: # noqa + self.__analysers[command] = compile(command) # type: ignore + else: + for cmd in command.commands: # type: ignore + self.__analysers[cmd] = compile(cmd) + if command.namespace not in self.__commands: + self.__commands[command.namespace] = {} + cid = command.name.replace(self.sign, "") + if _cmd := self.__commands[command.namespace].get(cid): + if _cmd == command: return - _cmd.__union__(delegate.alconna) + _cmd.__union__(command) else: - self.__commands[delegate.alconna.namespace][cid] = delegate.alconna + self.__commands[command.namespace][cid] = command self.current_count += 1 def require(self, command: "Alconna") -> "Analyser": """获取命令解析器""" try: return self.__analysers[command] - except KeyError: + except KeyError as e: namespace, name = self._command_part(command.path) - raise ValueError(lang_config.manager_undefined_command.format(target=f"{namespace}.{name}")) + raise ValueError(config.lang.manager_undefined_command.format(target=f"{namespace}.{name}")) from e def delete(self, command: Union["Alconna", 'AlconnaGroup', str]) -> None: """删除命令""" namespace, name = self._command_part(command if isinstance(command, str) else command.path) try: + base = self.__commands[namespace][name] + if base._group: # noqa + for cmd in base.commands: # type: ignore + del self.__analysers[cmd] + else: + del self.__analysers[base] # type: ignore del self.__commands[namespace][name] self.current_count -= 1 finally: @@ -133,7 +134,7 @@ def set_enable(self, command: Union["Alconna", str]) -> None: if isinstance(command, str): namespace, name = self._command_part(command) if namespace not in self.__commands or name not in self.__commands[namespace]: - raise ValueError(lang_config.manager_undefined_command.format(target=command)) + raise ValueError(config.lang.manager_undefined_command.format(target=command)) temp = [cmd for cmd in self.__abandons if cmd.path == f"{namespace}.{name}"] for cmd in temp: self.__abandons.remove(cmd) @@ -144,7 +145,7 @@ def add_shortcut( self, target: Union["Alconna", str], shortcut: str, - source: Union["Arpamar", Union[str, DataCollection]], + source: Union["Arpamar", DataCollection[Union[str, Any]]], expiration: int = 0, ) -> None: """添加快捷命令""" @@ -152,12 +153,12 @@ def add_shortcut( namespace, name = self._command_part(target if isinstance(target, str) else target.path) try: _ = self.__commands[namespace][name] - except KeyError: - raise ValueError(lang_config.manager_undefined_command.format(target=f"{namespace}.{name}")) + except KeyError as e: + raise ValueError(config.lang.manager_undefined_command.format(target=f"{namespace}.{name}")) from e if isinstance(source, Arpamar) and source.matched or not isinstance(source, Arpamar): self.__shortcuts.set(f"{namespace}.{name}::{shortcut}", source, expiration) else: - raise ValueError(lang_config.manager_incorrect_shortcut.format(target=f"{shortcut}")) + raise ValueError(config.lang.manager_incorrect_shortcut.format(target=f"{shortcut}")) def find_shortcut(self, shortcut: str, target: Optional[Union["Alconna", str]] = None): """查找快捷命令""" @@ -165,22 +166,22 @@ def find_shortcut(self, shortcut: str, target: Optional[Union["Alconna", str]] = namespace, name = self._command_part(target if isinstance(target, str) else target.path) try: _ = self.__commands[namespace][name] - except KeyError: - raise ValueError(lang_config.manager_undefined_command.format(target=f"{namespace}.{name}")) + except KeyError as e: + raise ValueError(config.lang.manager_undefined_command.format(target=f"{namespace}.{name}")) from e try: return self.__shortcuts[f"{namespace}.{name}::{shortcut}"] - except KeyError: + except KeyError as e: raise ValueError( - lang_config.manager_target_command_error.format(target=f"{namespace}.{name}", shortcut=shortcut) - ) + config.lang.manager_target_command_error.format(target=f"{namespace}.{name}", shortcut=shortcut) + ) from e else: for key in self.__shortcuts: if key.split("::")[1] == shortcut: return self.__shortcuts.get(key) - raise ValueError(lang_config.manager_undefined_shortcut.format(target=f"{shortcut}")) + raise ValueError(config.lang.manager_undefined_shortcut.format(target=f"{shortcut}")) - def update_shortcut(self, random: bool = False): - return self.__shortcuts.update() if random else self.__shortcuts.update_all() + def update_shortcut(self): + return self.__shortcuts.update_all() def delete_shortcut(self, shortcut: str, target: Optional[Union["Alconna", str]] = None): """删除快捷命令""" @@ -198,10 +199,10 @@ def set_disable(self, command: Union["Alconna", str]) -> None: if isinstance(command, str): namespace, name = self._command_part(command) if namespace not in self.__commands or name not in self.__commands[namespace]: - raise ValueError(lang_config.manager_undefined_command.format(target=f"{namespace}.{name}")) + raise ValueError(config.lang.manager_undefined_command.format(target=f"{namespace}.{name}")) cmd = self.__commands[namespace][name] return ( - self.__abandons.extend(cmd.commands) # type: ignore + self.__abandons.extend(cmd.commands) # type: ignore if hasattr(cmd, 'commands') else self.__abandons.append(cmd) # type: ignore ) self.__abandons.append(command) @@ -213,15 +214,15 @@ def get_command(self, command: str) -> Union["Alconna", "AlconnaGroup", None]: return None return self.__commands[namespace][name] - def get_commands(self, namespace: Optional[str] = None) -> List[Union["Alconna", "AlconnaGroup"]]: + def get_commands(self, namespace: str = '') -> List[Union["Alconna", "AlconnaGroup"]]: """获取命令列表""" - if namespace is None: + if not namespace: return [ana for namespace in self.__commands for ana in self.__commands[namespace].values()] if namespace not in self.__commands: return [] - return [ana for ana in self.__commands[namespace].values()] + return list(self.__commands[namespace].values()) - def broadcast(self, message: Union[str, DataCollection], namespace: Optional[str] = None) -> Optional['Arpamar']: + def broadcast(self, message: DataCollection[Union[str, Any]], namespace: str = '') -> Optional['Arpamar']: """将一段命令广播给当前空间内的所有命令""" for cmd in self.get_commands(namespace): if (res := cmd.parse(message)) and res.matched: @@ -249,19 +250,19 @@ def all_command_help( max_length: 单个页面展示的最大长度 page: 当前页码 """ - header = header or lang_config.manager_help_header - pages = pages or lang_config.manager_help_pages - footer = footer or lang_config.manager_help_footer + header = header or config.lang.manager_help_header + pages = pages or config.lang.manager_help_pages + footer = footer or config.lang.manager_help_footer cmds = self.get_commands(namespace) if max_length < 1: command_string = "\n".join( - f" - {cmd.name} : {cmd.help_text.replace('Usage', ';').replace('Example', ';').split(';')[0]}" - for cmd in cmds - ) if not show_index else "\n".join( f" {str(index).rjust(len(str(len(cmds))), '0')} {slot.name} : " + slot.help_text.replace('Usage', ';').replace('Example', ';').split(';')[0] for index, slot in enumerate(cmds) + ) if show_index else "\n".join( + f" - {cmd.name} : {cmd.help_text.replace('Usage', ';').replace('Example', ';').split(';')[0]}" + for cmd in cmds ) else: max_page = len(cmds) // max_length + 1 @@ -269,14 +270,14 @@ def all_command_help( page = 1 header += "\t" + pages.format(current=page, total=max_page) command_string = "\n".join( - f" - {cmd.name} : {cmd.help_text.replace('Usage', ';').replace('Example', ';').split(';')[0]}" - for cmd in cmds[(page - 1) * max_length: page * max_length] - ) if not show_index else "\n".join( f" {str(index).rjust(len(str(page * max_length)), '0')} {cmd.name} : " f"{cmd.help_text.replace('Usage', ';').replace('Example', ';').split(';')[0]}" for index, cmd in enumerate( cmds[(page - 1) * max_length: page * max_length], start=(page - 1) * max_length ) + ) if show_index else "\n".join( + f" - {cmd.name} : {cmd.help_text.replace('Usage', ';').replace('Example', ';').split(';')[0]}" + for cmd in cmds[(page - 1) * max_length: page * max_length] ) return f"{header}\n{command_string}\n{footer}" @@ -286,12 +287,7 @@ def command_help(self, command: str) -> Optional[str]: if cmd := self.get_command(f"{command_parts[0]}.{command_parts[1]}"): return cmd.get_help() - def record( - self, - token: int, - message: Union[str, DataCollection], - result: "Arpamar" - ): + def record(self, token: int, message: DataCollection[Union[str, Any]], result: "Arpamar"): result.origin = message self.__record.set(token, result) @@ -301,7 +297,7 @@ def get_record(self, token: int) -> Optional["Arpamar"]: return self.__record.get(token) @property - def recent_message(self) -> Optional[Union[str, DataCollection]]: + def recent_message(self) -> Optional[DataCollection[Union[str, Any]]]: if rct := self.__record.recent: return rct.origin diff --git a/src/arclet/alconna/typing.py b/src/arclet/alconna/typing.py index 6b6fa6e0..a3387d18 100644 --- a/src/arclet/alconna/typing.py +++ b/src/arclet/alconna/typing.py @@ -1,114 +1,94 @@ """Alconna 参数相关""" import re +import sre_compile import inspect -from copy import copy -from collections.abc import ( - Iterable as ABCIterable, - Sequence as ABCSequence, - Set as ABCSet, - MutableSet as ABCMutableSet, - MutableSequence as ABCMutableSequence, - MutableMapping as ABCMutableMapping, - Mapping as ABCMapping, -) +from datetime import datetime +from copy import deepcopy +from collections.abc import Sequence as ABCSeq, Set as ABCSet, \ + MutableSet as ABCMuSet, MutableSequence as ABCMuSeq, MutableMapping as ABCMuMap, Mapping as ABCMap +from contextlib import suppress from functools import lru_cache from pathlib import Path from enum import IntEnum -from typing import TypeVar, Type, Callable, Optional, Protocol, Any, Pattern, Union, Sequence, \ - List, Dict, get_args, Literal, Tuple, get_origin, Iterable, Generic +from types import FunctionType, LambdaType, MethodType +from typing import TypeVar, Type, Callable, Optional, Protocol, Any, Pattern, Union, List, Dict, \ + Literal, Tuple, Iterable, Generic, Iterator, runtime_checkable try: - from typing import Annotated # type: ignore + from typing import Annotated, get_args, get_origin # type: ignore except ImportError: - from typing_extensions import Annotated + from typing_extensions import Annotated, get_args, get_origin from .exceptions import ParamsUnmatched -from .lang import lang_config +from .config import config from .util import generic_isinstance DataUnit = TypeVar("DataUnit", covariant=True) GenericAlias = type(List[int]) -AnnotatedAlias = type(Annotated[int, lambda x: x > 0]) +TPattern: Type[Pattern] = type(sre_compile.compile('', 0)) +@runtime_checkable class DataCollection(Protocol[DataUnit]): """数据集合协议""" - - def __str__(self) -> str: - ... - - def __iter__(self) -> DataUnit: - ... - - def __len__(self) -> int: - ... + def __str__(self) -> str: ... + def __iter__(self) -> Iterator[DataUnit]: ... + def __len__(self) -> int: ... class PatternModel(IntEnum): """ 参数表达式匹配模式 """ - REGEX_CONVERT = 3 """正则匹配并转换""" - TYPE_CONVERT = 2 """传入值直接转换""" - REGEX_MATCH = 1 """正则匹配""" - KEEP = 0 """保持传入值""" class _All: """泛匹配""" - - __slots__ = () - def __repr__(self): return "AllParam" AllParam = _All() Empty = inspect.Signature.empty - TOrigin = TypeVar("TOrigin") class BasePattern(Generic[TOrigin]): - """ - 对参数类型值的包装 - """ - - regex_pattern: Pattern + """对参数类型值的包装""" + regex_pattern: TPattern # type: ignore pattern: str model: PatternModel converter: Callable[[Union[str, Any]], TOrigin] - validator: Callable[[TOrigin], bool] + validators: List[Callable[[TOrigin], bool]] anti: bool - origin_type: Type[TOrigin] + origin: Type[TOrigin] accepts: Optional[List[Type]] alias: Optional[str] previous: Optional["BasePattern"] __slots__ = ( - "regex_pattern", "pattern", "model", "converter", "anti", - "origin_type", "accepts", "alias", "previous", "validator" + "regex_pattern", "pattern", "model", "converter", "anti", "origin", "accepts", "alias", "previous", "validators" ) def __init__( self, pattern: str = "(.+?)", model: PatternModel = PatternModel.REGEX_MATCH, - origin_type: Type[TOrigin] = str, + origin: Type[TOrigin] = str, converter: Optional[Callable[[Union[str, Any]], TOrigin]] = None, alias: Optional[str] = None, previous: Optional["BasePattern"] = None, accepts: Optional[List[Type]] = None, - validator: Optional[Callable[[TOrigin], bool]] = None, + validators: Optional[List[Callable[[TOrigin], bool]]] = None, anti: bool = False ): """ @@ -117,31 +97,29 @@ def __init__( self.pattern = pattern self.regex_pattern = re.compile(f"^{pattern}$") self.model = model - self.origin_type = origin_type + self.origin = origin self.alias = alias self.previous = previous self.accepts = accepts - if converter: - self.converter = converter - elif model == PatternModel.TYPE_CONVERT: - self.converter = lambda x: origin_type(x) - else: - self.converter = lambda x: eval(x) - self.validator = validator or (lambda x: True) + self.converter = converter or (lambda x: origin(x) if model == PatternModel.TYPE_CONVERT else eval(x)) + self.validators = validators or [] self.anti = anti def __repr__(self): if self.model == PatternModel.KEEP: - return ('|'.join(x.__name__ for x in self.accepts)) if self.accepts else 'Any' + return self.alias or (('|'.join(x.__name__ for x in self.accepts)) if self.accepts else 'Any') + name = self.alias or self.origin.__name__ if self.model == PatternModel.REGEX_MATCH: text = self.alias or self.pattern elif self.model == PatternModel.REGEX_CONVERT: - text = self.alias or self.origin_type.__name__ + text = name else: - text = f"{(('|'.join(x.__name__ for x in self.accepts)) + ' -> ') if self.accepts else ''}" \ - f"{self.alias or self.origin_type.__name__}" + text = (('|'.join(x.__name__ for x in self.accepts) + ' -> ') if self.accepts else '') + name return f"{(f'{self.previous.__repr__()}, ' if self.previous else '')}{'!' if self.anti else ''}{text}" + def __str__(self): + return self.__repr__() + def __hash__(self): return hash(self.__repr__()) @@ -150,10 +128,13 @@ def __eq__(self, other): @classmethod def of(cls, unit: Type[TOrigin]): - """ - 提供原来 TAValue 中的 Type[DataUnit] 类型的构造方法 - """ - return cls(origin_type=unit, accepts=[unit], model=PatternModel.KEEP, alias=unit.__name__) + """提供原来 TAValue 中的 Type[DataUnit] 类型的构造方法""" + return cls('', PatternModel.KEEP, unit, alias=unit.__name__, accepts=[unit]) + + @classmethod + def on(cls, obj: TOrigin): + """提供原来 TAValue 中的 DataUnit 类型的构造方法""" + return cls('', PatternModel.KEEP, type(obj), alias=str(obj), validators=[lambda x: x == obj]) def reverse(self): self.anti = not self.anti @@ -163,45 +144,48 @@ def match(self, input_: Union[str, Any]) -> TOrigin: """ 对传入的参数进行匹配, 如果匹配成功, 则返回转换后的值, 否则返回None """ - if self.model > 1 and generic_isinstance(input_, self.origin_type): + if self.model > 0 and self.origin not in (str, Any) and generic_isinstance(input_, self.origin): return input_ # type: ignore if self.accepts and not isinstance(input_, tuple(self.accepts)): - if not self.previous: - raise ParamsUnmatched(lang_config.args_type_error.format(target=input_.__class__)) - input_ = self.previous.match(input_) + if not self.previous or not isinstance(input_ := self.previous.match(input_), tuple(self.accepts)): + raise ParamsUnmatched(config.lang.args_type_error.format(target=input_.__class__)) if self.model == PatternModel.KEEP: return input_ # type: ignore if self.model == PatternModel.TYPE_CONVERT: res = self.converter(input_) - if not generic_isinstance(res, self.origin_type): - raise ParamsUnmatched(lang_config.args_error.format(target=input_)) + if not generic_isinstance(res, self.origin) or (not res and self.origin == Any): + raise ParamsUnmatched(config.lang.args_error.format(target=input_)) return res if not isinstance(input_, str): - raise ParamsUnmatched(lang_config.args_type_error.format(target=type(input_))) + if not self.previous or not isinstance(input_ := self.previous.match(input_), str): + raise ParamsUnmatched(config.lang.args_type_error.format(target=type(input_))) if r := self.regex_pattern.findall(input_): return self.converter(r[0]) if self.model == PatternModel.REGEX_CONVERT else r[0] - raise ParamsUnmatched(lang_config.args_error.format(target=input_)) + raise ParamsUnmatched(config.lang.args_error.format(target=input_)) def validate(self, input_: Union[str, Any], default: Optional[Any] = None) -> Tuple[Any, Literal["V", "E", "D"]]: - if not self.anti: - try: - res = self.match(input_) - if self.validator(res): - return res, "V" - raise ParamsUnmatched(lang_config.args_error.format(target=input_)) - except Exception as e: - if default is None: - return e, "E" - return None if default is Empty else default, "D" + try: + res = self.match(input_) + for i in self.validators: + if not i(res): + raise ParamsUnmatched(config.lang.args_error.format(target=input_)) + return res, "V" + except Exception as e: + if default is None: + return e, "E" + return None if default is Empty else default, "D" + + def invalidate(self, input_: Union[str, Any], default: Optional[Any] = None) -> Tuple[Any, Literal["V", "E", "D"]]: try: res = self.match(input_) except ParamsUnmatched: return input_, "V" else: - if not self.validator(res): - return input_, "E" + for i in self.validators: + if not i(res): + return input_, "E" if default is None: - return ParamsUnmatched(lang_config.args_error.format(target=input_)), "E" + return ParamsUnmatched(config.lang.args_error.format(target=input_)), "E" return None if default is Empty else default, "D" @@ -214,6 +198,10 @@ def validate(self, input_: Union[str, Any], default: Optional[Any] = None) -> Tu _Url = BasePattern(r"[\w]+://[^/\s?#]+[^\s?#]+(?:\?[^\s#]*)?(?:#[^\s]*)?", alias="url") _HexLike = BasePattern(r"((?:0x)?[0-9a-fA-F]+)", PatternModel.REGEX_CONVERT, int, lambda x: int(x, 16), "hex") _HexColor = BasePattern(r"(#[0-9a-fA-F]{6})", PatternModel.REGEX_CONVERT, str, lambda x: x[1:], "color") +_Datetime = BasePattern( + model=PatternModel.TYPE_CONVERT, origin=datetime, alias='datetime', accepts=[str, int], + converter=lambda x: datetime.fromtimestamp(x) if isinstance(x, int) else datetime.fromisoformat(x) +) class MultiArg(BasePattern): @@ -221,45 +209,30 @@ class MultiArg(BasePattern): flag: str array_length: Optional[int] - def __init__( - self, - base: BasePattern, - flag: Literal['args', 'kwargs'] = 'args', - array_length: Optional[int] = None, - ): - alias_content = base.alias or base.origin_type.__name__ + def __init__(self, base: BasePattern, flag: Literal['args', 'kwargs'] = 'args', length: Optional[int] = None): self.flag = flag - self.array_length = array_length + self.array_length = length if flag == 'args': - _t = Tuple[base.origin_type, ...] - alias = f"*{alias_content}[:{array_length}]" if array_length else f"*{alias_content}" + _t = Tuple[base.origin, ...] + alias = f"*{base}[:{length}]" if length else f"*{base}" else: - _t = Dict[str, base.origin_type] - alias = f"**{alias_content}[:{array_length}]" if array_length else f"**{alias_content}" + _t = Dict[str, base.origin] + alias = f"**{base}[:{length}]" if length else f"**{base}" super().__init__( - base.pattern, base.model, _t, - alias=alias, converter=base.converter, previous=base.previous, accepts=base.accepts + base.pattern, base.model, _t, base.converter, alias, base.previous, base.accepts, base.validators ) - def __repr__(self): - ctn = super().__repr__() - if self.flag == 'args': - return f"{ctn}[{self.array_length}]" if self.array_length else f"({ctn}, ...)" - elif self.flag == 'kwargs': - return f"{{KEY={ctn}, ...}}" - class UnionArg(BasePattern): """多类型参数的匹配""" optional: bool - arg_value: Sequence[Union[BasePattern, object, str]] + arg_value: List[Union[BasePattern, object, str]] for_validate: List[BasePattern] for_equal: List[Union[str, object]] - def __init__(self, base: Sequence[Union[BasePattern, object, str]], anti: bool = False): - self.arg_value = base + def __init__(self, base: Iterable[Union[BasePattern, object, str]], anti: bool = False): + self.arg_value = list(base) self.optional = False - self.for_validate = [] self.for_equal = [] @@ -277,37 +250,16 @@ def __init__(self, base: Sequence[Union[BasePattern, object, str]], anti: bool = def match(self, text: Union[str, Any]): if not text: text = None - if self.anti: - validate = False - equal = text in self.for_equal + if text not in self.for_equal: for pat in self.for_validate: - try: - pat.match(text) - validate = True - break - except ParamsUnmatched: - continue - if validate or equal: - raise ParamsUnmatched(lang_config.args_error.format(target=text)) - return text - not_match = True - not_equal = text not in self.for_equal - if not_equal: - for pat in self.for_validate: - try: - text = pat.match(text) - not_match = False - break - except (ParamsUnmatched, TypeError): - continue - if not_match and not_equal: - raise ParamsUnmatched(lang_config.args_error.format(target=text)) + res, v = pat.validate(text) + if v == 'V': + return res + raise ParamsUnmatched(config.lang.args_error.format(target=text)) return text def __repr__(self): - return ("!" if self.anti else "") + ("|".join( - [repr(a) for a in self.for_validate] + [repr(a) for a in self.for_equal] - )) + return ("!" if self.anti else "") + ("|".join(repr(a) for a in (*self.for_validate, *self.for_equal))) class SequenceArg(BasePattern): @@ -319,27 +271,25 @@ def __init__(self, base: BasePattern, form: str = "list"): if base is AnyOne: base = _String self.form = form - alias_content = base.alias or base.origin_type.__name__ self.arg_value = base if form == "list": - super().__init__(r"\[(.+?)\]", PatternModel.REGEX_MATCH, list, alias=f"List[{alias_content}]") + super().__init__(r"\[(.+?)\]", PatternModel.REGEX_MATCH, list, alias=f"list[{base}]") elif form == "tuple": - super().__init__(r"\((.+?)\)", PatternModel.REGEX_MATCH, tuple, alias=f"Tuple[{alias_content}]") + super().__init__(r"\((.+?)\)", PatternModel.REGEX_MATCH, tuple, alias=f"tuple[{base}]") elif form == "set": - super().__init__(r"\{(.+?)\}", PatternModel.REGEX_MATCH, set, alias=f"Set[{alias_content}]") + super().__init__(r"\{(.+?)\}", PatternModel.REGEX_MATCH, set, alias=f"set[{base}]") else: - raise ValueError(lang_config.types_sequence_form_error.format(target=form)) + raise ValueError(config.lang.types_sequence_form_error.format(target=form)) def match(self, text: Union[str, Any]): _res = super().match(text) - sequence = re.split(r"\s*,\s*", _res) if isinstance(_res, str) else _res result = [] - for s in sequence: + for s in (re.split(r"\s*,\s*", _res) if isinstance(_res, str) else _res): try: result.append(self.arg_value.match(s)) - except ParamsUnmatched: - raise ParamsUnmatched(f"{s} is not matched with {self.arg_value}") - return self.origin_type(result) + except ParamsUnmatched as e: + raise ParamsUnmatched(f"{s} is not matched with {self.arg_value}") from e + return self.origin(result) def __repr__(self): return f"{self.form}[{self.arg_value}]" @@ -353,10 +303,7 @@ class MappingArg(BasePattern): def __init__(self, arg_key: BasePattern, arg_value: BasePattern): self.arg_key = arg_key self.arg_value = arg_value - - alias_content = f"{self.arg_key.alias or self.arg_key.origin_type.__name__}, " \ - f"{self.arg_value.alias or self.arg_value.origin_type.__name__}" - super().__init__(r"\{(.+?)\}", PatternModel.REGEX_MATCH, dict, alias=f"Dict[{alias_content}]") + super().__init__(r"\{(.+?)\}", PatternModel.REGEX_MATCH, dict, alias=f"dict[{self.arg_key}, {self.arg_value}]") def match(self, text: Union[str, Any]): _res = super().match(text) @@ -366,37 +313,32 @@ def _generator_items(res: Union[str, Dict]): if isinstance(res, dict): return res.items() for m in re.split(r"\s*,\s*", res): - _k, _v = re.split(r"\s*[:=]\s*", m) - yield _k, _v + yield re.split(r"\s*[:=]\s*", m) for k, v in _generator_items(_res): try: real_key = self.arg_key.match(k) - except ParamsUnmatched: - raise ParamsUnmatched(f"{k} is not matched with {self.arg_key}") + except ParamsUnmatched as e: + raise ParamsUnmatched(f"{k} is not matched with {self.arg_key}") from e try: arg_find = self.arg_value.match(v) - except ParamsUnmatched: - raise ParamsUnmatched(f"{v} is not matched with {self.arg_value}") + except ParamsUnmatched as e: + raise ParamsUnmatched(f"{v} is not matched with {self.arg_value}") from e result[real_key] = arg_find return result def __repr__(self): - return f"dict[{self.arg_key.origin_type.__name__}, {self.arg_value}]" + return f"dict[{self.arg_key.origin.__name__}, {self.arg_value}]" pattern_map = { Any: AnyOne, Ellipsis: AnyOne, object: AnyOne, "email": _Email, "color": _HexColor, - "hex": _HexLike, "ip": _IP, "url": _Url, "...": AnyOne, "*": AllParam, "": Empty + "hex": _HexLike, "ip": _IP, "url": _Url, "...": AnyOne, "*": AllParam, "": Empty, "datetime": _Datetime } -def set_converter( - target: BasePattern, - alias: Optional[str] = None, - cover: bool = False -): +def set_converter(target: BasePattern, alias: Optional[str] = None, cover: bool = False, data: Optional[dict] = None): """ 增加 Alconna 内使用的类型转换器 @@ -404,131 +346,123 @@ def set_converter( target: 设置的表达式 alias: 目标类型的别名 cover: 是否覆盖已有的转换器 + data: BasePattern的存储字典 """ - for k in (alias, target.alias, target.origin_type): - if k not in pattern_map or cover: - pattern_map[k] = target + data = data or pattern_map + for k in (alias, target.alias, target.origin): + if not k: + continue + if k not in data or cover: + data[k] = target else: - al_pat = pattern_map[k] - pattern_map[k] = UnionArg([*al_pat.arg_value, target]) if isinstance(al_pat, UnionArg) else ( + al_pat = data[k] + data[k] = UnionArg([*al_pat.arg_value, target]) if isinstance(al_pat, UnionArg) else ( UnionArg([al_pat, target]) ) def set_converters( patterns: Union[Iterable[BasePattern], Dict[str, BasePattern]], - cover: bool = False + cover: bool = False, data: Optional[dict] = None ): for arg_pattern in patterns: if isinstance(patterns, Dict): - set_converter(patterns[arg_pattern], alias=arg_pattern, cover=cover) # type: ignore + set_converter(patterns[arg_pattern], alias=arg_pattern, cover=cover, data=data) # type: ignore else: - set_converter(arg_pattern, cover=cover) # type: ignore - + set_converter(arg_pattern, cover=cover, data=data) # type: ignore -def remove_converter(origin_type: type, alias: Optional[str] = None): - """ - :param origin_type: - :param alias: - :return: - """ - if alias and (al_pat := pattern_map.get(alias)): +def remove_converter(origin_type: type, alias: Optional[str] = None, data: Optional[dict] = None): + data = data or pattern_map + if alias and (al_pat := data.get(alias)): if isinstance(al_pat, UnionArg): - pattern_map[alias] = UnionArg(list(filter(lambda x: x.alias != alias, al_pat.arg_value))) # type: ignore + data[alias] = UnionArg(filter(lambda x: x.alias != alias, al_pat.arg_value)) # type: ignore else: - del pattern_map[alias] - elif al_pat := pattern_map.get(origin_type): + del data[alias] + elif al_pat := data.get(origin_type): if isinstance(al_pat, UnionArg): - pattern_map[origin_type] = UnionArg( - list(filter(lambda x: x.origin_type != origin_type, al_pat.arg_value)) # type: ignore - ) + data[origin_type] = UnionArg(filter(lambda x: x.origin != origin_type, al_pat.for_validate)) else: - del pattern_map[origin_type] + del data[origin_type] -StrPath = BasePattern(model=PatternModel.TYPE_CONVERT, origin_type=Path, alias="path", accepts=[str]) +StrPath = BasePattern(model=PatternModel.TYPE_CONVERT, origin=Path, alias="path", accepts=[str]) AnyPathFile = BasePattern( - model=PatternModel.TYPE_CONVERT, origin_type=bytes, alias="file", accepts=[Path], previous=StrPath, + model=PatternModel.TYPE_CONVERT, origin=bytes, alias="file", accepts=[Path], previous=StrPath, converter=lambda x: x.read_bytes() if x.exists() and x.is_file() else None # type: ignore ) _Digit = BasePattern(r"(\-?\d+)", PatternModel.REGEX_CONVERT, int, lambda x: int(x), "int") _Float = BasePattern(r"(\-?\d+\.?\d*)", PatternModel.REGEX_CONVERT, float, lambda x: float(x), "float") -_Bool = BasePattern(r"(True|False|true|false)", PatternModel.REGEX_CONVERT, bool, lambda x: x.lower() == "true", "bool") +_Bool = BasePattern(r"(?i:True|False)", PatternModel.REGEX_CONVERT, bool, lambda x: x.lower() == "true", "bool") _List = BasePattern(r"(\[.+?\])", PatternModel.REGEX_CONVERT, list, alias="list") _Tuple = BasePattern(r"(\(.+?\))", PatternModel.REGEX_CONVERT, tuple, alias="tuple") _Set = BasePattern(r"(\{.+?\})", PatternModel.REGEX_CONVERT, set, alias="set") _Dict = BasePattern(r"(\{.+?\})", PatternModel.REGEX_CONVERT, dict, alias="dict") - set_converters([AnyPathFile, _String, _Digit, _Float, _Bool, _List, _Tuple, _Set, _Dict]) -def pattern_gen(name: str, re_pattern: str): - """便捷地设置转换器""" - - def __wrapper(func): - return BasePattern(re_pattern, PatternModel.REGEX_CONVERT, converter=func, alias=name) - - return __wrapper - - -def argument_type_validator(item: Any, extra: str = "allow"): +def args_type_parser(item: Any, extra: str = "allow"): """对 Args 里参数类型的检查, 将一般数据类型转为 Args 使用的类型""" if isinstance(item, (BasePattern, _All)): return item - try: + with suppress(TypeError): if pat := pattern_map.get(item, None): return pat - except TypeError: - pass if not inspect.isclass(item) and isinstance(item, GenericAlias): - if isinstance(item, AnnotatedAlias): - _o = argument_type_validator(item.__origin__, extra) # type: ignore - if not isinstance(_o, BasePattern): + origin = get_origin(item) + if origin is Annotated: + org, meta = get_args(item) + if not isinstance(_o := args_type_parser(org, extra), BasePattern): # type: ignore return _o - _arg = copy(_o) - _arg.validator = lambda x: all(i(x) for i in item.__metadata__) + _arg = deepcopy(_o) + _arg.validators.extend(meta if isinstance(meta, tuple) else [meta]) # type: ignore return _arg - origin = get_origin(item) if origin in (Union, Literal): - _args = list({argument_type_validator(t, extra) for t in get_args(item)}) - return (_args[0] if len(_args) == 1 else UnionArg(_args)) if _args else item - if origin in (dict, ABCMapping, ABCMutableMapping): - arg_key = argument_type_validator(get_args(item)[0], 'ignore') - arg_value = argument_type_validator(get_args(item)[1], 'allow') + _args = {args_type_parser(t, extra) for t in get_args(item)} + return (_args.pop() if len(_args) == 1 else UnionArg(_args)) if _args else item + if origin in (dict, ABCMap, ABCMuMap): + arg_key = args_type_parser(get_args(item)[0], 'ignore') + arg_value = args_type_parser(get_args(item)[1], 'allow') if isinstance(arg_value, list): arg_value = UnionArg(arg_value) return MappingArg(arg_key=arg_key, arg_value=arg_value) - args = argument_type_validator(get_args(item)[0], 'allow') + args = args_type_parser(get_args(item)[0], 'allow') if isinstance(args, list): args = UnionArg(args) - if origin in (ABCMutableSequence, list): + if origin in (ABCMuSeq, list): return SequenceArg(args) - if origin in (ABCSequence, ABCIterable, tuple): + if origin in (ABCSeq, tuple): return SequenceArg(args, form="tuple") - if origin in (ABCMutableSet, ABCSet, set): + if origin in (ABCMuSet, ABCSet, set): return SequenceArg(args, form="set") - return BasePattern("", PatternModel.KEEP, origin, alias=f"{repr(item).split('.')[-1]}", accepts=[origin]) - - if isinstance(item, (list, tuple, set)): - return UnionArg(list(map(argument_type_validator, item))) - if isinstance(item, dict): - return MappingArg( - arg_key=argument_type_validator(list(item.keys())[0], 'ignore'), - arg_value=argument_type_validator(list(item.values())[0], 'allow') + return BasePattern("", 0, origin, alias=f"{repr(item).split('.')[-1]}", accepts=[origin]) # type: ignore + if isinstance(item, (FunctionType, MethodType, LambdaType)): + if len((sig := inspect.signature(item)).parameters) != 1: + raise TypeError(f"{item} can only accept 1 argument") + anno = list(sig.parameters.values())[0].annotation + return BasePattern( + accepts=[] if anno == Empty else list(_) if (_ := get_args(anno)) else [anno], converter=item, + origin=Any if sig.return_annotation == Empty else sig.return_annotation, model=PatternModel.TYPE_CONVERT ) - if item is None or type(None) == item: - return Empty if isinstance(item, str): + if "|" in item: + names = item.split("|") + return UnionArg(args_type_parser(i) for i in names if i) return BasePattern(item, alias=f"\'{item}\'") + if isinstance(item, (list, tuple, set, ABCSeq, ABCMuSeq, ABCSet, ABCMuSet)): # Args[foo, [123, int]] + return UnionArg(map(args_type_parser, item)) + if isinstance(item, (dict, ABCMap, ABCMuMap)): # Args[foo, {'foo': 'bar'}] + return BasePattern(model=PatternModel.TYPE_CONVERT, origin=Any, converter=lambda x: item.get(x, None)) + if item is None or type(None) == item: + return Empty if extra == "ignore": return AnyOne elif extra == "reject": - raise TypeError(lang_config.types_validate_reject.format(target=item)) + raise TypeError(config.lang.types_validate_reject.format(target=item)) if inspect.isclass(item): return BasePattern.of(item) - return item + return BasePattern.on(item) class Bind: @@ -540,25 +474,23 @@ def __new__(cls, *args, **kwargs): @classmethod @lru_cache(maxsize=None) def __class_getitem__(cls, params): - if not isinstance(params, tuple) or len(params) != 2: - raise TypeError( - "Bind[...] should be used with only two arguments (a type and an annotation)." - ) - if not (pattern := pattern_map.get(params[0]) if not isinstance(params[0], BasePattern) else params[0]): - raise ValueError( - "Bind[...] first argument should be a BasePattern." - ) - if not callable(params[1]): - raise TypeError( - "Bind[...] second argument should be a callable." - ) - pattern = copy(pattern) - pattern.validator = params[1] + if not isinstance(params, tuple) or len(params) < 2: + raise TypeError("Bind[...] should be used with only two arguments (a type and an annotation).") + if not (pattern := params[0] if isinstance(params[0], BasePattern) else pattern_map.get(params[0])): + raise ValueError("Bind[...] first argument should be a BasePattern.") + if not all(callable(i) for i in params[1:]): + raise TypeError("Bind[...] second argument should be a callable.") + pattern = deepcopy(pattern) + pattern.validators.extend(params[1:]) return pattern +def set_unit(target: Type[TOrigin], predicate: Callable[..., bool]) -> Annotated[TOrigin, ...]: + return Annotated[target, predicate] + + __all__ = [ - "DataUnit", "DataCollection", "Empty", "AnyOne", "AllParam", "PatternModel", - "BasePattern", "MultiArg", "SequenceArg", "UnionArg", "MappingArg", "Bind", - "pattern_gen", "pattern_map", "set_converter", "set_converters", "remove_converter", "argument_type_validator" + "DataCollection", "Empty", "AnyOne", "AllParam", "PatternModel", "BasePattern", "MultiArg", "UnionArg", "Bind", + "pattern_map", "set_converter", "set_converters", "remove_converter", "args_type_parser", "set_unit", + "SequenceArg", "MappingArg", "TPattern" ] diff --git a/src/arclet/alconna/util.py b/src/arclet/alconna/util.py index 6bbb300e..a653a517 100644 --- a/src/arclet/alconna/util.py +++ b/src/arclet/alconna/util.py @@ -1,14 +1,20 @@ """杂物堆""" import contextlib -import random - +import inspect +from functools import lru_cache from collections import OrderedDict from datetime import datetime, timedelta -from typing import TypeVar, Optional, Dict, Any, Iterator, Generic, Hashable, Tuple, Set, Union, get_origin, get_args +from typing import TypeVar, Optional, Dict, Any, Iterator, Hashable, Tuple, Union, Mapping +from typing_extensions import get_origin, get_args R = TypeVar('R') +@lru_cache(4096) +def is_async(o: Any): + return inspect.iscoroutinefunction(o) or inspect.isawaitable(o) + + class Singleton(type): """单例模式""" _instances = {} @@ -23,12 +29,13 @@ def remove(mcs, cls): mcs._instances.pop(cls, None) -def split_once(text: str, separates: Union[str, Set[str]]): # 相当于另类的pop, 不会改变本来的字符串 +@lru_cache(4096) +def split_once(text: str, separates: Union[str, Tuple[str, ...]]): # 相当于另类的pop, 不会改变本来的字符串 """单次分隔字符串""" out_text = "" quotation = "" is_split = True - separates = separates if isinstance(separates, set) else {separates} + separates = tuple(separates) for char in text: if char in {"'", '"'}: # 遇到引号括起来的部分跳过分隔 if not quotation: @@ -40,11 +47,11 @@ def split_once(text: str, separates: Union[str, Set[str]]): # 相当于另类 if char in separates and is_split: break out_text += char - result = "".join(out_text) - return result, text[len(result) + 1:] + return out_text, text[len(out_text) + 1:] -def split(text: str, separates: Optional[Set[str]] = None): +@lru_cache(4096) +def split(text: str, separates: Optional[Tuple[str, ...]] = None): """尊重引号与转义的字符串切分 Args: @@ -54,59 +61,48 @@ def split(text: str, separates: Optional[Set[str]] = None): Returns: List[str]: 切割后的字符串, 可能含有空格 """ - separates = separates or {" "} - result = [] + separates = separates or (" ",) + result = "" quotation = "" - cache = "" for index, char in enumerate(text): if char in {"'", '"'}: if not quotation: quotation = char if index and text[index - 1] == "\\": - cache += char + result += char elif char == quotation: quotation = "" if index and text[index - 1] == "\\": - cache += char - elif char in {"\n", "\r"} or (not quotation and char in separates and cache): - result.append(cache) - cache = "" - elif char != "\\" and (char not in separates or quotation): - cache += char - if cache: - result.append(cache) - return result + result += char + elif char in {"\n", "\r"} or (not quotation and char in separates): + result += "\0" + elif char != "\\": + result += char + return result.split('\0') def levenshtein_norm(source: str, target: str) -> float: """编辑距离算法, 计算源字符串与目标字符串的相似度, 取值范围[0, 1], 值越大越相似""" - return 1 - float(levenshtein(source, target)) / max(len(source), len(target)) - - -def levenshtein(source: str, target: str) -> int: - """编辑距离算法的具体内容""" - s_range = range(len(source) + 1) - t_range = range(len(target) + 1) + l_s, l_t = len(source), len(target) + s_range, t_range = range(l_s + 1), range(l_t + 1) matrix = [[(i if j == 0 else j) for j in t_range] for i in s_range] for i in s_range[1:]: for j in t_range[1:]: - del_distance = matrix[i - 1][j] + 1 - ins_distance = matrix[i][j - 1] + 1 - sub_trans_cost = 0 if source[i - 1] == target[j - 1] else 1 - sub_distance = matrix[i - 1][j - 1] + sub_trans_cost - matrix[i][j] = min(del_distance, ins_distance, sub_distance) - return matrix[len(source)][len(target)] + sub_distance = matrix[i - 1][j - 1] + (0 if source[i - 1] == target[j - 1] else 1) + matrix[i][j] = min(matrix[i - 1][j] + 1, matrix[i][j - 1] + 1, sub_distance) + + return 1 - float(matrix[l_s][l_t]) / max(l_s, l_t) _K = TypeVar("_K", bound=Hashable) _V = TypeVar("_V") +_T = TypeVar("_T") -class LruCache(Generic[_K, _V]): +class LruCache(Mapping[_K, _V]): max_size: int cache: OrderedDict - __size: int record: Dict[_K, Tuple[datetime, timedelta]] __slots__ = ("max_size", "cache", "record", "__size") @@ -117,17 +113,16 @@ def __init__(self, max_size: int = -1) -> None: self.record = {} self.__size = 0 - def __getitem__(self, key: _K) -> _V: + def get(self, key: _K, default: Optional[_T] = None) -> Union[_V, _T]: if key in self.cache: self.cache.move_to_end(key) return self.cache[key] - raise KeyError(key) + return default - def get(self, key: _K, default: Any = None) -> _V: - try: - return self[key] - except KeyError: - return default + def __getitem__(self, item): + if res := self.get(item): + return res + raise ValueError def query_time(self, key: _K) -> datetime: if key in self.cache: @@ -146,8 +141,6 @@ def set(self, key: _K, value: Any, expiration: int = 0) -> None: self.record[key] = (datetime.now(), timedelta(seconds=expiration)) def delete(self, key: _K) -> None: - if key not in self.cache: - raise KeyError(key) self.cache.pop(key) self.record.pop(key) @@ -173,13 +166,6 @@ def __iter__(self) -> Iterator[_K]: def __repr__(self) -> str: return repr(self.cache) - def update(self) -> None: - now = datetime.now() - key = random.choice(list(self.cache.keys())) - expire = self.record[key][1] - if expire.total_seconds() > 0 and now > self.record[key][0] + expire: - self.delete(key) - def update_all(self) -> None: now = datetime.now() for key in self.cache.keys(): @@ -189,41 +175,21 @@ def update_all(self) -> None: @property def recent(self) -> Optional[_V]: - if self.cache: - try: - return self.cache[list(self.cache.keys())[-1]] - except KeyError: - return None + with contextlib.suppress(KeyError): + return self.cache[list(self.cache.keys())[-1]] return None - def items(self, size: int = -1) -> Iterator[Tuple[_K, _V]]: - if size == -1: - return iter(self.cache.items()) - try: - return iter(list(self.cache.items())[:-size:-1]) - except IndexError: - return iter(self.cache.items()) + def keys(self): + return self.cache.keys() + def values(self): + return self.values() -def generic_issubclass(cls: type, par: Union[type, Any, Tuple[type, ...]]) -> bool: - """ - 检查 cls 是否是 par 中的一个子类, 支持泛型, Any, Union, GenericAlias - """ - if par is Any: - return True - with contextlib.suppress(TypeError): - if isinstance(par, (type, tuple)): - return issubclass(cls, par) - if issubclass(cls, get_origin(par)): # type: ignore - return True - if get_origin(par) is Union: - return any(generic_issubclass(cls, p) for p in get_args(par)) - if isinstance(par, TypeVar): - if par.__constraints__: - return any(generic_issubclass(cls, p) for p in par.__constraints__) - if par.__bound__: - return generic_issubclass(cls, par.__bound__) - return False + def items(self, size: int = -1) -> Iterator[Tuple[_K, _V]]: + if size > 0: + with contextlib.suppress(IndexError): + return iter(list(self.cache.items())[:-size:-1]) + return iter(self.cache.items()) def generic_isinstance(obj: Any, par: Union[type, Any, Tuple[type, ...]]) -> bool: @@ -235,13 +201,8 @@ def generic_isinstance(obj: Any, par: Union[type, Any, Tuple[type, ...]]) -> boo with contextlib.suppress(TypeError): if isinstance(par, (type, tuple)): return isinstance(obj, par) - if isinstance(obj, get_origin(par)): # type: ignore + if isinstance(obj, get_origin(par)): # type: ignore return True if get_origin(par) is Union: return any(generic_isinstance(obj, p) for p in get_args(par)) - if isinstance(par, TypeVar): - if par.__constraints__: - return any(generic_isinstance(obj, p) for p in par.__constraints__) - if par.__bound__: - return generic_isinstance(obj, par.__bound__) return False diff --git a/test_alconna/analyser_test.py b/test_alconna/analyser_test.py new file mode 100644 index 00000000..5362ce0a --- /dev/null +++ b/test_alconna/analyser_test.py @@ -0,0 +1,70 @@ +from typing import Union +from arclet.alconna.builtin.analyser import DefaultCommandAnalyser +from arclet.alconna import Alconna, Args +from arclet.alconna.typing import set_unit + + +def test_filter_out(): + DefaultCommandAnalyser.filter_out.append("int") + ana = Alconna("ana", Args["foo", str]) + assert ana.parse(["ana", 123, "bar"]).matched is True + assert ana.parse("ana bar").matched is True + DefaultCommandAnalyser.filter_out.remove("int") + assert ana.parse(["ana", 123, "bar"]).matched is False + + +def test_preprocessor(): + DefaultCommandAnalyser.preprocessors["float"] = lambda x: int(x) + ana1 = Alconna("ana1", Args["bar", int]) + assert ana1.parse(["ana1", 123.06]).matched is True + assert ana1.parse(["ana1", 123.06]).bar == 123 + del DefaultCommandAnalyser.preprocessors["float"] + assert ana1.parse(["ana1", 123.06]).matched is False + + +def test_with_set_unit(): + class Segment: + type: str + data: dict + + def __init__(self, type_: str, **data): + self.type = type_ + self.data = data + self.data.setdefault("type", type_) + + def __repr__(self): + data = self.data.copy() + if self.type == "text": + return data.get("text", "") + return f"[{self.type}:{self.data}]" + + @staticmethod + def text(content: str): + return Segment("text", text=content) + + @staticmethod + def face(id_: int, content: str = ''): + return Segment("face", id=id_, content=content) + + @staticmethod + def at(user_id: Union[int, str]): + return Segment("at", qq=str(user_id)) + + DefaultCommandAnalyser.text_sign = "plain" + DefaultCommandAnalyser.preprocessors['Segment'] = lambda x: str(x) if x.type == "text" else None + + face = set_unit(Segment, lambda x: x.type == "face") + at = set_unit(Segment, lambda x: x.type == "at") + + ana2 = Alconna("ana2", Args["foo", at]["bar", face]) + res = ana2.parse([Segment.text("ana2"), Segment.at(123456), Segment.face(103)]) + assert res.matched is True + assert res.foo.data['qq'] == '123456' + + DefaultCommandAnalyser.text_sign = 'text' + del DefaultCommandAnalyser.preprocessors['Segment'] + + +if __name__ == '__main__': + import pytest + pytest.main([__file__, "-vs"]) diff --git a/test_alconna/args_test.py b/test_alconna/args_test.py new file mode 100644 index 00000000..ba88d18b --- /dev/null +++ b/test_alconna/args_test.py @@ -0,0 +1,154 @@ +from typing import Union + +from arclet.alconna import Args +from arclet.alconna.analysis.base import analyse_args +from arclet.alconna.typing import BasePattern, PatternModel, Bind + + +def test_kwargs_create(): + arg = Args(pak=str, upgrade=str) + assert arg == Args.pak[str]["upgrade", str] + assert analyse_args(arg, "arclet-alconna bar") == {"pak": "arclet-alconna", "upgrade": 'bar'} + + +def test_magic_create(): + arg1 = Args.round[float]["test", bool]["aaa", str] + assert len(arg1) == 3 + arg1 = arg1 << Args.perm[str, ...] + ["month", int] + assert len(arg1) == 5 + arg1["foo"] = ["bar", ...] + assert len(arg1) == 6 + arg11: Args = Args.baz[int] + arg11.add_argument("foo", value=int, default=1) + assert len(arg11) == 2 + + +def test_type_convert(): + arg2 = Args.round[float]["test", bool] + assert analyse_args(arg2, "1.2 False") != {'round': '1.2', 'test': 'False'} + assert analyse_args(arg2, "1.2 False") == {'round': 1.2, 'test': False} + assert analyse_args(arg2, "a False", raise_exception=False) != {'round': 'a', 'test': False} + + +def test_regex(): + arg3 = Args.foo["abc[0-9]{3}"] + assert analyse_args(arg3, "abc123") == {"foo": "abc123"} + assert analyse_args(arg3, "abc", raise_exception=False) != {"foo": "abc"} + + +def test_string(): + arg4 = Args["foo"]["bar"] + assert analyse_args(arg4, "foo bar") == {"foo": "foo", "bar": "bar"} + + +def test_default(): + arg5 = Args.foo[int]["de", bool, True] + assert analyse_args(arg5, "123 False") == {"foo": 123, "de": False} + assert analyse_args(arg5, "123") == {"foo": 123, "de": True} + + +def test_separate(): + arg6 = Args.foo[str]["bar", int] / ";" + assert analyse_args(arg6, 'abc;123') == {'foo': 'abc', 'bar': 123} + + +def test_object(): + arg7 = Args.foo[str]["bar", 123] + assert analyse_args(arg7, ['abc', 123]) == {'foo': 'abc', 'bar': 123} + assert analyse_args(arg7, ['abc', 124], raise_exception=False) != {'foo': 'abc', 'bar': 124} + + +def test_multi(): + arg8 = Args().add_argument("multi", value=str, flags="S") + assert analyse_args(arg8, "a b c d").get('multi') == ("a", "b", "c", "d") + arg8_1 = Args().add_argument("kwargs", value=str, flags="W") + assert analyse_args(arg8_1, "a=b c=d").get('kwargs') == {"a": "b", "c": "d"} + + +def test_anti(): + arg9 = Args().add_argument("anti", value=r"(.+?)/(.+?)\.py", flags="A") + assert analyse_args(arg9, "a/b.mp3") == {"anti": "a/b.mp3"} + assert analyse_args(arg9, "a/b.py", raise_exception=False) != {"anti": "a/b.py"} + + +def test_choice(): + arg10 = Args.choice[("a", "b", "c")] + assert analyse_args(arg10, "a") == {"choice": "a"} + assert analyse_args(arg10, "d", raise_exception=False) != {"choice": "d"} + arg10_1 = Args.mapping[{"a": 1, "b": 2, "c": 3}] + assert analyse_args(arg10_1, "a") == {"mapping": 1} + assert analyse_args(arg10_1, "d", raise_exception=False) != {"mapping": "d"} + + +def test_union(): + arg11 = Args.bar[Union[int, float]] + assert analyse_args(arg11, "1.2") == {"bar": 1.2} + assert analyse_args(arg11, "1") == {"bar": 1} + assert analyse_args(arg11, "abc", raise_exception=False) != {"bar": "abc"} + arg11_1 = Args.bar[[int, float, "abc"]] + assert analyse_args(arg11_1, "1.2") == analyse_args(arg11, "1.2") + assert analyse_args(arg11_1, "abc") == {"bar": "abc"} + assert analyse_args(arg11_1, "cba", raise_exception=False) != {"bar": "cba"} + arg11_2 = Args.bar["int|float"] + assert analyse_args(arg11_2, "1.2") == analyse_args(arg11_1, "1.2") + + +def test_force(): + arg12 = Args.foo[str].add_argument("bar", value=bool, flags="F") + assert analyse_args(arg12, ['123', True]) == {'bar': True, 'foo': '123'} + assert analyse_args(arg12, ['123', 'True'], raise_exception=False) != {'bar': True, 'foo': '123'} + + +def test_optional(): + arg13 = Args.foo[str].add_argument("bar", value=int, flags="O") + assert analyse_args(arg13, 'abc 123') == {'foo': 'abc', 'bar': 123} + assert analyse_args(arg13, 'abc') == {'foo': 'abc'} + + +def test_kwonly(): + arg14 = Args.foo[str].add_argument("bar", value=int, flags="K") + assert analyse_args(arg14, 'abc bar=123') == {'foo': 'abc', 'bar': 123} + assert analyse_args(arg14, 'abc 123', raise_exception=False) != {'foo': 'abc', 'bar': 123} + arg14_1 = Args["--width;OK", int, 1280]["--height;OK", int, 960] + assert analyse_args(arg14_1, "--width=960 --height=960") == {"--width": 960, "--height": 960} + + +def test_pattern(): + test_type = BasePattern("(.+?).py", PatternModel.REGEX_CONVERT, list, lambda x: x.split("/"), "test") + arg15 = Args().add_argument("bar", value=test_type) + assert analyse_args(arg15, 'abc.py') == {'bar': ['abc']} + assert analyse_args(arg15, 'abc/def.py') == {'bar': ['abc', 'def']} + assert analyse_args(arg15, 'abc/def.mp3', raise_exception=False) != {'bar': ['abc', 'def']} + + +def test_callable(): + def test(foo: str, bar: int, baz: bool = False): + ... + + arg16, _ = Args.from_callable(test) + assert len(arg16.argument) == 3 + assert analyse_args(arg16, "abc 123 True") == {"foo": "abc", "bar": 123, "baz": True} + + +def test_func_anno(): + from datetime import datetime + + def test(time: Union[int, str]) -> datetime: + return datetime.fromtimestamp(time) if isinstance(time, int) else datetime.fromisoformat(time) + + arg17 = Args["time", test] + assert analyse_args(arg17, "1145-05-14") == {"time": datetime.fromisoformat("1145-05-14")} + + +def test_annotated(): + from typing_extensions import Annotated + + arg18 = Args["foo", Annotated[int, lambda x: x > 0]]["bar", Bind[int, lambda x: x < 0]] + assert analyse_args(arg18, "123 -123") == {"foo": 123, "bar": -123} + assert analyse_args(arg18, "0 0", raise_exception=False) != {"foo": 0, "bar": 0} + + +if __name__ == '__main__': + import pytest + + pytest.main([__file__, "-vs"]) diff --git a/test_alconna/base_test.py b/test_alconna/base_test.py new file mode 100644 index 00000000..5aeeaa49 --- /dev/null +++ b/test_alconna/base_test.py @@ -0,0 +1,74 @@ +from arclet.alconna.base import Option, Subcommand, CommandNode, Args +from arclet.alconna.analysis.base import analyse_option, analyse_subcommand + + +def test_node_create(): + node = CommandNode("foo", Args.bar[int], dest="test") + assert node.name == "foo" + assert node.dest != "foo" + assert node.nargs == 1 + + +def test_string_args(): + node1 = CommandNode("foo", "bar:int") + assert node1.args == Args.bar[int] + + +def test_node_requires(): + node2 = CommandNode("foo", requires=["baz", "qux"]) + assert node2.dest == "baz_qux_foo" + node2_1 = CommandNode("baz qux foo") + assert node2_1.name == "foo" + assert node2_1.requires == ["baz", "qux"] + + +def test_option_aliases(): + opt = Option("test|T|t") + assert opt.aliases == {"test", "T", "t"} + opt_1 = Option("test", alias=["T", "t"]) + assert opt_1.aliases == {"test", "T", "t"} + assert opt == opt_1 + assert opt == Option("T|t|test") + + +def test_option_requires(): + opt1 = Option("foo bar test|T|t") + assert opt1.aliases == {"test", "T", "t"} + assert opt1.requires == ["foo", "bar"] + opt1_1 = Option("foo bar test| T | t") + assert opt1_1.aliases != {"test", "T", "t"} + + +def test_separator(): + opt2 = Option("foo", Args.bar[int], separators="|") + assert analyse_option(opt2, "foo|123") == ("foo", {"args": {"bar": 123}, "value": None}) + opt2_1 = Option("foo", Args.bar[int]).separate("|") + assert opt2 == opt2_1 + + +def test_subcommand(): + sub = Subcommand("test", options=[Option("foo"), Option("bar")]) + assert len(sub.options) == 2 + assert analyse_subcommand(sub, "test foo") == ( + "test", {"value": None, "args": {}, 'options': {"foo": {"args": {}, "value": Ellipsis}}} + ) + + +def test_compact(): + opt3 = Option("foo", Args.bar[int], separators="") + assert opt3.is_compact is True + assert analyse_option(opt3, "foo123") == ("foo", {"args": {"bar": 123}, "value": None}) + + +def test_from_callable(): + def test(bar: int, baz: bool = False): + ... + + opt4 = Option("foo", action=test) + assert len(opt4.args.argument) == 2 + assert analyse_option(opt4, "foo 123 True") == ("foo", {"args": {"bar": 123, "baz": True}, "value": None}) + + +if __name__ == '__main__': + import pytest + pytest.main([__file__, "-vs"]) diff --git a/test_alconna/components_test.py b/test_alconna/components_test.py new file mode 100644 index 00000000..69464df4 --- /dev/null +++ b/test_alconna/components_test.py @@ -0,0 +1,77 @@ +from arclet.alconna import Alconna, Option, Args, Subcommand, Arpamar, ArpamarBehavior, store_value +from arclet.alconna.builtin.actions import set_default, exclusion, cool_down +from arclet.alconna.components.duplication import AlconnaDuplication +from arclet.alconna.components.stub import ArgsStub, OptionStub, SubcommandStub + + +def test_behavior(): + com = Alconna("comp", Args["bar", int]) + Option("foo") + + @com.behaviors.append + class Test(ArpamarBehavior): + requires = [set_default(321, option="foo")] + + @classmethod + def operate(cls, interface: "Arpamar"): + print('\ncom: ') + print(interface.query("options.foo.value")) + interface.matched = False + + assert com.parse("comp 123").matched is False + + +def test_set_defualt(): + com1 = Alconna("comp1") + Option("foo", action=store_value(123)) + Option("bar", action=store_value(234)) + com1.behaviors.append(set_default(321, option="bar")) + assert com1.parse("comp1 bar").query("bar.value") == 234 + assert com1.parse("comp1 foo").query("bar.value") == 321 + + +def test_exclusion(): + com2 = Alconna( + "comp2", + options=[Option("foo"), Option("bar")], + behaviors=[exclusion(target_path="options.foo", other_path="options.bar")] + ) + assert com2.parse("comp2 foo").matched is True + assert com2.parse("comp2 bar").matched is True + assert com2.parse("comp2 foo bar").matched is False + + +def test_cooldown(): + import time + com3 = Alconna("comp3", Args["bar", int], behaviors=[cool_down(0.3)]) + print('') + for i in range(4): + time.sleep(0.2) + print(com3.parse("comp3 {}".format(i))) + + +def test_duplication(): + class Demo(AlconnaDuplication): + testArgs: ArgsStub + bar: OptionStub + sub: SubcommandStub + + com4 = Alconna( + "comp4", Args["foo", int], + options=[ + Option("--bar", Args["bar", str]), + Subcommand("sub", options=[Option("--sub1", Args["baz", str])]) + ] + ) + res = com4.parse("comp4 123 --bar abc sub --sub1 xyz") + assert res.matched is True + duplication = com4.parse("comp4 123 --bar abc sub --sub1 xyz", duplication=Demo) + assert isinstance(duplication, Demo) + assert duplication.testArgs.available is True + assert duplication.testArgs.foo == 123 + assert duplication.bar.available is True + assert duplication.bar.args.bar == 'abc' + assert duplication.sub.available is True + assert duplication.sub.option("sub1").args.first_arg == 'xyz' + + +if __name__ == '__main__': + import pytest + pytest.main([__file__, "-vs"]) diff --git a/test_alconna/config_test.py b/test_alconna/config_test.py new file mode 100644 index 00000000..fae3e4bf --- /dev/null +++ b/test_alconna/config_test.py @@ -0,0 +1,23 @@ +from arclet.alconna.config import config +from arclet.alconna import Alconna, Option + + +def test_config(): + config.separators = {";"} + cfg = Alconna("cfg") + Option("foo") + assert cfg.parse("cfg foo").matched is False + assert cfg.parse("cfg;foo").matched is True + config.separators = {' '} + + +def test_alconna_config(): + Alconna.config(headers=["!"]) + cfg1 = Alconna("cfg1") + assert cfg1.parse("cfg1").matched is False + assert cfg1.parse("!cfg1").matched is True + Alconna.config(headers=['']) + + +if __name__ == '__main__': + import pytest + pytest.main([__file__, "-vs"]) diff --git a/test_alconna/construct_test.py b/test_alconna/construct_test.py new file mode 100644 index 00000000..b35869fe --- /dev/null +++ b/test_alconna/construct_test.py @@ -0,0 +1,123 @@ +from typing import Optional + +from arclet.alconna import AlconnaString, AlconnaFormat, AlconnaFire, AlconnaDecorate, delegate, Args + + +def test_koishi_like(): + con = AlconnaString("con ") + assert con.parse("con https://www.example.com").matched is True + con_1 = AlconnaString("con_1", "--foo [bar:bool]", "--bar &True") + assert con_1.parse("con_1 --bar").query("bar.value") is True + assert con_1.parse("con_1 --foo").query("foo.args") == {"foo": "123"} + con_2 = AlconnaString("[!con_2|/con_2] <...bar>") + assert con_2.parse("!con_2 112 334").matched is True + assert con_2.parse("con_2 112 334").matched is False + + +def test_format_like(): + con1 = AlconnaFormat("con1 {artist} {title:str} singer {name:str}") + print('') + print(repr(con1.get_help())) + assert con1.parse("con1 Nameless MadWorld").artist == "Nameless" + con1_1 = AlconnaFormat("con1_1 user {target}", {"target": str}) + assert con1_1.parse("con1_1 user Nameless").query("user.target") == "Nameless" + con1_2 = AlconnaFormat( + "con1_2 user {target} perm set {perm} {default}", + {"target": str, "perm": str, "default": Args["default", bool, True]}, + ) + print(repr(con1_2.get_help())) + assert con1_2.parse("con1_2 user Nameless perm set Admin.set True").query("perm_set.default") is True + + +def test_fire_like_class(): + class Test: + """测试从类中构建对象""" + + def __init__(self, sender: Optional[str]): + """Constructor""" + self.sender = sender + + def talk(self, name="world"): + """Test Function""" + print(f"Hello {name} from {self.sender}") + + class Repo: + def set(self, name): + print(f"set {name}") + + class SubConfig: + description = "sub-test" + + class Config: + command = "con2" + description = "测试" + extra = "reject" + get_subcommand = True + + con2 = AlconnaFire(Test) + assert con2.parse("con2 Alc talk Repo set hhh").matched is True + assert con2.parse("con2 talk Friend").query("talk.name") == "Friend" + print('') + print(repr(con2.get_help())) + print(con2.instance) + + +def test_fire_like_object(): + class Test: + def __init__(self, action=sum): + self.action = action + + def calculator(self, a, b, c, *nums: int, **kwargs: str): + """calculator""" + print(a, b, c) + print(nums, kwargs) + print(self.action(nums)) + + class Config: + command = "con3" + + con3 = AlconnaFire(Test(sum)) + print('') + print(con3.get_help()) + assert con3.parse("con3 calculator 1 2 3 4 5 d=6 f=7") + + +def test_fire_like_func(): + def test_function(name="world"): + """测试从函数中构建对象""" + + class Config: + command = "con4" + description = "测试" + + print("Hello {}!".format(name)) + + con4 = AlconnaFire(test_function) + assert con4.parse("con4 Friend").matched is True + + +def test_delegate(): + @delegate + class con5: + """hello""" + prefix = "!" + + print(repr(con5.get_help())) + + +def test_click_like(): + con6 = AlconnaDecorate() + + @con6.build_command("con6") + @con6.option("--count", Args["num", int], help="Test Option Count") + @con6.option("--foo", Args["bar", str], help="Test Option Foo") + def hello(bar: str, num: int = 1): + """测试DOC""" + print(bar * num) + + assert hello("con6 --foo John --count 2").matched is True + + +if __name__ == '__main__': + import pytest + pytest.main([__file__, "-vs"]) diff --git a/test_alconna/core_test.py b/test_alconna/core_test.py new file mode 100644 index 00000000..8cbb89d1 --- /dev/null +++ b/test_alconna/core_test.py @@ -0,0 +1,275 @@ +from arclet.alconna.core import AlconnaGroup +from arclet.alconna import Alconna, Args, Option, Subcommand, ArgParserTextFormatter, AllParam + + +def test_alconna_create(): + alc = Alconna( + "core", Args["foo", str], + headers=["!"], options=[ + Option("bar", Args["num", int]) + ] + ) + assert alc.path == "Alconna.core" + assert alc.parse("!core abc bar 123").matched is True + + +def test_alconna_multi_match(): + alc1 = Alconna( + headers=["/", "!"], + command="core1", + options=[ + Subcommand( + "test", [Option("-u", Args["username", str], help_text="输入用户名")], args=Args["test", "Test"], + help_text="测试用例"), + Option("-n|--num", Args["count", int, 123], help_text="输入数字"), + Option("-u", Args(id=int), help_text="输入需要At的用户") + ], + main_args=Args["IP", 'ip'], + help_text="测试指令1" + ) + assert len(alc1.options) == 5 + print('') + print(repr(alc1.get_help())) + res1 = alc1.parse(["/core1 -u", 123, "test Test -u AAA -n 222 127.0.0.1"]) + assert res1.matched is True + assert res1.query("num.count") == 222 + assert res1.query("test.u.username") == "AAA" + res2 = alc1.parse(["/core1 127.0.0.1 -u", 321]) + assert res2.IP == "127.0.0.1" + res3 = alc1.parse("/core1 aa") + assert res3.matched is False + assert res3.head_matched is True + + +def test_special_header(): + alc2 = Alconna("RD{r:int}?=={e:int}") + res = alc2.parse("RD100==36") + assert res.matched is True + assert res.header['r'] == '100' + assert res.header['e'] == '36' + + +def test_formatter(): + alc3 = Alconna( + command="/pip", + options=[ + Subcommand("install", [ + Option("--upgrade", help_text="升级包"), + Option("-i|--index-url", Args["url", "url"]) + ], Args["pak", str], help_text="安装一个包"), + Option("--retries", Args["retries", int], help_text="设置尝试次数"), + Option("-t|--timeout", Args["sec", int], help_text="设置超时时间"), + Option("--exists-action", Args["action", str], help_text="添加行为"), + Option("--trusted-host", Args["host", str], help_text="选择可信赖地址") + ], + help_text="简单的pip指令", + formatter_type=ArgParserTextFormatter + ) + print('') + print(alc3.get_help()) + res = alc3.parse( + "/pip install alconna --upgrade -i https://pypi.douban.com/simple -t 6 --trusted-host pypi.douban.com" + ) + assert res.matched is True + assert res.all_matched_args['sec'] == 6 + assert res.all_matched_args['pak'] == 'alconna' + assert res.all_matched_args['url'] == "https://pypi.douban.com/simple" + assert res.all_matched_args['host'] == 'pypi.douban.com' + + +def test_alconna_special_help(): + alc4 = Alconna( + "Cal", + help_text="计算器 Usage: Cal ; Example: Cal -sum 1 2;", + options=[ + Subcommand( + "-div", + options=[Option("--round|-r", Args.decimal[int], help_text="保留n位小数")], + args=Args(num_a=int, num_b=int), help_text="除法计算" + ) + ], + ) + print('') + print(alc4.get_help()) + res = alc4.parse("Cal -div 12 23 --round 2") + assert res.query("div.args") == {"num_a": 12, "num_b": 23} + + +def test_alconna_chain_option(): + alc5 = Alconna( + "点歌" + ).add( + "歌名", sep=":", args=Args(song_name=str) + ).add( + "歌手", sep=":", args=Args(singer_name=str) + ) + res = alc5.parse("点歌 歌名:Freejia") + assert res.song_name == "Freejia" + + +def test_alconna_multi_header(): + class A: + pass + + a, b = A(), A() + alc6 = Alconna("core6", headers=["/", "!", "."]) + assert alc6.parse("!core6").head_matched is True + assert alc6.parse("#core6").head_matched is False + assert alc6.parse([a]).head_matched is False + alc6_1 = Alconna("core6_1", headers=["/", a]) + assert alc6_1.parse("/core6_1").head_matched is True + assert alc6_1.parse([a, "core6_1"]).head_matched is True + assert alc6_1.parse([b, "core6_1"]).head_matched is False + alc6_2 = Alconna("core6_2", headers=[(a, "/")]) + assert alc6_2.parse([a, "/core6_2"]).head_matched is True + assert alc6_2.parse([a, "core6_2"]).head_matched is False + assert alc6_2.parse("/core6_2").head_matched is False + alc6_3 = Alconna(A) + assert alc6_3.parse([a]).head_matched is True + assert alc6_3.parse([b]).head_matched is True + assert alc6_3.parse('a').head_matched is False + alc6_4 = Alconna(A, headers=["/", b]) + assert alc6_4.parse(["/", a]).head_matched is True + assert alc6_4.parse([b, b]).head_matched is True + assert alc6_4.parse([b, a]).head_matched is True + assert alc6_4.parse([b]).head_matched is False + assert alc6_4.parse([b, "abc"]).head_matched is False + alc6_5 = Alconna(headers=["/dd", "!cd"]) + assert alc6_5.parse("/dd").head_matched is True + assert alc6_5.parse("/dd !cd").matched is False + alc6_6 = Alconna(1234) # type: ignore + assert alc6_6.parse([1234]).head_matched is True + assert alc6_6.parse([4321]).head_matched is False + + +def test_alconna_namespace(): + alc7 = Alconna("core7", namespace="Test") + assert alc7.path == "Test.core7" + alc7_1 = Alconna("core7_1").reset_namespace("Test") + assert alc7_1.path == "Test.core7_1" + alc7_2 = "Test" / Alconna("core7_2") + assert alc7_2.path == "Test.core7_2" + + +def test_alconna_add_option(): + alc8 = Alconna("core8") + Option("foo", Args["foo", str]) >> Option("bar") + assert len(alc8.options) == 4 + alc8_1 = Alconna("core8_1") + "foo/bar:str" >> "baz" + assert len(alc8_1.options) == 4 + + +def test_alconna_action(): + def test(wild, text: str, num: int, boolean: bool = False): + print('wild:', wild) + print('text:', text) + print('num:', num) + print('boolean:', boolean) + + alc9 = Alconna("core9", action=test) + print("") + print("alc9: -----------------------------") + alc9.parse("core9 abc def 123 False") + print("alc9: -----------------------------") + + +def test_alconna_synthesise(): + alc10 = Alconna( + main_args=Args["min", r"(\d+)张"]["max;O", r"(\d+)张"] / "到", + headers=["发涩图", "来点涩图", "来点好康的"], + options=[Option("从", Args["tags;5", str] / ("和", "与"), separators='')], + action=lambda x, y: (int(x), int(y)) + ) + res = alc10.parse("来点涩图 3张到6张 从女仆和能天使与德克萨斯和拉普兰德与莫斯提马") + assert res.matched is True + assert res.min == 3 + assert res.tags == ("女仆", "能天使", "德克萨斯", "拉普兰德", "莫斯提马") + + +def test_simple_override(): + alc11 = Alconna("core11") + Option("foo", Args["bar", str]) + Option("foo") + res = alc11.parse("core11 foo abc") + res1 = alc11.parse("core11 foo") + assert res.matched is True + assert res1.matched is True + + +def test_requires(): + alc12 = Alconna( + "core12", + Args["target", int], + options=[ + Option("user perm set", Args["foo", str], help_text="set user permission"), + Option("user perm del", Args["foo", str], help_text="del user permission"), + Option("group perm set", Args["foo", str], help_text="set group permission"), + Option("group perm del", Args["foo", str], help_text="del group permission"), + Option("test") + ] + ) + + assert alc12.parse("core12 123 user perm set 123").find("user_perm_set") is True + assert alc12.parse("core12 123 user perm del 123").find("user_perm_del") is True + assert alc12.parse("core12 123 group perm set 123").find("group_perm_set") is True + assert alc12.parse("core12 123 group perm del 123 test").find("group_perm_del") is True + print('\n------------------------') + print(alc12.get_help()) + + +def test_wildcard(): + alc13 = Alconna("core13", Args["foo", AllParam]) + assert alc13.parse(["core13 abc def gh", 123, 5.0, "dsdf"]).foo == ['abc', 'def', 'gh', 123, 5.0, 'dsdf'] + + +def test_alconna_group(): + alc14 = AlconnaGroup( + "core14", + Alconna("core14", options=[Option("--foo"), Option("--bar", Args["num", int])]), + Alconna("core14", options=[Option("--baz"), Option("--qux", Args["num", int])]), + ) + assert alc14.parse("core14 --foo --bar 123").matched is True + assert alc14.parse("core14 --baz --qux 123").matched is True + print("\n---------------------------") + print(alc14.get_help()) + + +def test_fuzzy(): + alc15 = Alconna(main_args="foo:str", headers=['!core15'], is_fuzzy_match=True) + assert alc15.parse("core15 foo bar").matched is False + + +def test_shortcut(): + alc16 = Alconna("core16", Args["foo", int], options=[Option("bar")]) + assert alc16.parse("core16 123 bar").matched is True + alc16.shortcut("TEST", "core16 432 bar") + res = alc16.parse("TEST") + assert res.matched is True + assert res.foo == 432 + + +def test_help(): + alc17 = Alconna("core17") + Option("foo", Args["bar", str]) + alc17.parse("core17 --help") + alc17.parse("core17 --help foo") + alc17_1 = Alconna( + "core17_1", options=[Option("foo bar abc baz", Args["qux", int]), Option("foo qux bar", Args["baz", str])] + ) + alc17_1.parse("core17_1 --help") + alc17_1.parse("core17_1 --help aaa") + + +def test_hide_annotation(): + alc18 = Alconna("core18", Args["foo", int]) + print(alc18.get_help()) + alc18_1 = Alconna("core18_1", Args["foo;H", int]) + print(alc18_1.get_help()) + + +def test_args_notice(): + alc19 = Alconna("core19", Args["foo#A TEST;O", int]) + Option("bar", Args["baz#ANOTHER TEST;K", str]) + print('') + print(alc19.get_help()) + + +if __name__ == '__main__': + import pytest + + pytest.main([__file__, "-vs"]) diff --git a/test_alconna/entry_test.py b/test_alconna/entry_test.py new file mode 100644 index 00000000..667ed947 --- /dev/null +++ b/test_alconna/entry_test.py @@ -0,0 +1,13 @@ +from test_alconna.args_test import * +from test_alconna.base_test import * +from test_alconna.type_test import * +from test_alconna.util_test import * +from test_alconna.core_test import * +from test_alconna.construct_test import * +from test_alconna.components_test import * +from test_alconna.config_test import * +from test_alconna.analyser_test import * + +if __name__ == '__main__': + import pytest + pytest.main([__file__, "-vs"]) diff --git a/test_alconna/type_test.py b/test_alconna/type_test.py new file mode 100644 index 00000000..e522e751 --- /dev/null +++ b/test_alconna/type_test.py @@ -0,0 +1,119 @@ +from arclet.alconna.typing import DataCollection, BasePattern, PatternModel, args_type_parser +from arclet.alconna.builtin.pattern import ObjectPattern + + +def test_collection(): + """测试数据集合协议, 要求__str__、__iter__和__len__""" + assert isinstance("abcdefg", DataCollection) + assert isinstance(["abcd", "efg"], DataCollection) + assert isinstance({"a": 1}, DataCollection) + assert isinstance([123, 456, 7.0, {"a": 1}], DataCollection) + assert issubclass(list, DataCollection) + + +def test_pattern_of(): + """测试 BasePattern 的快速创建方法之一, 对类有效""" + pat = BasePattern.of(int) + assert pat.origin == int + assert pat.validate(123)[0] == 123 + assert pat.validate('abc')[1] == 'E' + + +def test_pattern_on(): + """测试 BasePattern 的快速创建方法之一, 对对象有效""" + pat1 = BasePattern.on(123) + assert pat1.origin == int + assert pat1.validate(123)[0] == 123 + assert pat1.validate(124)[1] == 'E' + + +def test_pattern_keep(): + """测试 BasePattern 的保持模式, 不会进行匹配或者类型转换""" + pat2 = BasePattern(model=PatternModel.KEEP) + assert pat2.validate(123)[0] == 123 + assert pat2.validate("abc")[0] == "abc" + + +def test_pattern_regex(): + """测试 BasePattern 的正则匹配模式, 仅正则匹配""" + pat3 = BasePattern("abc[A-Z]+123", PatternModel.REGEX_MATCH) + assert pat3.validate("abcABC123")[0] == "abcABC123" + assert pat3.validate("abcAbc123")[1] == "E" + + +def test_pattern_regex_convert(): + """测试 BasePattern 的正则转换模式, 正则匹配成功后再进行类型转换""" + pat4 = BasePattern(r"\[at:(\d+)\]", PatternModel.REGEX_CONVERT, int, lambda x: int(x)) + assert pat4.validate("[at:123456]")[0] == 123456 + assert pat4.validate("[at:abcdef]")[1] == "E" + assert pat4.validate(123456)[0] == 123456 + + +def test_pattern_type_convert(): + """测试 BasePattern 的类型转换模式, 仅将传入对象变为另一类型的新对象""" + pat5 = BasePattern(model=PatternModel.TYPE_CONVERT, origin=str, converter=lambda x: str(x)) + assert pat5.validate(123)[0] == "123" + assert pat5.validate([4, 5, 6])[0] == "[4, 5, 6]" + + +def test_pattern_accepts(): + """测试 BasePattern 的输入类型筛选, 不在范围内的类型视为非法""" + pat6 = BasePattern(model=PatternModel.TYPE_CONVERT, origin=str, converter=lambda x: x.decode(), accepts=[bytes]) + assert pat6.validate(b'123')[0] == "123" + assert pat6.validate(123)[1] == 'E' + pat6_1 = BasePattern(model=PatternModel.KEEP, accepts=[int, float]) + assert pat6_1.validate(123)[0] == 123 + assert pat6_1.validate('123')[1] == 'E' + + +def test_pattern_previous(): + """测试 BasePattern 的前置表达式, 在传入的对象类型不正确时会尝试用前置表达式进行预处理""" + + class A: + def __repr__(self): + return '123' + + pat7 = BasePattern(model=PatternModel.TYPE_CONVERT, origin=str, converter=lambda x: f"abc[{x}]") + pat7_1 = BasePattern( + r"abc\[(\d+)\]", model=PatternModel.REGEX_CONVERT, origin=int, converter=lambda x: int(x), previous=pat7 + ) + assert pat7_1.validate("abc[123]")[0] == 123 + assert pat7_1.validate(A())[0] == 123 + + +def test_pattern_anti(): + """测试 BasePattern 的反向验证功能""" + pat8 = BasePattern.of(int) + assert pat8.validate(123)[1] == 'V' + assert pat8.invalidate(123)[1] == 'E' + + +def test_pattern_validator(): + """测试 BasePattern 的匹配后验证器, 会对匹配结果进行验证""" + pat9 = BasePattern(model=PatternModel.KEEP, origin=int, validators=[lambda x: x > 0]) + assert pat9.validate(23)[0] == 23 + assert pat9.validate(-23)[1] == 'E' + + +def test_args_parser(): + pat10 = args_type_parser(int) + assert pat10.validate(-321)[1] == 'V' + pat10_1 = args_type_parser(123) + assert pat10_1 == BasePattern.on(123) + + +def test_object_pattern(): + class A: + def __init__(self, username: str, userid: int): + self.name = username + self.id = userid + + pat11 = ObjectPattern(A, flag='urlget') + + assert pat11.validate("username=abcd&userid=123")[1] == 'V' + + +if __name__ == '__main__': + import pytest + + pytest.main([__file__, "-vs"]) diff --git a/test_alconna/util_test.py b/test_alconna/util_test.py new file mode 100644 index 00000000..b27918ca --- /dev/null +++ b/test_alconna/util_test.py @@ -0,0 +1,55 @@ +from arclet.alconna.util import split_once, split, LruCache +from arclet.alconna.builtin.checker import simple_type + + +def test_split_once(): + """测试单次分割函数, 能以引号扩起空格""" + text1 = "rrr b bbbb" + text2 = "\'rrr b\' bbbb" + assert split_once(text1, ' ') == ('rrr', 'b bbbb') + assert split_once(text2, ' ') == ("'rrr b'", 'bbbb') + + +def test_split(): + """测试分割函数, 能以引号扩起空格, 并允许保留引号""" + text1 = "rrr b bbbb" + text2 = "\'rrr b\' bbbb" + text3 = "\\\'rrr b\\\' bbbb" + assert split(text1) == ["rrr", "b", "bbbb"] + assert split(text2) == ["rrr b", "bbbb"] + assert split(text3) == ["'rrr b'", "bbbb"] + + +def test_lru(): + """测试 LRU缓存""" + cache: LruCache[str, str] = LruCache(3) + cache.set("a", "a") + cache.set("b", "b") + cache.set("c", "c") + assert cache.recent == "c" + _ = cache.get("a") + print(f"\n{cache}") + assert cache.recent == "a" + cache.set("d", "d") + assert cache.get("b", Ellipsis) == Ellipsis + + +def test_checker(): + @simple_type() + def hello(num: int): + return num + + assert hello(123) == 123 + assert hello("123") == 123 # type: ignore + + @simple_type() + def test(foo: 'bar'): # type: ignore + return foo + + assert test("bar") == "bar" + assert test("foo") is None + + +if __name__ == '__main__': + import pytest + pytest.main([__file__, "-vs"])