From c0548e3846770f8a37fb5b7a2f7f562e837f3373 Mon Sep 17 00:00:00 2001 From: RF-Tar-Railt <3165388245@qq.com> Date: Mon, 16 Oct 2023 17:36:28 +0800 Subject: [PATCH] :art: black --- src/arclet/alconna/_internal/_analyser.py | 44 +++++------ src/arclet/alconna/_internal/_argv.py | 17 +++-- src/arclet/alconna/_internal/_handlers.py | 62 +++++++++------- src/arclet/alconna/_internal/_header.py | 32 ++++---- src/arclet/alconna/_internal/_util.py | 4 +- src/arclet/alconna/action.py | 2 + src/arclet/alconna/args.py | 19 +++-- src/arclet/alconna/argv.py | 8 +- src/arclet/alconna/arparma.py | 42 +++++++---- src/arclet/alconna/base.py | 54 +++++++------- src/arclet/alconna/builtin.py | 51 +++++++++---- src/arclet/alconna/completion.py | 13 ++-- src/arclet/alconna/config.py | 2 + src/arclet/alconna/core.py | 36 ++++----- src/arclet/alconna/duplication.py | 6 +- src/arclet/alconna/formatter.py | 90 +++++++++++------------ src/arclet/alconna/manager.py | 87 +++++++++++----------- src/arclet/alconna/model.py | 8 +- src/arclet/alconna/model.pyi | 8 +- src/arclet/alconna/output.py | 2 + src/arclet/alconna/stub.py | 15 ++-- src/arclet/alconna/typing.py | 11 ++- 22 files changed, 330 insertions(+), 283 deletions(-) diff --git a/src/arclet/alconna/_internal/_analyser.py b/src/arclet/alconna/_internal/_analyser.py index 2f72f60e..11f88743 100644 --- a/src/arclet/alconna/_internal/_analyser.py +++ b/src/arclet/alconna/_internal/_analyser.py @@ -37,11 +37,7 @@ from ..core import Alconna from ._argv import Argv -_SPECIAL = { - "help": handle_help, - "shortcut": handle_shortcut, - "completion": handle_completion -} +_SPECIAL = {"help": handle_help, "shortcut": handle_shortcut, "completion": handle_completion} def _compile_opts(option: Option, data: dict[str, Sentence | Option | list[Option] | SubAnalyser]): @@ -76,7 +72,7 @@ def default_compiler(analyser: SubAnalyser, pids: set[str]): """ for opts in analyser.command.options: if isinstance(opts, Option): - if opts.compact or opts.action.type == 2 or not set(analyser.command.separators).issuperset(opts.separators): + if opts.compact or opts.action.type == 2 or not set(analyser.command.separators).issuperset(opts.separators): # noqa: E501 analyser.compact_params.append(opts) _compile_opts(opts, analyser.compile_params) # type: ignore if opts.default: @@ -100,6 +96,7 @@ def default_compiler(analyser: SubAnalyser, pids: set[str]): @dataclass class SubAnalyser(Generic[TDC]): """子解析器, 用于子命令的解析""" + command: Subcommand """子命令""" default_main_only: bool = field(default=False) @@ -157,9 +154,7 @@ def result(self) -> SubcommandResult: for k, v in self.default_sub_result.items(): if k not in self.subcommands_result: self.subcommands_result[k] = v - res = SubcommandResult( - self.value_result, self.args_result, self.options_result, self.subcommands_result - ) + res = SubcommandResult(self.value_result, self.args_result, self.options_result, self.subcommands_result) self.reset() return res @@ -233,6 +228,7 @@ def get_sub_analyser(self, target: Subcommand) -> SubAnalyser[TDC] | None: class Analyser(SubAnalyser[TDC], Generic[TDC]): """命令解析器""" + command: Alconna """命令实例""" used_tokens: set[int] @@ -252,10 +248,7 @@ def __init__(self, alconna: Alconna[TDC], compiler: TCompile | None = None): self.used_tokens = set() self.command_header = Header.generate(alconna.command, alconna.prefixes, alconna.meta.compact) compiler = compiler or default_compiler - compiler( - self, - command_manager.resolve(self.command).param_ids - ) + compiler(self, command_manager.resolve(self.command).param_ids) def _clr(self): self.used_tokens.clear() @@ -285,15 +278,15 @@ def shortcut( if isinstance(short, Arparma): return short - argv.build(short.get('command', argv.converter(self.command.command or self.command.name))) - if not short.get('fuzzy') and data: + argv.build(short.get("command", argv.converter(self.command.command or self.command.name))) + if not short.get("fuzzy") and data: exc = ParamsUnmatched(lang.require("analyser", "param_unmatched").format(target=data[0])) if self.command.meta.raise_exception: raise exc return self.export(argv, True, exc) - if short.get('fuzzy') and reg and len(trigger) > reg.span()[1]: - argv.addon((trigger[reg.span()[1]:],)) - argv.addon(short.get('args', [])) + if short.get("fuzzy") and reg and len(trigger) > reg.span()[1]: + argv.addon((trigger[reg.span()[1] :],)) + argv.addon(short.get("args", [])) data = _handle_shortcut_data(argv, data) argv.bak_data = argv.raw_data.copy() argv.addon(data) @@ -318,11 +311,7 @@ def process(self, argv: Argv[TDC]) -> Arparma[TDC]: ParamsUnmatched: 参数不匹配 ArgumentMissing: 参数缺失 """ - if ( - argv.message_cache and - argv.token in self.used_tokens and - (res := command_manager.get_record(argv.token)) - ): + if argv.message_cache and argv.token in self.used_tokens and (res := command_manager.get_record(argv.token)): return res try: self.header_result = analyse_header(self.command_header, argv) @@ -363,7 +352,7 @@ def process(self, argv: Argv[TDC]) -> Arparma[TDC]: rest = argv.release() if len(rest) > 0: if isinstance(rest[-1], str) and rest[-1] in argv.completion_names: - argv.bak_data[-1] = argv.bak_data[-1][:-len(rest[-1])].rstrip() + argv.bak_data[-1] = argv.bak_data[-1][: -len(rest[-1])].rstrip() return handle_completion(self, argv, rest[-2]) exc = ParamsUnmatched(lang.require("analyser", "param_unmatched").format(target=argv.next(move=False)[0])) else: @@ -388,7 +377,7 @@ def analyse(self, argv: Argv[TDC]) -> Arparma[TDC] | None: except (ParamsUnmatched, ArgumentMissing) as e1: if (rest := argv.release()) and isinstance(rest[-1], str): if rest[-1] in argv.completion_names: - argv.bak_data[-1] = argv.bak_data[-1][:-len(rest[-1])].rstrip() + argv.bak_data[-1] = argv.bak_data[-1][: -len(rest[-1])].rstrip() return handle_completion(self, argv) if handler := argv.special.get(rest[-1]): return _SPECIAL[handler](self, argv) @@ -404,7 +393,10 @@ def analyse(self, argv: Argv[TDC]) -> Arparma[TDC] | None: self.args_result = analyse_args(argv, self.self_args) def export( - self, argv: Argv[TDC], fail: bool = False, exception: BaseException | None = None, + self, + argv: Argv[TDC], + fail: bool = False, + exception: BaseException | None = None, ) -> Arparma[TDC]: """创建 `Arparma` 解析结果, 其一定是一次解析的最后部分 diff --git a/src/arclet/alconna/_internal/_argv.py b/src/arclet/alconna/_internal/_argv.py index 7cde2c69..de9a2eff 100644 --- a/src/arclet/alconna/_internal/_argv.py +++ b/src/arclet/alconna/_internal/_argv.py @@ -16,6 +16,7 @@ @dataclass(repr=True) class Argv(Generic[TDC]): """命令行参数""" + namespace: InitVar[Namespace] = field(default=config.default_namespace) fuzzy_match: bool = field(default=False) """当前命令是否模糊匹配""" @@ -23,7 +24,7 @@ class Argv(Generic[TDC]): """命令元素的预处理器""" to_text: Callable[[Any], str | None] = field(default=lambda x: x if isinstance(x, str) else None) """将命令元素转换为文本, 或者返回None以跳过该元素""" - separators: tuple[str, ...] = field(default=(' ',)) + separators: tuple[str, ...] = field(default=(" ",)) """命令分隔符""" filter_out: list[type] = field(default_factory=list) """需要过滤掉的命令元素""" @@ -60,11 +61,11 @@ def __post_init__(self, namespace: Namespace): self.reset() self.special: dict[str, str] = {} self.special.update( - [(i, "help") for i in namespace.builtin_option_name['help']] + - [(i, "completion") for i in namespace.builtin_option_name['completion']] + - [(i, "shortcut") for i in namespace.builtin_option_name['shortcut']] + [(i, "help") for i in namespace.builtin_option_name["help"]] + + [(i, "completion") for i in namespace.builtin_option_name["completion"]] + + [(i, "shortcut") for i in namespace.builtin_option_name["shortcut"]] ) - self.completion_names = namespace.builtin_option_name['completion'] + self.completion_names = namespace.builtin_option_name["completion"] if __cache := self.__class__._cache.get(self.__class__, {}): self.preprocessors.update(__cache.get("preprocessors") or {}) self.filter_out.extend(__cache.get("filter_out") or []) @@ -218,7 +219,7 @@ def free(self, separate: tuple[str, ...] | None = None): self.bak_data.insert(self.current_index + 1, _rest_text) self.raw_data.insert(self.current_index + 1, _rest_text) self.ndata += 1 - self.bak_data[self.current_index] = self.bak_data[self.current_index][:-len(_current_data)] + self.bak_data[self.current_index] = self.bak_data[self.current_index][: -len(_current_data)] self.raw_data[self.current_index] = "" else: self.bak_data.pop(self.current_index) @@ -235,10 +236,10 @@ def release(self, separate: tuple[str, ...] | None = None, recover: bool = False list[str | Any]: 剩余的数据. """ _result = [] - data = self.bak_data if recover else self.raw_data[self.current_index:] + data = self.bak_data if recover else self.raw_data[self.current_index :] for _data in data: if _data.__class__ is str: - _result.extend(split(_data, separate or (' ',), self.filter_crlf)) + _result.extend(split(_data, separate or (" ",), self.filter_crlf)) else: _result.append(_data) return _result diff --git a/src/arclet/alconna/_internal/_handlers.py b/src/arclet/alconna/_internal/_handlers.py index cc27d75e..4b075e02 100644 --- a/src/arclet/alconna/_internal/_handlers.py +++ b/src/arclet/alconna/_internal/_handlers.py @@ -35,14 +35,15 @@ def _validate(argv: Argv, target: Arg[Any], value: BasePattern[Any], result: dic return default_val = target.field.default res = value.invalidate(arg, default_val) if value.anti else value.validate(arg, default_val) - if res.flag != 'valid': + if res.flag != "valid": argv.rollback(arg) - if res.flag == 'error': + if res.flag == "error": if target.optional: return raise ParamsUnmatched(target.field.get_unmatch_tips(arg, res.error.args[0])) result[target.name] = res._value # noqa + def step_varpos(argv: Argv, args: Args, result: dict[str, Any]): value, arg = args.argument.var_positional argv.context = arg @@ -60,28 +61,26 @@ def step_varpos(argv: Argv, args: Args, result: dict[str, Any]): if not may_arg or (_str and may_arg in argv.param_ids): argv.rollback(may_arg) break - if ( - _str and kwonly_seps and - split_once(pat.match(may_arg)["name"], kwonly_seps, argv.filter_crlf)[0] in args.argument.keyword_only - ): + if _str and kwonly_seps and split_once(pat.match(may_arg)["name"], kwonly_seps, argv.filter_crlf)[0] in args.argument.keyword_only: # noqa: E501 argv.rollback(may_arg) break if _str and args.argument.var_keyword and args.argument.var_keyword[0].base.sep in may_arg: # type: ignore argv.rollback(may_arg) break - if (res := value.base.exec(may_arg)).flag != 'valid': + if (res := value.base.exec(may_arg)).flag != "valid": argv.rollback(may_arg) break _result.append(res._value) # noqa if not _result: if default_val is not None: _result = default_val if isinstance(default_val, Iterable) else () - elif value.flag == '*': + elif value.flag == "*": _result = () else: raise ArgumentMissing(arg.field.get_missing_tips(lang.require("args", "missing").format(key=key))) result[key] = tuple(_result) + def step_varkey(argv: Argv, args: Args, result: dict[str, Any]): value, arg = args.argument.var_keyword argv.context = arg @@ -98,25 +97,26 @@ def step_varkey(argv: Argv, args: Args, result: dict[str, Any]): if not may_arg or (_str and may_arg in argv.param_ids) or not _str: argv.rollback(may_arg) break - if not (_kwarg := re.match(fr'^(-*[^{value.base.sep}]+){value.base.sep}(.*?)$', may_arg)): + if not (_kwarg := re.match(rf"^(-*[^{value.base.sep}]+){value.base.sep}(.*?)$", may_arg)): argv.rollback(may_arg) break key = _kwarg[1] if not (_m_arg := _kwarg[2]): _m_arg, _ = argv.next(arg.separators) - if (res := value.base.base.exec(_m_arg)).flag != 'valid': + if (res := value.base.base.exec(_m_arg)).flag != "valid": argv.rollback(may_arg) break _result[key] = res._value # noqa if not _result: if default_val is not None: _result = default_val if isinstance(default_val, dict) else {} - elif value.flag == '*': + elif value.flag == "*": _result = {} else: raise ArgumentMissing(arg.field.get_missing_tips(lang.require("args", "missing").format(key=name))) result[name] = _result + def step_keyword(argv: Argv, args: Args, result: dict[str, Any]): kwonly_seps = set() for arg in args.argument.keyword_only.values(): @@ -140,7 +140,7 @@ def step_keyword(argv: Argv, args: Args, result: dict[str, Any]): if args.argument.var_keyword or (_str and may_arg in argv.param_ids): break for arg in args.argument.keyword_only.values(): - if arg.value.base.exec(may_arg).flag == 'valid': # type: ignore + if arg.value.base.exec(may_arg).flag == "valid": # type: ignore raise ParamsUnmatched(lang.require("args", "key_missing").format(target=may_arg, key=arg.name)) for name in args.argument.keyword_only: if levenshtein(_key, name) >= config.fuzzy_threshold: @@ -165,6 +165,7 @@ def step_keyword(argv: Argv, args: Args, result: dict[str, Any]): elif not arg.optional: raise ArgumentMissing(arg.field.get_missing_tips(lang.require("args", "missing").format(key=key))) + def analyse_args(argv: Argv, args: Args) -> dict[str, Any]: """ 分析 `Args` 部分 @@ -236,7 +237,7 @@ def handle_option(argv: Argv, opt: Option) -> tuple[str, OptionResult]: if opt.compact: for al in opt.aliases: if mat := re.fullmatch(f"{al}(?P.*?)", name): - argv.rollback(mat['rest'], replace=True) + argv.rollback(mat["rest"], replace=True) error = False break elif opt.action.type == 2: @@ -312,13 +313,13 @@ def analyse_compact_params(analyser: SubAnalyser, argv: Argv): if param.__class__ is Option: if param.requires and analyser.sentences != param.requires: return lang.require("option", "require_error").format( - source=param.name, target=' '.join(analyser.sentences) + source=param.name, target=" ".join(analyser.sentences) ) analyse_option(analyser, argv, param) else: if param.command.requires and analyser.sentences != param.command.requires: return lang.require("subcommand", "require_error").format( - source=param.command.name, target=' '.join(analyser.sentences) + source=param.command.name, target=" ".join(analyser.sentences) ) try: param.process(argv) @@ -377,7 +378,7 @@ def analyse_param(analyser: SubAnalyser, argv: Argv, seps: tuple[str, ...] | Non if _param.__class__ is Option: if _param.requires and analyser.sentences != _param.requires: raise ParamsUnmatched( - lang.require("option", "require_error").format(source=_param.name, target=' '.join(analyser.sentences)) + lang.require("option", "require_error").format(source=_param.name, target=" ".join(analyser.sentences)) ) analyse_option(analyser, argv, _param) elif _param.__class__ is list: @@ -386,9 +387,11 @@ def analyse_param(analyser: SubAnalyser, argv: Argv, seps: tuple[str, ...] | Non _data, _index = argv.data_set() try: if opt.requires and analyser.sentences != opt.requires: - raise ParamsUnmatched(lang.require("option", "require_error").format( - source=opt.name, target=' '.join(analyser.sentences) - )) + raise ParamsUnmatched( + lang.require("option", "require_error").format( + source=opt.name, target=" ".join(analyser.sentences) + ) + ) analyser.sentences = [] analyse_option(analyser, argv, opt) _data.clear() @@ -403,7 +406,7 @@ def analyse_param(analyser: SubAnalyser, argv: Argv, seps: tuple[str, ...] | Non if _param.command.requires and analyser.sentences != _param.command.requires: raise ParamsUnmatched( lang.require("subcommand", "require_error").format( - source=_param.command.name, target=' '.join(analyser.sentences) + source=_param.command.name, target=" ".join(analyser.sentences) ) ) try: @@ -436,14 +439,14 @@ def analyse_header(header: Header, argv: Argv) -> HeadResult: elif content.__class__ is TPattern and (mat := content.fullmatch(head_text)): return HeadResult(head_text, head_text, True, mat.groupdict(), mapping) if header.compact and content.__class__ in (set, TPattern) and (mat := header.compact_pattern.match(head_text)): - argv.rollback(head_text[len(mat[0]):], replace=True) + argv.rollback(head_text[len(mat[0]) :], replace=True) return HeadResult(mat[0], mat[0], True, mat.groupdict(), mapping) if isinstance(content, BasePattern): if (val := content.exec(head_text, Empty)).success: return HeadResult(head_text, val.value, True, fixes=mapping) if header.compact and (val := header.compact_pattern.exec(head_text, Empty)).success: if _str: - argv.rollback(head_text[len(str(val.value)):], replace=True) + argv.rollback(head_text[len(str(val.value)) :], replace=True) return HeadResult(val.value, val.value, True, fixes=mapping) may_cmd, _m_str = argv.next() @@ -490,7 +493,7 @@ def handle_help(analyser: Analyser, argv: Argv): analyser.command.name, lambda: analyser.command.formatter.format_node(_help_param), ) - return analyser.export(argv, True, SpecialOptionTriggered('help')) + return analyser.export(argv, True, SpecialOptionTriggered("help")) _args = Args["action?", "delete|list"]["name?", str]["command", str, "$"] @@ -521,12 +524,13 @@ def handle_shortcut(analyser: Analyser, argv: Argv): output_manager.send(analyser.command.name, lambda: msg) except Exception as e: output_manager.send(analyser.command.name, lambda: str(e)) - return analyser.export(argv, True, SpecialOptionTriggered('shortcut')) + return analyser.export(argv, True, SpecialOptionTriggered("shortcut")) INDEX_SLOT = re.compile(r"\{%(\d+)\}") WILDCARD_SLOT = re.compile(r"\{\*(.*)\}", re.DOTALL) + def _gen_extend(data: list, sep: str): extend = [] for slot in data: @@ -536,6 +540,7 @@ def _gen_extend(data: list, sep: str): extend.append(slot) return extend + def _handle_multi_slot(argv: Argv, unit: str, data: list, index: int, current: int, offset: int): slot = data[index] if not isinstance(slot, str): @@ -547,9 +552,10 @@ def _handle_multi_slot(argv: Argv, unit: str, data: list, index: int, current: i argv.raw_data[current + 2] = right.strip() offset += 1 else: - argv.raw_data[current + offset] = unescape(unit.replace(f"{{%{index}}}", slot)) + argv.raw_data[current + offset] = unescape(unit.replace(f"{{%{index}}}", slot)) return offset + def _handle_shortcut_data(argv: Argv, data: list): data_len = len(data) if not data_len: @@ -573,7 +579,7 @@ def _handle_shortcut_data(argv: Argv, data: list): offset = _handle_multi_slot(argv, unit, data, index, i, offset) record.add(index) elif mat := WILDCARD_SLOT.search(unit): - extend = _gen_extend(data, mat[1] or ' ') + extend = _gen_extend(data, mat[1] or " ") if unit == f"{{*{mat[1]}}}": argv.raw_data.extend(extend) else: @@ -582,6 +588,7 @@ def _handle_shortcut_data(argv: Argv, data: list): break return [unit for i, unit in enumerate(data) if i not in record] + def _handle_shortcut_reg(argv: Argv, groups: tuple[str, ...], gdict: dict[str, str]): for j, unit in enumerate(argv.raw_data): if not isinstance(unit, str): @@ -593,6 +600,7 @@ def _handle_shortcut_reg(argv: Argv, groups: tuple[str, ...], gdict: dict[str, s unit = unit.replace(f"{{{k}}}", v) argv.raw_data[j] = unescape(unit) + def _prompt_unit(analyser: Analyser, argv: Argv, trig: Arg): if comp := trig.field.get_completion(): if isinstance(comp, str): @@ -668,4 +676,4 @@ def handle_completion(analyser: Analyser, argv: Argv, trigger: str | None = None analyser.command.name, lambda: f"{lang.require('completion', 'node')}\n* " + "\n* ".join([i.text for i in res]), ) - return analyser.export(argv, True, SpecialOptionTriggered('completion')) # type: ignore + return analyser.export(argv, True, SpecialOptionTriggered("completion")) # type: ignore diff --git a/src/arclet/alconna/_internal/_header.py b/src/arclet/alconna/_internal/_header.py index 2140fa63..671cc393 100644 --- a/src/arclet/alconna/_internal/_header.py +++ b/src/arclet/alconna/_internal/_header.py @@ -45,6 +45,7 @@ def handle_bracket(name: str, mapping: dict): class Pair: """用于匹配前缀和命令的配对""" + __slots__ = ("prefix", "pattern", "is_prefix_pat", "gd_supplier", "_match") def __init__(self, prefix: Any, pattern: TPattern | str): @@ -58,7 +59,7 @@ def _match(command: str, pbfn: Callable[..., ...], comp: bool): if command == self.pattern: return command, None if comp and command.startswith(self.pattern): - pbfn(command[len(self.pattern):], replace=True) + pbfn(command[len(self.pattern) :], replace=True) return self.pattern, None return None, None @@ -69,9 +70,10 @@ def _match(command: str, pbfn: Callable[..., ...], comp: bool): if mat := self.pattern.fullmatch(command): return command, mat if comp and (mat := self.pattern.match(command)): - pbfn(command[len(mat[0]):], replace=True) + pbfn(command[len(mat[0]) :], replace=True) return mat[0], mat return None, None + self._match = _match def match(self, _pf: Any, command: str, pbfn: Callable[..., ...], comp: bool): @@ -86,6 +88,7 @@ def match(self, _pf: Any, command: str, pbfn: Callable[..., ...], comp: bool): class Double: """用于匹配前缀和命令的组合""" + command: TPattern | BasePattern | str def __init__( @@ -126,16 +129,16 @@ def match0(self, pf: Any, cmd: Any, p_str: bool, c_str: bool, pbfn: Callable[... return (pf, cmd), (pf, val.value), True, None if comp and (val := self.comp_pattern.exec(cmd, Empty)).success: if c_str: - pbfn(cmd[len(str(val.value)):], replace=True) - return (pf, cmd), (pf, cmd[:len(str(val.value))]), True, None + pbfn(cmd[len(str(val.value)) :], replace=True) + return (pf, cmd), (pf, cmd[: len(str(val.value))]), True, None return if (val := self.patterns.exec(pf, Empty)).success: if (val2 := self.command.exec(cmd, Empty)).success: return (pf, cmd), (val.value, val2.value), True, None if comp and (val2 := self.comp_pattern.exec(cmd, Empty)).success: if c_str: - pbfn(cmd[len(str(val2.value)):], replace=True) - return (pf, cmd), (val.value, cmd[:len(str(val2.value))]), True, None + pbfn(cmd[len(str(val2.value)) :], replace=True) + return (pf, cmd), (val.value, cmd[: len(str(val2.value))]), True, None return def match1(self, pf: Any, cmd: Any, p_str: bool, c_str: bool, pbfn: Callable[..., ...], comp: bool): @@ -144,7 +147,7 @@ def match1(self, pf: Any, cmd: Any, p_str: bool, c_str: bool, pbfn: Callable[... if (val := self.patterns.exec(pf, Empty)).success and (mat := self.command.fullmatch(cmd)): return (pf, cmd), (val.value, cmd), True, mat.groupdict() if comp and (mat := self.comp_pattern.match(cmd)): - pbfn(cmd[len(mat[0]):], replace=True) + pbfn(cmd[len(mat[0]) :], replace=True) return (pf, cmd), (pf, mat[0]), True, mat.groupdict() def match(self, pf: Any, cmd: Any, p_str: bool, c_str: bool, pbfn: Callable[..., ...], comp: bool): @@ -160,26 +163,27 @@ def match(self, pf: Any, cmd: Any, p_str: bool, c_str: bool, pbfn: Callable[..., return pf, pf, True, mat.groupdict() if comp and (mat := self.comp_pattern.match(pf)): pbfn(cmd) - pbfn(pf[len(mat[0]):], replace=True) + pbfn(pf[len(mat[0]) :], replace=True) return mat[0], mat[0], True, mat.groupdict() if not c_str: return if mat := self.prefix.fullmatch((name := pf + cmd)): return name, name, True, mat.groupdict() if comp and (mat := self.comp_pattern.match(name)): - pbfn(name[len(mat[0]):], replace=True) + pbfn(name[len(mat[0]) :], replace=True) return mat[0], mat[0], True, mat.groupdict() return if (val := self.patterns.exec(pf, Empty)).success: if mat := self.command.fullmatch(cmd): return (pf, cmd), (val.value, cmd), True, mat.groupdict() if comp and (mat := self.command.match(cmd)): - pbfn(cmd[len(mat[0]):], replace=True) + pbfn(cmd[len(mat[0]) :], replace=True) return (pf, cmd), (val.value, mat[0]), True, mat.groupdict() class Header: """命令头部的匹配表达式""" + __slots__ = ("origin", "content", "mapping", "compact", "compact_pattern") def __init__( @@ -215,10 +219,10 @@ def generate( return cls((command, prefixes), cmd, mapping, compact, re.compile(f"^{_cmd}")) if isinstance(prefixes[0], tuple): return cls( - (command, prefixes), [ - Pair(h[0], re.compile(re.escape(h[1]) + _cmd) if to_regex else h[1] + _cmd) - for h in prefixes - ], mapping, compact + (command, prefixes), + [Pair(h[0], re.compile(re.escape(h[1]) + _cmd) if to_regex else h[1] + _cmd) for h in prefixes], + mapping, + compact, ) if all(isinstance(h, str) for h in prefixes): prf = "|".join(re.escape(h) for h in prefixes) diff --git a/src/arclet/alconna/_internal/_util.py b/src/arclet/alconna/_internal/_util.py index 296133e5..ffe1b86d 100644 --- a/src/arclet/alconna/_internal/_util.py +++ b/src/arclet/alconna/_internal/_util.py @@ -1,5 +1,5 @@ def levenshtein(source: str, target: str) -> float: - """ `编辑距离算法`_, 计算源字符串与目标字符串的相似度, 取值范围[0, 1], 值越大越相似 + """`编辑距离算法`_, 计算源字符串与目标字符串的相似度, 取值范围[0, 1], 值越大越相似 Args: source (str): 源字符串 @@ -33,7 +33,7 @@ def escape(string: str) -> str: def unescape(string: str) -> str: - """逆转义字符串, 自动去除空白符 """ + """逆转义字符串, 自动去除空白符""" for k, v in R_ESCAPE.items(): string = string.replace(k, v) return string.strip() diff --git a/src/arclet/alconna/action.py b/src/arclet/alconna/action.py index f2555582..c3b56f91 100644 --- a/src/arclet/alconna/action.py +++ b/src/arclet/alconna/action.py @@ -5,6 +5,7 @@ class ActType(IntEnum): """节点触发的动作类型""" + STORE = 0 """无 Args 时, 仅存储一个值, 默认为 Ellipsis; 有 Args 时, 后续的解析结果会覆盖之前的值""" APPEND = 1 @@ -22,6 +23,7 @@ class ActType(IntEnum): @dataclass(eq=True, frozen=True) class Action: """节点触发的动作""" + type: ActType value: Any diff --git a/src/arclet/alconna/args.py b/src/arclet/alconna/args.py index 67132be9..a6c7013e 100644 --- a/src/arclet/alconna/args.py +++ b/src/arclet/alconna/args.py @@ -19,7 +19,7 @@ def safe_dcls_kw(**kwargs): if sys.version_info < (3, 10): # pragma: no cover - kwargs.pop('slots') + kwargs.pop("slots") return kwargs @@ -30,7 +30,8 @@ def safe_dcls_kw(**kwargs): class ArgFlag(str, Enum): """标识参数单元的特殊属性""" - OPTIONAL = '?' + + OPTIONAL = "?" HIDDEN = "/" ANTI = "!" @@ -57,7 +58,7 @@ def display(self): def get_completion(self): """返回参数单元的补全""" - return None if not self.completion else self.completion() + return self.completion() if self.completion else None def get_unmatch_tips(self, value: Any, fallback: str): """返回参数单元的错误提示""" @@ -112,7 +113,7 @@ def __init__( notice (str, optional): 参数单元的注释. Defaults to None. flags (list[ArgFlag], optional): 参数单元的标识. Defaults to None. """ - if not isinstance(name, str) or name.startswith('$'): + if not isinstance(name, str) or name.startswith("$"): raise InvalidParam(lang.require("args", "name_error")) if not name.strip(): raise InvalidParam(lang.require("args", "name_empty")) @@ -148,7 +149,6 @@ def __add__(self, other) -> "Args": if isinstance(other, Arg): return Args(self, other) raise TypeError(f"unsupported operand type(s) for +: 'Arg' and '{other.__class__.__name__}'") - class ArgsMeta(type): @@ -172,8 +172,10 @@ def __getitem__(self, item: Union[Arg, tuple[Arg, ...], str, tuple[Any, ...]], k return self(*data) return self(Arg(key, *data)) if key else self(Arg(*data)) # type: ignore + NULL = {Empty: None, None: Empty} + class _argument(List[Arg[Any]]): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -183,6 +185,7 @@ def __init__(self, *args, **kwargs): self.keyword_only: dict[str, Arg[Any]] = {} self.unpack: tuple[Arg, Args] | None = None + def gen_unpack(var: UnpackVar): unpack = Args() for field in var.fields: @@ -200,6 +203,7 @@ def gen_unpack(var: UnpackVar): var.alias = f"{var.alias}{'()' if unpack.empty else f'{unpack}'[4:]}" return unpack + class Args(metaclass=ArgsMeta): """参数集合 @@ -215,6 +219,7 @@ class Args(metaclass=ArgsMeta): >>> Args.name[str] Args('name': str) """ + argument: _argument @classmethod @@ -242,7 +247,7 @@ def from_callable(cls, target: Callable, kw_sep: str = "=") -> tuple[Args, bool] de = NULL.get(de, de) if param.kind == param.KEYWORD_ONLY: if anno == bool: - anno = KWBool(f"(?:-*no)?-*{name}", MatchMode.REGEX_CONVERT, bool, lambda _, x: not x[0].lstrip("-").startswith('no')) + anno = KWBool(f"(?:-*no)?-*{name}", MatchMode.REGEX_CONVERT, bool, lambda _, x: not x[0].lstrip("-").startswith('no')) # noqa: E501 anno = KeyWordVar(anno, sep=kw_sep) if param.kind == param.VAR_POSITIONAL: anno = MultiVar(anno, "*") @@ -346,7 +351,7 @@ def __check_vars__(self): if isinstance(arg.value.base, KeyWordVar): if self.argument.var_keyword: raise InvalidParam(lang.require("args", "duplicate_kwargs")) - if self.argument.var_positional and arg.value.base.sep in self.argument.var_positional[1].separators: + if self.argument.var_positional and arg.value.base.sep in self.argument.var_positional[1].separators: # noqa: E501 raise InvalidParam("varkey cannot use the same sep as varpos's Arg") self.argument.var_keyword = (arg.value, arg) elif self.argument.var_positional: diff --git a/src/arclet/alconna/argv.py b/src/arclet/alconna/argv.py index 32695bdd..eba69598 100644 --- a/src/arclet/alconna/argv.py +++ b/src/arclet/alconna/argv.py @@ -20,7 +20,7 @@ def argv_config( to_text: Callable[[Any], str | None] | None = None, filter_out: list[type] | None = None, checker: Callable[[Any], bool] | None = None, - converter: Callable[[str | list], TDC] | None = None + converter: Callable[[str | list], TDC] | None = None, ): """配置命令行参数 @@ -32,6 +32,6 @@ def argv_config( checker (Callable[[Any], bool] | None, optional): 检查传入命令. converter (Callable[[str | list], TDC] | None, optional): 将字符串或列表转为目标命令类型. """ - Argv._cache.setdefault( - target or __argv_type__.get(), {} - ).update({k: v for k, v in locals().items() if v is not None}) + Argv._cache.setdefault(target or __argv_type__.get(), {}).update( + {k: v for k, v in locals().items() if v is not None} + ) diff --git a/src/arclet/alconna/arparma.py b/src/arclet/alconna/arparma.py index f9e2c327..478fa124 100644 --- a/src/arclet/alconna/arparma.py +++ b/src/arclet/alconna/arparma.py @@ -14,9 +14,9 @@ from .model import HeadResult, OptionResult, SubcommandResult from .typing import TDC -T = TypeVar('T') -T1 = TypeVar('T1') -D = TypeVar('D') +T = TypeVar("T") +T1 = TypeVar("T1") +D = TypeVar("D") def _handle_opt(_pf: str, _parts: list[str], _opts: dict[str, OptionResult]): @@ -29,7 +29,7 @@ def _handle_opt(_pf: str, _parts: list[str], _opts: dict[str, OptionResult]): return _opts, _pf if (_end := _parts.pop(0)) == "value": return __src, _end - if _end == 'args': + if _end == "args": return (__src.args, _parts.pop(0)) if _parts else (__src, _end) return __src.args, _end @@ -44,7 +44,7 @@ def _handle_sub(_pf: str, _parts: list[str], _subs: dict[str, SubcommandResult]) return _subs, _pf if (_end := _parts.pop(0)) == "value": return __src, _end - if _end == 'args': + if _end == "args": return (__src.args, _parts.pop(0)) if _parts else (__src, _end) if _end == "options" and (_end in __src.options or not _parts): raise RuntimeError(lang.require("arparma", "ambiguous_name").format(target=f"{_pf}.{_end}")) @@ -59,6 +59,7 @@ def _handle_sub(_pf: str, _parts: list[str], _subs: dict[str, SubcommandResult]) class _Query(Generic[T]): source: Arparma + def __get__(self, instance: Arparma, owner: type) -> _Query[T]: self.source = instance return self @@ -77,20 +78,21 @@ def __call__(self, path: str) -> T | None: def __call__(self, path: str, default: D) -> T | D: ... - def __call__(self, path: str, default: D | None = None) -> T | D | None: + def __call__(self, path: str, default: D | None = None) -> T | D | None: """查询 `Arparma` 中的数据 Args: path (str): 要查询的路径 default (T | None, optional): 如果查询失败, 则返回该值 """ - source, endpoint = self.source.__require__(path.split('.')) + source, endpoint = self.source.__require__(path.split(".")) if source is None: return default if isinstance(source, (OptionResult, SubcommandResult)): return getattr(source, endpoint, default) if endpoint else source # type: ignore return source.get(endpoint, default) if endpoint else MappingProxyType(source) # type: ignore + class Arparma(Generic[TDC]): """承载解析结果与操作数据的接口类 @@ -105,6 +107,7 @@ class Arparma(Generic[TDC]): options (dict[str, OptionResult]): 选项匹配结果 subcommands (dict[str, SubcommandResult]): 子命令匹配结果 """ + header_match: HeadResult options: dict[str, OptionResult] subcommands: dict[str, SubcommandResult] @@ -186,6 +189,7 @@ def all_matched_args(self) -> dict[str, Any]: def token(self) -> int: """返回命令的 Token""" from .manager import command_manager + return command_manager.get_token(self) def _unpack_opts(self, _data): @@ -255,7 +259,7 @@ def call(self, target: Callable[..., T]) -> T: **self.all_matched_args, "all_args": self.all_matched_args, "options": self.options, - "subcommands": self.subcommands + "subcommands": self.subcommands, } sig = inspect.signature(target) @@ -286,8 +290,8 @@ def __require__(self, parts: list[str]) -> tuple[dict[str, Any] | OptionResult | if part in src: return src, part if part in {"options", "subcommands", "main_args", "other_args"}: - return getattr(self, part, {}), '' - return (self.all_matched_args, '') if part == "args" else (None, part) + return getattr(self, part, {}), "" + return (self.all_matched_args, "") if part == "args" else (None, part) prefix = parts.pop(0) # parts[0] if prefix in {"options", "subcommands"} and prefix in self.components: raise RuntimeError(lang.require("arparma", "ambiguous_name").format(target=prefix)) @@ -300,7 +304,8 @@ def __require__(self, parts: list[str]) -> tuple[dict[str, Any] | OptionResult | return getattr(self, prefix, {}), parts.pop(0) return None, prefix - def query_with(self, arg_type: type[T], *args): return self.query[arg_type](*args) + def query_with(self, arg_type: type[T], *args): + return self.query[arg_type](*args) def find(self, path: str) -> bool: """查询路径是否存在 @@ -345,7 +350,7 @@ def __getitem__(self, item: str | type[T] | tuple[type[T], int]) -> T | Any | No return next(i for i in self.all_matched_args.values() if generic_isinstance(i, item)) def __getattr__(self, item: str): - return self.all_matched_args.get(item, self.query(item.replace('_', '.'))) + return self.all_matched_args.get(item, self.query(item.replace("_", "."))) def __repr__(self): if not self.matched: @@ -353,9 +358,12 @@ def __repr__(self): return ", ".join([f"{a}={v}" for a, v in attrs]) else: attrs = { - "matched": self.matched, "header_match": self.header_match, - "options": self.options, "subcommands": self.subcommands, - "main_args": self.main_args, "other_args": self.other_args + "matched": self.matched, + "header_match": self.header_match, + "options": self.options, + "subcommands": self.subcommands, + "main_args": self.main_args, + "other_args": self.other_args, } return ", ".join([f"{a}={v}" for a, v in attrs.items() if v]) @@ -367,6 +375,7 @@ class ArparmaBehavior(metaclass=ABCMeta): Attributes: requires (list[ArparmaBehavior]): 该行为器所依赖的行为器 """ + record: dict[int, dict[str, tuple[Any, Any]]] = field(default_factory=dict, init=False, repr=False, hash=False) requires: list[ArparmaBehavior] = field(init=False, hash=False, repr=False) @@ -404,6 +413,7 @@ def update(self, interface: Arparma, path: str, value: Any): path (str): 要更新的路径 value (Any): 要更新的值 """ + def _update(tkn, src, pth, ep, val): _record = self.record.setdefault(tkn, {}) if isinstance(src, dict): @@ -426,7 +436,7 @@ def _update(tkn, src, pth, ep, val): @lru_cache(4096) def requirement_handler(behavior: ArparmaBehavior) -> list[ArparmaBehavior]: res = [] - for b in getattr(behavior, 'requires', []): + for b in getattr(behavior, "requires", []): res.extend(requirement_handler(b)) res.append(behavior) return res diff --git a/src/arclet/alconna/base.py b/src/arclet/alconna/base.py index 008fac98..3be33380 100644 --- a/src/arclet/alconna/base.py +++ b/src/arclet/alconna/base.py @@ -65,8 +65,12 @@ class CommandNode: """命令节点需求前缀""" def __init__( - self, name: str, args: Arg | Args | None = None, - dest: str | None = None, default: Any = None, action: Action | None = None, + self, + name: str, + args: Arg | Args | None = None, + dest: str | None = None, + default: Any = None, + action: Action | None = None, separators: str | Sequence[str] | set[str] | None = None, help_text: str | None = None, requires: str | list[str] | tuple[str, ...] | set[str] | None = None, @@ -94,13 +98,9 @@ def __init__( self.default = default self.action = action or store _handle_default(self) - self.separators = (' ',) if separators is None else ( - (separators,) if isinstance(separators, str) else tuple(separators) - ) + self.separators = (" ",) if separators is None else ((separators,) if isinstance(separators, str) else tuple(separators)) # noqa: E501 self.nargs = len(self.args.argument) - self.dest = ( - dest or (("_".join(self.requires) + "_") if self.requires else "") + self.name.lstrip('-') - ).lstrip('-') + self.dest = (dest or (("_".join(self.requires) + "_") if self.requires else "") + self.name.lstrip("-")).lstrip("-") # noqa: E501 self.help_text = help_text or self.dest self._hash = self._calc_hash() @@ -157,12 +157,17 @@ class Option(CommandNode): def __init__( self, - name: str, args: Arg | Args | None = None, alias: Iterable[str] | None = None, - dest: str | None = None, default: Any = None, action: Action | None = None, + name: str, + args: Arg | Args | None = None, + alias: Iterable[str] | None = None, + dest: str | None = None, + default: Any = None, + action: Action | None = None, separators: str | Sequence[str] | set[str] | None = None, help_text: str | None = None, requires: str | list[str] | tuple[str, ...] | set[str] | None = None, - compact: bool = False, priority: int = 0, + compact: bool = False, + priority: int = 0, ): """初始化命令选项 @@ -191,10 +196,7 @@ def __init__( self.aliases = frozenset(aliases) self.priority = priority self.compact = compact - default = ( - None if default is None else - default if isinstance(default, OptionResult) else OptionResult(default) - ) + default = None if default is None else default if isinstance(default, OptionResult) else OptionResult(default) super().__init__(name, args, dest, default, action, separators, help_text, requires) if self.separators == ("",): self.compact = True @@ -221,10 +223,7 @@ def __add__(self, other: Option | Args | Arg) -> Self | Subcommand: TypeError: 如果other不是命令选项或命令节点, 则抛出此异常 """ if isinstance(other, Option): - return Subcommand( - self.name, other, self.args, dest=self.dest, - separators=self.separators, help_text=self.help_text, requires=self.requires - ) + return Subcommand(self.name, other, self.args, dest=self.dest, separators=self.separators, help_text=self.help_text, requires=self.requires) # noqa: E501 if isinstance(other, (Arg, Args)): self.args += other self.nargs = len(self.args) @@ -246,6 +245,7 @@ def __radd__(self, other: str): """ if isinstance(other, str): from .core import Alconna + return Alconna(other, self) raise TypeError(f"unsupported operand type(s) for +: '{other.__class__.__name__}' and 'Option'") @@ -255,6 +255,7 @@ class Subcommand(CommandNode): 与命令节点不同, 子命令可以包含多个命令选项与相对于自己的子命令 """ + default: SubcommandResult | None """子命令默认值""" options: list[Option | Subcommand] @@ -264,7 +265,8 @@ def __init__( self, name: str, *args: Args | Arg | Option | Subcommand | list[Option | Subcommand], - dest: str | None = None, default: Any = None, + dest: str | None = None, + default: Any = None, separators: str | Sequence[str] | set[str] | None = None, help_text: str | None = None, requires: str | list[str] | tuple[str, ...] | set[str] | None = None, @@ -285,15 +287,8 @@ def __init__( for li in args: if isinstance(li, list): self.options.extend(li) - default = ( - None if default is None else - default if isinstance(default, SubcommandResult) else SubcommandResult(default) - ) - super().__init__( - name, - reduce(lambda x, y: x + y, [Args()] + [i for i in args if isinstance(i, (Arg, Args))]), # type: ignore - dest, default, None, separators, help_text, requires - ) + default = None if default is None else (default if isinstance(default, SubcommandResult) else SubcommandResult(default)) # noqa: E501 + super().__init__(name, reduce(lambda x, y: x + y, [Args()] + [i for i in args if isinstance(i, (Arg, Args))]), dest, default, None, separators, help_text, requires) # type: ignore # noqa: E501 def __add__(self, other: Option | Args | Arg | str) -> Self: """连接子命令与命令选项或命令节点 @@ -332,6 +327,7 @@ def __radd__(self, other: str): """ if isinstance(other, str): from .core import Alconna + return Alconna(other, self) raise TypeError(f"unsupported operand type(s) for +: '{other.__class__.__name__}' and 'Subcommand'") diff --git a/src/arclet/alconna/builtin.py b/src/arclet/alconna/builtin.py index 8a9de2a1..6308692c 100644 --- a/src/arclet/alconna/builtin.py +++ b/src/arclet/alconna/builtin.py @@ -15,21 +15,29 @@ def generate_duplication(alc: Alconna) -> type[Duplication]: """依据给定的命令生成一个解析结果的检查类。""" from .base import Option, Subcommand + options = filter(lambda x: isinstance(x, Option), alc.options) subcommands = filter(lambda x: isinstance(x, Subcommand), alc.options) - return cast(type[Duplication], type( - f"{alc.name.strip('/.-:')}Interface", - (Duplication,), { - "__annotations__": { - "args": ArgsStub, - **{opt.dest: OptionStub for opt in options}, - **{sub.dest: SubcommandStub for sub in subcommands}, - } - } - )) - - -class _MISSING_TYPE: pass + return cast( + type[Duplication], + type( + f"{alc.name.strip('/.-:')}Interface", + (Duplication,), + { + "__annotations__": { + "args": ArgsStub, + **{opt.dest: OptionStub for opt in options}, + **{sub.dest: SubcommandStub for sub in subcommands}, + } + }, + ), + ) + + +class _MISSING_TYPE: + pass + + MISSING = _MISSING_TYPE() @@ -56,17 +64,28 @@ def operate(self, interface: Arparma): @overload -def set_default(*, value: Any, path: str,) -> _SetDefault: +def set_default( + *, + value: Any, + path: str, +) -> _SetDefault: ... @overload -def set_default(*, factory: Callable[..., Any], path: str,) -> _SetDefault: +def set_default( + *, + factory: Callable[..., Any], + path: str, +) -> _SetDefault: ... def set_default( - *, value: Any = MISSING, factory: Callable[..., Any] = MISSING, path: str | None = None, + *, + value: Any = MISSING, + factory: Callable[..., Any] = MISSING, + path: str | None = None, ) -> _SetDefault: """ 设置一个选项的默认值, 在无该选项时会被设置 diff --git a/src/arclet/alconna/completion.py b/src/arclet/alconna/completion.py index 03697079..c3d57e04 100644 --- a/src/arclet/alconna/completion.py +++ b/src/arclet/alconna/completion.py @@ -20,6 +20,7 @@ class Prompt: can_use: bool = field(default=True, hash=False) removal_prefix: str | None = field(default=None, hash=False) + @dataclass class EnterResult: result: Arparma | None = None @@ -117,13 +118,13 @@ def enter(self, content: list | None = None) -> EnterResult: if not prompt.can_use: return EnterResult(exception=ValueError(lang.require("completion", "prompt_unavailable"))) if prompt.removal_prefix: - argv.bak_data[-1] = argv.bak_data[-1][:-len(prompt.removal_prefix)] + argv.bak_data[-1] = argv.bak_data[-1][: -len(prompt.removal_prefix)] argv.next(move=True) input_ = [prompt.text] if isinstance(self.trigger, ParamsUnmatched): - argv.raw_data = argv.bak_data[:max(_i, 1)] + argv.raw_data = argv.bak_data[: max(_i, 1)] argv.addon(input_) - argv.raw_data.extend(_r[max(_i, 1):]) + argv.raw_data.extend(_r[max(_i, 1) :]) else: argv.raw_data = argv.bak_data.copy() argv.addon(input_) @@ -182,10 +183,7 @@ def exit(self): def lines(self): """获取补全选项的文本列表。""" - return [ - f"{'>>' if self.index == index else '*'} {sug.text}" - for index, sug in enumerate(self.prompts) - ] + return [f"{'>>' if self.index == index else '*'} {sug.text}" for index, sug in enumerate(self.prompts)] def __repr__(self): return f"{lang.require('completion', 'node')}\n" + "\n".join(self.lines()) @@ -214,4 +212,5 @@ def fresh(self, exc: PauseTriggered): self.trigger = exc.args[1] return True + comp_ctx: ContextModel[CompSession] = ContextModel("comp_ctx") diff --git a/src/arclet/alconna/config.py b/src/arclet/alconna/config.py index c7d8aede..a4fdca3b 100644 --- a/src/arclet/alconna/config.py +++ b/src/arclet/alconna/config.py @@ -24,6 +24,7 @@ class OptionNames(TypedDict): @dataclass(init=True, repr=True) class Namespace: """命名空间配置, 用于规定同一命名空间下的选项的默认配置""" + name: str """命名空间名称""" prefixes: TPrefixes = field(default_factory=list) @@ -102,6 +103,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): class _AlconnaConfig: """全局配置类""" + command_max_count: int = 200 """最大命令数量""" fuzzy_threshold: float = 0.6 diff --git a/src/arclet/alconna/core.py b/src/arclet/alconna/core.py index e3d03a55..4f57d8e4 100644 --- a/src/arclet/alconna/core.py +++ b/src/arclet/alconna/core.py @@ -21,7 +21,7 @@ from .manager import ShortcutArgs, command_manager from .typing import TDC, CommandMeta, DataCollection, TPrefixes -T_Duplication = TypeVar('T_Duplication', bound=Duplication) +T_Duplication = TypeVar("T_Duplication", bound=Duplication) T = TypeVar("T") TDC1 = TypeVar("TDC1", bound=DataCollection[Any]) @@ -37,22 +37,15 @@ def handle_argv(): def add_builtin_options(options: list[Option | Subcommand], ns: Namespace) -> None: - options.append( - Option("|".join(ns.builtin_option_name['help']), help_text=lang.require("builtin", "option_help")), - ) + options.append(Option("|".join(ns.builtin_option_name["help"]), help_text=lang.require("builtin", "option_help"))) # noqa: E501 options.append( Option( - "|".join(ns.builtin_option_name['shortcut']), + "|".join(ns.builtin_option_name["shortcut"]), Args["action?", "delete|list"]["name?", str]["command", str, "$"], - help_text=lang.require("builtin", "option_shortcut") - ) - ) - options.append( - Option( - "|".join(ns.builtin_option_name['completion']), - help_text=lang.require("builtin", "option_completion") + help_text=lang.require("builtin", "option_shortcut"), ) ) + options.append(Option("|".join(ns.builtin_option_name["completion"]), help_text=lang.require("builtin", "option_completion"))) # noqa: E501 @dataclass(init=True, unsafe_hash=True) @@ -62,6 +55,7 @@ class ArparmaExecutor(Generic[T]): Attributes: target(Callable[..., T]): 目标函数 """ + target: Callable[..., T] binding: Callable[..., list[Arparma]] = field(default=lambda: [], repr=False) @@ -101,6 +95,7 @@ class Alconna(Subcommand, Generic[TDC]): ... ) >>> alc.parse("name opt opt_arg") """ + prefixes: TPrefixes """命令前缀""" command: str | Any @@ -126,7 +121,7 @@ def __init__( namespace: str | Namespace | None = None, separators: str | set[str] | Sequence[str] | None = None, behaviors: list[ArparmaBehavior] | None = None, - formatter_type: type[TextFormatter] | None = None + formatter_type: type[TextFormatter] | None = None, ): """ 以标准形式构造 `Alconna` @@ -164,10 +159,7 @@ def __init__( name = f"{self.command or self.prefixes[0]}" # type: ignore self.path = f"{self.namespace}::{name}" _args = sum((i for i in args if isinstance(i, (Args, Arg))), Args()) - super().__init__( - "ALCONNA::", - _args, *options, dest=name, separators=separators or ns_config.separators, help_text=self.meta.description - ) + super().__init__("ALCONNA::", _args, *options, dest=name, separators=separators or ns_config.separators, help_text=self.meta.description) # noqa: E501 self.name = name self.behaviors = [] for behavior in behaviors or []: @@ -235,8 +227,9 @@ def shortcut(self, key: str, args: ShortcutArgs | None = None, delete: bool = Fa if alc and alc == self: return command_manager.add_shortcut(self, key, {"command": cmd}) # type: ignore raise ValueError( - lang.require("shortcut", "recent_command_error") - .format(target=self.path, source=getattr(alc, "path", "Unknown")) + lang.require("shortcut", "recent_command_error").format( + target=self.path, source=getattr(alc, "path", "Unknown") + ) ) else: raise ValueError(lang.require("shortcut", "no_recent_command")) @@ -293,7 +286,7 @@ def parse(self, message, *, duplication: type[T_Duplication]) -> T_Duplication: def parse(self, message: TDC, *, duplication: type[T_Duplication] | None = None) -> Arparma[TDC] | T_Duplication: """命令分析功能, 传入字符串或消息链, 返回一个特定的数据集合类 - + Args: message (TDC): 命令消息 duplication (type[T_Duplication], optional): 指定的`副本`类型 @@ -321,18 +314,19 @@ def bind(self, active: bool = True): Args: active (bool, optional): 该执行器是否由 `Alconna` 主动调用, 默认为 `True` """ + def wrapper(target: Callable[..., T]) -> ArparmaExecutor[T]: ext = ArparmaExecutor(target, lambda: command_manager.get_result(self)) if active: self._executors[ext] = None return ext + return wrapper @property def exec_result(self) -> dict[str, Any]: return {ext.target.__name__: res for ext, res in self._executors.items() if res is not None} - def __truediv__(self, other) -> Self: return self.reset_namespace(other) diff --git a/src/arclet/alconna/duplication.py b/src/arclet/alconna/duplication.py index 0ceb56cb..777d0e1e 100644 --- a/src/arclet/alconna/duplication.py +++ b/src/arclet/alconna/duplication.py @@ -11,11 +11,13 @@ class Duplication: """`副本`, 用以更方便的检查、调用解析结果的类。""" + header: dict[str, str] def __init__(self, target: Arparma): from .base import Option, Subcommand from .manager import command_manager + source = command_manager.get_command(target.source) self.header = target.header.copy() for key, value in self.__annotations__.items(): @@ -30,11 +32,11 @@ def __init__(self, target: Arparma): for option in source.options: if isinstance(option, Option) and option.dest == key: setattr(self, key, OptionStub(option).set_result(target.options.get(key, None))) - elif key != 'header': + elif key != "header": setattr(self, key, target.all_matched_args.get(key, Empty)) def __repr__(self): - return f'{self.__class__.__name__}({self.__annotations__})' + return f"{self.__class__.__name__}({self.__annotations__})" def option(self, name: str) -> OptionStub | None: """获取指定名称的选项存根。""" diff --git a/src/arclet/alconna/formatter.py b/src/arclet/alconna/formatter.py index 5a3ccf7d..b7ec760e 100644 --- a/src/arclet/alconna/formatter.py +++ b/src/arclet/alconna/formatter.py @@ -63,11 +63,13 @@ class Trace: 该结构用于存放命令节点的数据,包括命令节点的头部、参数、分隔符和主体。 """ + head: dict[str, Any] args: Args separators: tuple[str, ...] body: list[Option | Subcommand] + class TextFormatter: """帮助文档格式化器 @@ -80,18 +82,23 @@ def __init__(self): def add(self, base: Alconna): """添加目标命令""" - self.ignore_names.update(base.namespace_config.builtin_option_name['help']) - self.ignore_names.update(base.namespace_config.builtin_option_name['shortcut']) - self.ignore_names.update(base.namespace_config.builtin_option_name['completion']) + self.ignore_names.update(base.namespace_config.builtin_option_name["help"]) + self.ignore_names.update(base.namespace_config.builtin_option_name["shortcut"]) + self.ignore_names.update(base.namespace_config.builtin_option_name["completion"]) pfs = base.prefixes.copy() if base.name in pfs: pfs.remove(base.name) # type: ignore res = Trace( { - 'name': base.name, 'prefix': pfs or [], 'description': base.meta.description, - 'usage': base.meta.usage, 'example': base.meta.example + "name": base.name, + "prefix": pfs or [], + "description": base.meta.description, + "usage": base.meta.usage, + "example": base.meta.example, }, - base.args, base.separators, base.options.copy() + base.args, + base.separators, + base.options.copy(), ) self.data[base] = res return self @@ -106,8 +113,9 @@ def format_node(self, parts: list | None = None): Args: parts (list | None, optional): 可能的节点路径. """ + def _handle(trace: Trace): - if not parts or parts == ['']: + if not parts or parts == [""]: return self.format(trace) _cache = resolve_requires(trace.body) _parts = [] @@ -128,20 +136,17 @@ def _handle(trace: Trace): elif i not in _visited: _opts.append(i) _visited.add(i) - return self.format(Trace( - {"name": _parts[-1], 'prefix': [], 'description': _parts[-1]}, Args(), trace.separators, - _opts - )) + return self.format( + Trace({"name": _parts[-1], 'prefix': [], 'description': _parts[-1]}, Args(), trace.separators, _opts) # noqa: E501 + ) if isinstance(_cache, Option): - return self.format(Trace( - {"name": "", "prefix": list(_cache.aliases), "description": _cache.help_text}, _cache.args, - _cache.separators, [] - )) + return self.format( + Trace({"name": "", "prefix": list(_cache.aliases), "description": _cache.help_text}, _cache.args, _cache.separators, []) # noqa: E501 + ) if isinstance(_cache, Subcommand): - return self.format(Trace( - {"name": _cache.name, "prefix": [], "description": _cache.help_text}, _cache.args, - _cache.separators, _cache.options # type: ignore - )) + return self.format( + Trace({"name": _cache.name, "prefix": [], "description": _cache.help_text}, _cache.args, _cache.separators, _cache.options) # noqa: E501 + ) return self.format(trace) return "\n".join([_handle(v) for v in self.data.values()]) @@ -187,19 +192,19 @@ def parameters(self, args: Args) -> str: """ res = "" for arg in args.argument: - if arg.name.startswith('_key_'): + if arg.name.startswith("_key_"): continue if len(arg.separators) == 1: - sep = ' ' if arg.separators[0] == ' ' else f' {arg.separators[0]!r} ' + sep = " " if arg.separators[0] == " " else f" {arg.separators[0]!r} " else: sep = f"[{'|'.join(arg.separators)!r}]" res += self.param(arg) + sep notice = [(arg.name, arg.notice) for arg in args.argument if arg.notice] return ( - f"{res}\n## {lang.require('format', 'notice')}\n " + - "\n ".join([f"{v[0]}: {v[1]}" for v in notice]) - ) if notice else res - + (f"{res}\n## {lang.require('format', 'notice')}\n " + "\n ".join([f"{v[0]}: {v[1]}" for v in notice])) + if notice + else res + ) def header(self, root: dict[str, Any], separators: tuple[str, ...]) -> str: """头部节点的描述 @@ -208,30 +213,24 @@ def header(self, root: dict[str, Any], separators: tuple[str, ...]) -> str: root (dict[str, Any]): 头部节点数据 separators (tuple[str, ...]): 分隔符 """ - help_string = f"\n{desc}" if (desc := root.get('description')) else "" - usage = f"\n{lang.require('format', 'usage')}:\n{usage}" if (usage := root.get('usage')) else "" - example = f"\n{lang.require('format', 'example')}:\n{example}" if (example := root.get('example')) else "" - prefixs = f"[{''.join(map(str, prefixs))}]" if (prefixs := root.get('prefix', [])) != [] else "" + help_string = f"\n{desc}" if (desc := root.get("description")) else "" + usage = f"\n{lang.require('format', 'usage')}:\n{usage}" if (usage := root.get("usage")) else "" + example = f"\n{lang.require('format', 'example')}:\n{example}" if (example := root.get("example")) else "" + prefixs = f"[{''.join(map(str, prefixs))}]" if (prefixs := root.get("prefix", [])) != [] else "" cmd = f"{prefixs}{root.get('name', '')}" - command_string = cmd or (root['name'] + separators[0]) + command_string = cmd or (root["name"] + separators[0]) return f"{command_string} %s{help_string}{usage}\n\n%s{example}" def opt(self, node: Option) -> str: """对单个选项的描述""" - alias_text = " ".join(node.requires) + (' ' if node.requires else '') + "|".join(node.aliases) - return ( - f"* {node.help_text}\n" - f" {alias_text}{node.separators[0]}{self.parameters(node.args)}\n" - ) + alias_text = " ".join(node.requires) + (" " if node.requires else "") + "|".join(node.aliases) + return f"* {node.help_text}\n" f" {alias_text}{node.separators[0]}{self.parameters(node.args)}\n" def sub(self, node: Subcommand) -> str: """对单个子命令的描述""" - name = " ".join(node.requires) + (' ' if node.requires else '') + node.name + name = " ".join(node.requires) + (" " if node.requires else "") + node.name opt_string = "".join( - [ - self.opt(opt).replace("\n", "\n ").replace("# ", "* ") - for opt in node.options if isinstance(opt, Option) - ] + [self.opt(opt).replace("\n", "\n ").replace("# ", "* ") for opt in node.options if isinstance(opt, Option)] ) sub_string = "".join( [ @@ -246,19 +245,14 @@ def sub(self, node: Subcommand) -> str: f" {name}{tuple(node.separators)[0]}{self.parameters(node.args)}\n" f"{sub_help}{sub_string}" f"{opt_help}{opt_string}" - ).rstrip(' ') + ).rstrip(" ") def body(self, parts: list[Option | Subcommand]) -> str: """子节点列表的描述""" option_string = "".join( - [ - self.opt(opt) for opt in parts - if isinstance(opt, Option) and opt.name not in self.ignore_names - ] - ) - subcommand_string = "".join( - [self.sub(sub) for sub in parts if isinstance(sub, Subcommand)] + [self.opt(opt) for opt in parts if isinstance(opt, Option) and opt.name not in self.ignore_names] ) + subcommand_string = "".join([self.sub(sub) for sub in parts if isinstance(sub, Subcommand)]) option_help = f"{lang.require('format', 'options')}:\n" if option_string else "" subcommand_help = f"{lang.require('format', 'subcommands')}:\n" if subcommand_string else "" return f"{subcommand_help}{subcommand_string}{option_help}{option_string}" diff --git a/src/arclet/alconna/manager.py b/src/arclet/alconna/manager.py index b0bc00ca..767eb5ae 100644 --- a/src/arclet/alconna/manager.py +++ b/src/arclet/alconna/manager.py @@ -203,22 +203,22 @@ def add_shortcut(self, target: Alconna, key: str, source: Arparma | ShortcutArgs namespace, name = self._command_part(target.path) argv = self.resolve(target) if isinstance(source, dict): - source.setdefault('fuzzy', True) - source.setdefault('prefix', False) + source.setdefault("fuzzy", True) + source.setdefault("prefix", False) if source.get("prefix") and target.prefixes: out = [] for prefix in target.prefixes: if not isinstance(prefix, str): continue _src = source.copy() - _src['command'] = argv.converter(prefix + source.get('command', str(target.command))) + _src["command"] = argv.converter(prefix + source.get("command", str(target.command))) prefix = re.escape(prefix) self.__shortcuts[f"{namespace}.{name}::{prefix}{key}"] = _src - out.append(lang.require("shortcut", "add_success").format( - shortcut=f"{prefix}{key}", target=target.path) + out.append( + lang.require("shortcut", "add_success").format(shortcut=f"{prefix}{key}", target=target.path) ) return "\n".join(out) - source['command'] = argv.converter(source.get('command', target.command or target.name)) + source["command"] = argv.converter(source.get("command", target.command or target.name)) self.__shortcuts[f"{namespace}.{name}::{key}"] = source return lang.require("shortcut", "add_success").format(shortcut=key, target=target.path) elif source.matched: @@ -243,21 +243,17 @@ def list_shortcut(self, target: Alconna) -> list[str]: continue short = self.__shortcuts[i] if isinstance(short, dict): - result.append(i.split('::')[1] + (" ..." if short.get('fuzzy') else "")) + result.append(i.split("::")[1] + (" ..." if short.get("fuzzy") else "")) else: - result.append(i.split('::')[1]) + result.append(i.split("::")[1]) return result @overload - def find_shortcut( - self, target: Alconna[TDC] - ) -> list[Union[Arparma[TDC], ShortcutArgs]]: + def find_shortcut(self, target: Alconna[TDC]) -> list[Union[Arparma[TDC], ShortcutArgs]]: ... @overload - def find_shortcut( - self, target: Alconna[TDC], query: str - ) -> tuple[Arparma[TDC] | ShortcutArgs, Match[str] | None]: + def find_shortcut(self, target: Alconna[TDC], query: str) -> tuple[Arparma[TDC] | ShortcutArgs, Match[str] | None]: ... def find_shortcut(self, target: Alconna[TDC], query: str | None = None): @@ -302,7 +298,7 @@ def get_command(self, command: str) -> Alconna: raise ValueError(command) return self.__commands[namespace][name] - def get_commands(self, namespace: str | Namespace = '') -> list[Alconna]: + def get_commands(self, namespace: str | Namespace = "") -> list[Alconna]: """获取命令列表""" if not namespace: return list(self.__analysers.keys()) @@ -312,13 +308,13 @@ def get_commands(self, namespace: str | Namespace = '') -> list[Alconna]: return [] return list(self.__commands[namespace].values()) - def test(self, message: TDC, namespace: str | Namespace = '') -> Arparma[TDC] | None: + def test(self, message: TDC, namespace: str | Namespace = "") -> Arparma[TDC] | None: """将一段命令给当前空间内的所有命令测试匹配""" for cmd in self.get_commands(namespace): if (res := cmd.parse(message)) and res.matched: return res - def broadcast(self, message: TDC, namespace: str | Namespace = '') -> WeakValueDictionary[str, Arparma[TDC]]: + def broadcast(self, message: TDC, namespace: str | Namespace = "") -> WeakValueDictionary[str, Arparma[TDC]]: """将一段命令给当前空间内的所有命令测试匹配""" data = WeakValueDictionary() for cmd in self.get_commands(namespace): @@ -334,7 +330,7 @@ def all_command_help( pages: str | None = None, footer: str | None = None, max_length: int = -1, - page: int = 1 + page: int = 1, ) -> str: """ 获取所有命令的帮助信息 @@ -349,39 +345,44 @@ def all_command_help( page (int, optional): 当前页码. Defaults to 1. """ pages = pages or lang.require("manager", "help_pages") - cmds = list(filter(lambda x: not x.meta.hide, self.get_commands(namespace or ''))) + cmds = list(filter(lambda x: not x.meta.hide, self.get_commands(namespace or ""))) header = header or lang.require("manager", "help_header") if max_length < 1: - command_string = "\n".join( - f" {str(index).rjust(len(str(len(cmds))), '0')} {slot.name} : {slot.meta.description}" - for index, slot in enumerate(cmds) - ) if show_index else "\n".join( - f" - {cmd.name} : {cmd.meta.description}" - for cmd in cmds + command_string = ( + "\n".join( + f" {str(index).rjust(len(str(len(cmds))), '0')} {slot.name} : {slot.meta.description}" + for index, slot in enumerate(cmds) + ) + if show_index + else "\n".join(f" - {cmd.name} : {cmd.meta.description}" for cmd in cmds) ) else: max_page = len(cmds) // max_length + 1 if page < 1 or page > max_page: page = 1 header += "\t" + pages.format(current=page, total=max_page) - command_string = "\n".join( - f" {str(index).rjust(len(str(page * max_length)), '0')} {cmd.name} : {cmd.meta.description}" - for index, cmd in enumerate( - cmds[(page - 1) * max_length: page * max_length], start=(page - 1) * max_length + command_string = ( + "\n".join( + f" {str(index).rjust(len(str(page * max_length)), '0')} {cmd.name} : {cmd.meta.description}" + for index, cmd in enumerate( + cmds[(page - 1) * max_length : page * max_length], start=(page - 1) * max_length + ) + ) + if show_index + else "\n".join( + f" - {cmd.name} : {cmd.meta.description}" + for cmd in cmds[(page - 1) * max_length : page * max_length] ) - ) if show_index else "\n".join( - f" - {cmd.name} : {cmd.meta.description}" - for cmd in cmds[(page - 1) * max_length: page * max_length] ) help_names = set() for i in cmds: - help_names.update(i.namespace_config.builtin_option_name['help']) + help_names.update(i.namespace_config.builtin_option_name["help"]) footer = footer or lang.require("manager", "help_footer").format(help="|".join(help_names)) return f"{header}\n{command_string}\n{footer}" def all_command_raw_help(self, namespace: str | Namespace | None = None) -> dict[str, CommandMeta]: """获取所有命令的原始帮助信息""" - cmds = list(filter(lambda x: not x.meta.hide, self.get_commands(namespace or ''))) + cmds = list(filter(lambda x: not x.meta.hide, self.get_commands(namespace or ""))) return {cmd.path: copy(cmd.meta) for cmd in cmds} def command_help(self, command: str) -> str | None: @@ -434,15 +435,15 @@ def set_record_size(self, size: int): def __repr__(self): return ( - f"Current: {hex(id(self))} in {datetime.now().strftime('%Y/%m/%d %H:%M:%S')}\n" + - "Commands:\n" + - f"[{', '.join([cmd.path for cmd in self.get_commands()])}]" + - "\nShortcuts:\n" + - "\n".join([f" {k} => {v}" for k, v in self.__shortcuts.items()]) + - "\nRecords:\n" + - "\n".join([f" [{k}]: {v[1].origin}" for k, v in enumerate(self.__record.items()[:20])]) + - "\nDisabled Commands:\n" + - f"[{', '.join(map(lambda x: x.path, self.__abandons))}]" + f"Current: {hex(id(self))} in {datetime.now().strftime('%Y/%m/%d %H:%M:%S')}\n" + + "Commands:\n" + + f"[{', '.join([cmd.path for cmd in self.get_commands()])}]" + + "\nShortcuts:\n" + + "\n".join([f" {k} => {v}" for k, v in self.__shortcuts.items()]) + + "\nRecords:\n" + + "\n".join([f" [{k}]: {v[1].origin}" for k, v in enumerate(self.__record.items()[:20])]) + + "\nDisabled Commands:\n" + + f"[{', '.join(map(lambda x: x.path, self.__abandons))}]" ) diff --git a/src/arclet/alconna/model.py b/src/arclet/alconna/model.py index 34010a25..d1e5bb8c 100644 --- a/src/arclet/alconna/model.py +++ b/src/arclet/alconna/model.py @@ -8,6 +8,7 @@ class Sentence: __slots__ = ("name",) __str__ = lambda self: self.name __repr__ = lambda self: self.name + def __init__(self, name): self.name = name @@ -16,6 +17,7 @@ def __init__(self, name): class OptionResult: __slots__ = ("value", "args") __repr__ = _repr_ + def __init__(self, value=Ellipsis, args=None): self.value = value self.args = args or {} @@ -25,6 +27,7 @@ def __init__(self, value=Ellipsis, args=None): class SubcommandResult: __slots__ = ("value", "args", "options", "subcommands") __repr__ = _repr_ + def __init__(self, value=Ellipsis, args=None, options=None, subcommands=None): self.value = value self.args = args or {} @@ -36,12 +39,11 @@ def __init__(self, value=Ellipsis, args=None, options=None, subcommands=None): class HeadResult: __slots__ = ("origin", "result", "matched", "groups") __repr__ = _repr_ + def __init__(self, origin=None, result=None, matched=False, groups=None, fixes=None): self.origin = origin self.result = result self.matched = matched self.groups = groups or {} if fixes: - self.groups.update( - {k: v.exec(self.groups[k]).value for k, v in fixes.items() if k in self.groups} # noqa - ) + self.groups.update({k: v.exec(self.groups[k]).value for k, v in fixes.items() if k in self.groups}) # noqa diff --git a/src/arclet/alconna/model.pyi b/src/arclet/alconna/model.pyi index 73953846..7638e593 100644 --- a/src/arclet/alconna/model.pyi +++ b/src/arclet/alconna/model.pyi @@ -12,6 +12,7 @@ class Sentence: Attributes: name (str): 句段名称 """ + name: str def __init__(self, name: str) -> None: ... @@ -22,6 +23,7 @@ class OptionResult: value (Any): 选项值 args (dict[str, Any]): 选项参数解析结果 """ + value: Any args: dict[str, Any] def __init__(self, value: Any = ..., args: dict[str, Any] | None = ...) -> None: ... @@ -35,6 +37,7 @@ class SubcommandResult: options (dict[str, OptionResult]): 子命令的子选项解析结果 subcommands (dict[str, SubcommandResult]): 子命令的子子命令解析结果 """ + value: Any args: dict[str, Any] options: dict[str, OptionResult] @@ -44,7 +47,7 @@ class SubcommandResult: value: Any = ..., args: dict[str, Any] | None = ..., options: dict[str, OptionResult] | None = ..., - subcommands: dict[str, SubcommandResult] | None = ... + subcommands: dict[str, SubcommandResult] | None = ..., ) -> None: ... class HeadResult: @@ -56,6 +59,7 @@ class HeadResult: matched (bool): 命令头是否匹配 groups (dict[str, Any]): 命令头匹配组 """ + origin: Any result: Any matched: bool @@ -66,5 +70,5 @@ class HeadResult: result: Any = ..., matched: bool = ..., groups: dict[str, str] | None = ..., - fixes: dict[str, BasePattern] | None = ... + fixes: dict[str, BasePattern] | None = ..., ) -> None: ... diff --git a/src/arclet/alconna/output.py b/src/arclet/alconna/output.py index d3feb52b..aefa4c8e 100644 --- a/src/arclet/alconna/output.py +++ b/src/arclet/alconna/output.py @@ -9,6 +9,7 @@ @dataclass(init=True, unsafe_hash=True) class Sender: """发送器""" + action: Callable[..., Any] """发送行为函数""" generator: Callable[[], str] @@ -23,6 +24,7 @@ def __call__(self, *args, **kwargs): @dataclass class OutputManager: """命令输出管理器""" + cache: dict[str, Callable] = field(default_factory=dict) """缓存的输出行为""" outputs: dict[str, Sender] = field(default_factory=dict) diff --git a/src/arclet/alconna/stub.py b/src/arclet/alconna/stub.py index f36228b2..098f628a 100644 --- a/src/arclet/alconna/stub.py +++ b/src/arclet/alconna/stub.py @@ -12,8 +12,8 @@ from .base import Option, Subcommand from .model import OptionResult, SubcommandResult -T = TypeVar('T') -T_Origin = TypeVar('T_Origin') +T = TypeVar("T") +T_Origin = TypeVar("T_Origin") @dataclass(init=True, eq=True) @@ -43,6 +43,7 @@ def __repr__(self): @dataclass(init=True) class ArgsStub(BaseStub[Args]): """参数存根""" + _value: dict[str, Any] = field(default_factory=dict) """解析结果""" @@ -92,7 +93,7 @@ def __len__(self): return len(self._value) def __getattribute__(self, item): - if item not in (_cache := super().__getattribute__('_value')): + if item not in (_cache := super().__getattribute__("_value")): return super().__getattribute__(item) return _cache.get(item, None) @@ -105,6 +106,7 @@ def __getitem__(self, item: int | str) -> Any: @dataclass(init=True) class OptionStub(BaseStub[Option]): """选项存根""" + args: ArgsStub = field(init=False) """选项的参数存根""" dest: str = field(init=False) @@ -116,8 +118,8 @@ class OptionStub(BaseStub[Option]): def __post_init__(self): self.dest = self._origin.dest - self.aliases = [alias.lstrip('-') for alias in self._origin.aliases] - self.name = self._origin.name.lstrip('-') + self.aliases = [alias.lstrip("-") for alias in self._origin.aliases] + self.name = self._origin.name.lstrip("-") self.args = ArgsStub(self._origin.args) def set_result(self, result: OptionResult | None): @@ -131,6 +133,7 @@ def set_result(self, result: OptionResult | None): @dataclass(init=True) class SubcommandStub(BaseStub[Subcommand]): """子命令存根""" + args: ArgsStub = field(init=False) """子命令的参数存根""" dest: str = field(init=False) @@ -144,7 +147,7 @@ class SubcommandStub(BaseStub[Subcommand]): def __post_init__(self): self.dest = self._origin.dest - self.name = self._origin.name.lstrip('-') + self.name = self._origin.name.lstrip("-") self.args = ArgsStub(self._origin.args) self.options = [OptionStub(opt) for opt in self._origin.options if isinstance(opt, Option)] self.subcommands = [SubcommandStub(sub) for sub in self._origin.options if isinstance(sub, Subcommand)] diff --git a/src/arclet/alconna/typing.py b/src/arclet/alconna/typing.py index 1ddc762c..5792beea 100644 --- a/src/arclet/alconna/typing.py +++ b/src/arclet/alconna/typing.py @@ -47,9 +47,10 @@ class CommandMeta: class KeyWordVar(BasePattern): """对具名参数的包装""" + base: BasePattern - def __init__(self, value: BasePattern | Any, sep: str = '='): + def __init__(self, value: BasePattern | Any, sep: str = "="): """构建一个具名参数 Args: @@ -67,14 +68,17 @@ def __repr__(self): class _Kw: __slots__ = () + def __getitem__(self, item): return KeyWordVar(item) + __matmul__ = __getitem__ __rmatmul__ = __getitem__ class MultiVar(BasePattern): """对可变参数的包装""" + base: BasePattern flag: Literal["+", "*"] length: int @@ -113,13 +117,14 @@ def __repr__(self): class KWBool(BasePattern): """对布尔参数的包装""" + ... class UnpackVar(BasePattern): """特殊参数,利用dataclass 的 field 生成 arg 信息,并返回dcls""" - def __init__(self, dcls: Any, kw_only: bool = False, kw_sep: str = '='): + def __init__(self, dcls: Any, kw_only: bool = False, kw_sep: str = "="): """构建一个可变参数 Args: @@ -135,7 +140,9 @@ def __init__(self, dcls: Any, kw_only: bool = False, kw_sep: str = '='): class _Up: __slots__ = () + def __mul__(self, other): return UnpackVar(other) + Up = _Up()