diff --git a/benchmark.py b/benchmark.py index f432d03e..76551c26 100644 --- a/benchmark.py +++ b/benchmark.py @@ -1,44 +1,21 @@ import time -from arclet.alconna import Alconna, Args, ANY, command_manager, namespace +from arclet.alconna import Alconna, Option, Args, command_manager import cProfile import pstats - -class Plain: - type = "Plain" - text: str - - def __init__(self, t: str): - self.text = t - - def __repr__(self): - return self.text - - -class At: - type = "At" - target: int - - def __init__(self, t: int): - self.target = t - - def __repr__(self): - return f"At:{self.target}" - - -with namespace("test") as np: - np.enable_message_cache = False - np.to_text = lambda x: x.text if x.__class__ is Plain else None - alc = Alconna( - ["."], - "test", - Args["bar", ANY] - ) +alc = Alconna( + "test", + Option("--foo", Args["f", str]), + Option("--bar", Args["b", str]), + Option("--baz", Args["z", str]), + Option("--qux", Args["q", str]), +) argv = command_manager.resolve(alc) analyser = command_manager.require(alc) -print(alc) -msg = [Plain(".test"), At(124)] +msg = ["test --qux 123"] + +print(alc.parse(msg)) count = 20000 if __name__ == "__main__": @@ -62,8 +39,6 @@ def __repr__(self): print(f"Alconna: {li / count} ns per loop with {count} loops") - command_manager.records.clear() - prof = cProfile.Profile() prof.enable() for _ in range(count): diff --git a/src/arclet/alconna/__init__.py b/src/arclet/alconna/__init__.py index cc4d3600..7a6b46c9 100644 --- a/src/arclet/alconna/__init__.py +++ b/src/arclet/alconna/__init__.py @@ -46,7 +46,7 @@ from .typing import UnpackVar as UnpackVar from .typing import Up as Up -__version__ = "2.0.0a1" +__version__ = "2.0.0a2" # backward compatibility Arpamar = Arparma diff --git a/src/arclet/alconna/_internal/_analyser.py b/src/arclet/alconna/_internal/_analyser.py index f39bac0d..56e22b50 100644 --- a/src/arclet/alconna/_internal/_analyser.py +++ b/src/arclet/alconna/_internal/_analyser.py @@ -40,26 +40,6 @@ _SPECIAL = {"help": handle_help, "shortcut": handle_shortcut, "completion": handle_completion} -def _compile_opts(option: Option, data: dict[str, Option | list[Option] | SubAnalyser]): - """处理选项 - - Args: - option (Option): 选项 - data (dict[str, Sentence | Option | list[Option] | SubAnalyser]): 编译的节点 - """ - for alias in option.aliases: - if li := data.get(alias): - if isinstance(li, SubAnalyser): - continue - if isinstance(li, list): - li.append(option) - li.sort(key=lambda x: x.priority, reverse=True) - else: - data[alias] = sorted([li, option], key=lambda x: x.priority, reverse=True) - else: - data[alias] = option - - def default_compiler(analyser: SubAnalyser, pids: set[str]): """默认的编译方法 @@ -71,7 +51,10 @@ def default_compiler(analyser: SubAnalyser, pids: set[str]): if isinstance(opts, Option) and not isinstance(opts, (Help, Shortcut, Completion)): if opts.compact or opts.action.type == 2 or not set(analyser.command.separators).issuperset(opts.separators): # noqa: E501 analyser.compact_params.append(opts) - _compile_opts(opts, analyser.compile_params) # type: ignore + for alias in opts.aliases: + if alias in analyser.compile_params and isinstance(analyser.compile_params[alias], SubAnalyser): + continue + analyser.compile_params[alias] = opts if opts.default is not Empty: analyser.default_opt_result[opts.dest] = (opts.default, opts.action) pids.update(opts.aliases) @@ -82,7 +65,7 @@ def default_compiler(analyser: SubAnalyser, pids: set[str]): default_compiler(sub, pids) if not set(analyser.command.separators).issuperset(opts.separators): analyser.compact_params.append(sub) - if sub.command.default: + if sub.command.default is not Empty: analyser.default_sub_result[opts.dest] = sub.command.default @@ -96,7 +79,7 @@ class SubAnalyser(Generic[TDC]): """命令是否只有主参数""" need_main_args: bool = field(default=False) """是否需要主参数""" - compile_params: dict[str, Option | list[Option] | SubAnalyser[TDC]] = field(default_factory=dict) + compile_params: dict[str, Option | SubAnalyser[TDC]] = field(default_factory=dict) """编译的节点""" compact_params: list[Option | SubAnalyser[TDC]] = field(default_factory=list) """可能紧凑的需要逐个解析的节点""" @@ -157,12 +140,12 @@ def reset(self): self.value_result = None self.header_result = None - def process(self, argv: Argv[TDC]) -> Self: + def process(self, argv: Argv[TDC], trigger: str | None = None) -> Self: """处理传入的参数集合 Args: argv (Argv[TDC]): 命令行参数 - + trigger (str | None, optional): 触发词. Defaults to None. Returns: Self: 自身 @@ -171,11 +154,12 @@ def process(self, argv: Argv[TDC]) -> Self: FuzzyMatchSuccess: 模糊匹配成功 """ sub = argv.context = self.command - name, _ = argv.next(sub.separators) - if name != sub.name: # 先匹配节点名称 - if argv.fuzzy_match and levenshtein(name, sub.name) >= config.fuzzy_threshold: - raise FuzzyMatchSuccess(lang.require("fuzzy", "matched").format(source=name, target=sub.name)) - raise ParamsUnmatched(lang.require("subcommand", "name_error").format(target=name, source=sub.name)) + if not trigger: + name, _ = argv.next(sub.separators) + if name != sub.name: # 先匹配节点名称 + if argv.fuzzy_match and levenshtein(name, sub.name) >= config.fuzzy_threshold: + raise FuzzyMatchSuccess(lang.require("fuzzy", "matched").format(source=name, target=sub.name)) + raise ParamsUnmatched(lang.require("subcommand", "name_error").format(target=name, source=sub.name)) self.value_result = sub.action.value return self.analyse(argv) @@ -221,8 +205,6 @@ class Analyser(SubAnalyser[TDC], Generic[TDC]): command: Alconna """命令实例""" - used_tokens: set[int] - """已使用的token""" command_header: Header """命令头部""" @@ -235,15 +217,10 @@ def __init__(self, alconna: Alconna[TDC], compiler: TCompile | None = None): """ super().__init__(alconna) self.fuzzy_match = alconna.meta.fuzzy_match - self.used_tokens = set() self.command_header = Header.generate(alconna.command, alconna.prefixes, alconna.meta.compact) compiler = compiler or default_compiler compiler(self, command_manager.resolve(self.command).param_ids) - def _clr(self): - self.used_tokens.clear() - super()._clr() - def __repr__(self): return f"<{self.__class__.__name__} of {self.command.path}>" @@ -283,15 +260,14 @@ def shortcut( if reg: _handle_shortcut_reg(argv, reg.groups(), reg.groupdict()) argv.bak_data = argv.raw_data.copy() - if argv.message_cache: - argv.token = argv.generate_token(argv.raw_data) return self.process(argv) - def process(self, argv: Argv[TDC]) -> Arparma[TDC]: + def process(self, argv: Argv[TDC], trigger=None) -> Arparma[TDC]: """主体解析函数, 应针对各种情况进行解析 Args: argv (Argv[TDC]): 命令行参数 + trigger (str | None, optional): 触发词. Defaults to None. Returns: Arparma[TDC]: Arparma 解析结果 @@ -301,8 +277,6 @@ def process(self, argv: Argv[TDC]) -> Arparma[TDC]: InvalidParam: 参数不匹配 ArgumentMissing: 参数缺失 """ - if argv.message_cache and argv.token in self.used_tokens and (res := command_manager.get_record(argv.token)): - return res try: self.header_result = analyse_header(self.command_header, argv) except InvalidParam as e: @@ -336,7 +310,7 @@ def process(self, argv: Argv[TDC]) -> Arparma[TDC]: if fail := self.analyse(argv): return fail - if argv.done and (not self.need_main_args or self.args_result): + if argv.current_index == argv.ndata and (not self.need_main_args or self.args_result): return self.export(argv) rest = argv.release() @@ -407,9 +381,6 @@ def export( result.main_args = self.args_result result.options = self.options_result result.subcommands = self.subcommands_result - if argv.message_cache: - command_manager.record(argv.token, result) - self.used_tokens.add(argv.token) self.reset() return result # type: ignore diff --git a/src/arclet/alconna/_internal/_argv.py b/src/arclet/alconna/_internal/_argv.py index bbfd143c..43104252 100644 --- a/src/arclet/alconna/_internal/_argv.py +++ b/src/arclet/alconna/_internal/_argv.py @@ -43,12 +43,9 @@ class Argv(Generic[TDC]): """备份的原始数据""" raw_data: list[str | Any] = field(init=False) """原始数据""" - token: int = field(init=False) - """命令的token""" origin: TDC = field(init=False) """原始命令""" _sep: tuple[str, ...] | None = field(init=False) - _cache: ClassVar[dict[type, dict[str, Any]]] = {} def __post_init__(self): @@ -70,16 +67,11 @@ def reset(self): self.ndata = 0 self.bak_data = [] self.raw_data = [] - self.token = 0 self.origin = "None" self._sep = None + self._next = None self.context = None - @staticmethod - def generate_token(data: list) -> int: - """命令的`token`的生成函数""" - return hash(repr(data)) - @property def done(self) -> bool: """命令是否解析完毕""" @@ -112,8 +104,6 @@ def build(self, data: TDC) -> Self: raise NullMessage(lang.require("argv", "null_message").format(target=data)) self.ndata = i self.bak_data = raw_data.copy() - if self.message_cache: - self.token = self.generate_token(raw_data) return self def addon(self, data: Iterable[str | Any]) -> Self: @@ -138,8 +128,6 @@ def addon(self, data: Iterable[str | Any]) -> Self: self.raw_data.append(d) self.ndata += 1 self.bak_data = self.raw_data.copy() - if self.message_cache: - self.token = self.generate_token(self.raw_data) return self def next(self, separate: tuple[str, ...] | None = None, move: bool = True) -> tuple[str | Any, bool]: diff --git a/src/arclet/alconna/_internal/_handlers.py b/src/arclet/alconna/_internal/_handlers.py index 4592a88c..4e95938b 100644 --- a/src/arclet/alconna/_internal/_handlers.py +++ b/src/arclet/alconna/_internal/_handlers.py @@ -27,10 +27,10 @@ def _validate(argv: Argv, target: Arg[Any], value: BasePattern[Any], result: dict[str, Any], arg: Any, _str: bool): - if value == ANY or (value == STRING and _str): + if (value is STRING and _str) or value is ANY: result[target.name] = arg return - if value == AnyString: + if value is AnyString: result[target.name] = str(arg) return default_val = target.field.default @@ -68,12 +68,12 @@ def step_varpos(argv: Argv, args: Args, result: dict[str, Any]): if _str and args.argument.var_keyword and args.argument.var_keyword[0].base.sep in may_arg: # type: ignore argv.rollback(may_arg) break - if (res := value.base.exec(may_arg)).flag != "valid": + if (res := value.base.validate(may_arg)).flag != "valid": argv.rollback(may_arg) break _result.append(res._value) # noqa if not _result: - if default_val is not None: + if default_val is not Empty: _result = default_val if isinstance(default_val, Iterable) else () elif value.flag == "*": _result = () @@ -105,12 +105,12 @@ def step_varkey(argv: Argv, args: Args, result: dict[str, Any]): key = _kwarg[1] if not (_m_arg := _kwarg[2]): _m_arg, _ = argv.next(arg.separators) - if (res := value.base.base.exec(_m_arg)).flag != "valid": + if (res := value.base.base.validate(_m_arg)).flag != "valid": argv.rollback(may_arg) break _result[key] = res._value # noqa if not _result: - if default_val is not None: + if default_val is not Empty: _result = default_val if isinstance(default_val, dict) else {} elif value.flag == "*": _result = {} @@ -143,7 +143,7 @@ def step_keyword(argv: Argv, args: Args, result: dict[str, Any]): if args.argument.var_keyword or (_str and may_arg in argv.param_ids): break for arg in args.argument.keyword_only.values(): - if arg.value.base.exec(may_arg).flag == "valid": # type: ignore + if arg.value.base.validate(may_arg).flag == "valid": # type: ignore raise InvalidParam(lang.require("args", "key_missing").format(target=may_arg, key=arg.name)) for name in args.argument.keyword_only: if levenshtein(_key, name) >= config.fuzzy_threshold: @@ -163,8 +163,8 @@ def step_keyword(argv: Argv, args: Args, result: dict[str, Any]): for key, arg in args.argument.keyword_only.items(): if key in result: continue - if arg.field.default is not None: - result[key] = None if arg.field.default is Empty else arg.field.default + if arg.field.default is not Empty: + result[key] = arg.field.default elif not arg.optional: raise ArgumentMissing(arg.field.get_missing_tips(lang.require("args", "missing").format(key=key))) @@ -180,7 +180,7 @@ def analyse_args(argv: Argv, args: Args) -> dict[str, Any]: Returns: dict[str, Any]: 解析结果 """ - result: dict[str, Any] = {} + result = {} for arg in args.argument.normal: argv.context = arg may_arg, _str = argv.next(arg.separators) @@ -188,14 +188,14 @@ def analyse_args(argv: Argv, args: Args) -> dict[str, Any]: if argv.special[may_arg] not in argv.namespace.disable_builtin_options: raise SpecialOptionTriggered(argv.special[may_arg]) if _str and may_arg in argv.param_ids and arg.optional: - if (de := arg.field.default) is not None: - result[arg.name] = None if de is Empty else de + if (de := arg.field.default) is not Empty: + result[arg.name] = de argv.rollback(may_arg) continue if not may_arg: argv.rollback(may_arg) - if (de := arg.field.default) is not None: - result[arg.name] = None if de is Empty else de + if (de := arg.field.default) is not Empty: + result[arg.name] = de elif not arg.optional: raise ArgumentMissing(arg.field.get_missing_tips(lang.require("args", "missing").format(key=arg.name))) continue @@ -212,8 +212,8 @@ def analyse_args(argv: Argv, args: Args) -> dict[str, Any]: unpack.separate(*arg.separators) result[arg.name] = arg.value.origin(**analyse_args(argv, unpack)) except Exception as e: - if (de := arg.field.default) is not None: - result[arg.name] = None if de is Empty else de + if (de := arg.field.default) is not Empty: + result[arg.name] = de elif not arg.optional: raise e if args.argument.var_positional: @@ -226,41 +226,42 @@ def analyse_args(argv: Argv, args: Args) -> dict[str, Any]: return result -def handle_option(argv: Argv, opt: Option) -> tuple[str, OptionResult]: +def handle_option(argv: Argv, opt: Option, trigger: str | None = None) -> tuple[str, OptionResult]: """ 处理 `Option` 部分 Args: argv (Argv): 命令行参数 opt (Option): 目标 `Option` + trigger (str | None, optional): 触发的选项名. """ argv.context = opt _cnt = 0 error = True - name, _ = argv.next(opt.separators) - if opt.compact: - for al in opt.aliases: - if mat := re.fullmatch(f"{al}(?P.*?)", name): - argv.rollback(mat["rest"], replace=True) - error = False - break - elif opt.action.type == 2: - for al in opt.aliases: - if name.startswith(al) and (cnt := (len(name.lstrip("-")) / len(al.lstrip("-")))).is_integer(): - _cnt = int(cnt) - error = False - break - elif name in opt.aliases: - error = False - if error: - if argv.fuzzy_match and levenshtein(name, opt.name) >= config.fuzzy_threshold: - raise FuzzyMatchSuccess(lang.require("fuzzy", "matched").format(source=name, target=opt.name)) - raise InvalidParam(lang.require("option", "name_error").format(source=opt.name, target=name)) - name = opt.dest + if not trigger: + name, _ = argv.next(opt.separators) + if opt.compact: + for al in opt.aliases: + if mat := re.fullmatch(f"{al}(?P.*?)", name): + argv.rollback(mat["rest"], replace=True) + error = False + break + elif opt.action.type == 2: + for al in opt.aliases: + if name.startswith(al) and (cnt := (len(name.lstrip("-")) / len(al.lstrip("-")))).is_integer(): + _cnt = int(cnt) + error = False + break + elif name in opt.aliases: + error = False + if error: + if argv.fuzzy_match and levenshtein(name, opt.name) >= config.fuzzy_threshold: + raise FuzzyMatchSuccess(lang.require("fuzzy", "matched").format(source=name, target=opt.name)) + raise InvalidParam(lang.require("option", "name_error").format(source=opt.name, target=name)) return ( - (name, OptionResult(None, analyse_args(argv, opt.args))) + (opt.dest, OptionResult(None, analyse_args(argv, opt.args))) if opt.nargs - else (name, OptionResult(_cnt or opt.action.value)) + else (opt.dest, OptionResult(_cnt or opt.action.value)) ) @@ -285,7 +286,7 @@ def handle_action(param: Option, source: OptionResult, target: OptionResult): return source -def analyse_option(analyser: SubAnalyser, argv: Argv, opt: Option): +def analyse_option(analyser: SubAnalyser, argv: Argv, opt: Option, trigger: str | None = None): """ 分析 `Option` 部分 @@ -293,8 +294,9 @@ def analyse_option(analyser: SubAnalyser, argv: Argv, opt: Option): analyser (SubAnalyser): 当前解析器 argv (Argv): 命令行参数 opt (Option): 目标 `Option` + trigger (str | None, optional): 触发的选项名. """ - opt_n, opt_v = handle_option(argv, opt) + opt_n, opt_v = handle_option(argv, opt, trigger) if opt_n not in analyser.options_result: analyser.options_result[opt_n] = opt_v if opt.action.type == 1 and opt_v.args: @@ -347,53 +349,33 @@ def analyse_param(analyser: SubAnalyser, argv: Argv, seps: tuple[str, ...] | Non argv (Argv): 命令行参数 seps (tuple[str, ...], optional): 指定的分隔符. """ - _text, _str = argv.next(seps, move=False) - if _str and _text in argv.special: - if argv.special[_text] not in argv.namespace.disable_builtin_options: - if _text in argv.completion_names: - argv.bak_data[argv.current_index] = argv.bak_data[argv.current_index].replace(_text, "") + _text, _str = argv.next(seps, move=True) + if _str: + if _text in argv.special and argv.special[_text] not in argv.namespace.disable_builtin_options: + # if _text in argv.completion_names: + # argv.bak_data[argv.current_index] = argv.bak_data[argv.current_index].replace(_text, "") raise SpecialOptionTriggered(argv.special[_text]) - if not _str or not _text: - _param = None - elif _text in analyser.compile_params: - _param = analyser.compile_params[_text] - elif analyser.compact_params and (res := analyse_compact_params(analyser, argv)): - if res.__class__ is str: - raise InvalidParam(res) + if _text in analyser.compile_params: + _param = analyser.compile_params[_text] + if _param.__class__ is Option: + analyse_option(analyser, argv, _param, _text) + else: + try: + _param.process(argv, _text) + finally: + analyser.subcommands_result[_param.command.dest] = _param.result() + argv.context = None + return True + argv.rollback(_text) + if analyser.compact_params and analyse_compact_params(analyser, argv): argv.context = None return True - else: - _param = None - if not _param and analyser.command.nargs and not analyser.args_result: + if analyser.command.nargs and not analyser.args_result: analyser.args_result = analyse_args(argv, analyser.self_args) if analyser.args_result: argv.context = None return True - if _param.__class__ is Option: - analyse_option(analyser, argv, _param) - elif _param.__class__ is list: - exc: Exception | None = None - for opt in _param: - _data, _index = argv.data_set() - try: - analyse_option(analyser, argv, opt) - _data.clear() - exc = None - break - except Exception as e: - exc = e - argv.data_reset(_data, _index) - if exc: - raise exc # type: ignore # noqa - elif _param is not None: - try: - _param.process(argv) - finally: - analyser.subcommands_result[_param.command.dest] = _param.result() - else: - return False - argv.context = None - return True + return False def analyse_header(header: Header, argv: Argv) -> HeadResult: @@ -480,7 +462,6 @@ def handle_help(analyser: Analyser, argv: Argv): def handle_shortcut(analyser: Analyser, argv: Argv): """处理快捷命令触发""" - argv.next() try: opt_v = analyse_args(argv, _args) except SpecialOptionTriggered: diff --git a/src/arclet/alconna/args.py b/src/arclet/alconna/args.py index e03cb0d8..51cca3dd 100644 --- a/src/arclet/alconna/args.py +++ b/src/arclet/alconna/args.py @@ -165,6 +165,8 @@ def __getitem__(self, item: Union[Arg, tuple[Arg, ...], str, tuple[str, Any], tu class _argument(List[Arg[Any]]): + __slots__ = ("unpack", "var_positional", "var_keyword", "keyword_only", "normal") + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.normal: list[Arg[Any]] = [] diff --git a/src/arclet/alconna/argv.py b/src/arclet/alconna/argv.py index eba69598..27898155 100644 --- a/src/arclet/alconna/argv.py +++ b/src/arclet/alconna/argv.py @@ -16,20 +16,14 @@ def set_default_argv_type(argv_type: type[Argv]): def argv_config( target: type[Argv] | None = None, - preprocessors: dict[type, Callable[..., Any]] | None = None, to_text: Callable[[Any], str | None] | None = None, - filter_out: list[type] | None = None, - checker: Callable[[Any], bool] | None = None, converter: Callable[[str | list], TDC] | None = None, ): """配置命令行参数 Args: target (type[Argv] | None, optional): 目标命令类型. - preprocessors (dict[type, Callable[..., Any]] | None, optional): 命令元素的预处理器. to_text (Callable[[Any], str | None] | None, optional): 将命令元素转换为文本, 或者返回None以跳过该元素. - filter_out (list[type] | None, optional): 需要过滤掉的命令元素. - checker (Callable[[Any], bool] | None, optional): 检查传入命令. converter (Callable[[str | list], TDC] | None, optional): 将字符串或列表转为目标命令类型. """ Argv._cache.setdefault(target or __argv_type__.get(), {}).update( diff --git a/src/arclet/alconna/arparma.py b/src/arclet/alconna/arparma.py index c276551f..ceed34df 100644 --- a/src/arclet/alconna/arparma.py +++ b/src/arclet/alconna/arparma.py @@ -280,14 +280,17 @@ def fail(self, exc: type[Exception] | Exception): def __require__(self, parts: list[str]) -> tuple[dict[str, Any] | OptionResult | SubcommandResult | None, str]: """如果能够返回, 除开基本信息, 一定返回该path所在的dict""" + all_args = self.all_matched_args if len(parts) == 1: part = parts[0] - for src in (self.main_args, self.other_args, self.options, self.subcommands): + for src in (self.main_args, all_args, self.options, self.subcommands): if part in src: return src, part - if part in {"options", "subcommands", "main_args", "other_args"}: + if part == "all_args": + return all_args, "" + if part in {"options", "subcommands", "main_args"}: return getattr(self, part, {}), "" - return (self.all_matched_args, "") if part == "args" else (None, part) + return (all_args, "") if part == "args" else (None, part) prefix = parts.pop(0) # parts[0] if prefix in {"options", "subcommands"} and prefix in self.components: raise RuntimeError(lang.require("arparma", "ambiguous_name").format(target=prefix)) @@ -295,8 +298,8 @@ def __require__(self, parts: list[str]) -> tuple[dict[str, Any] | OptionResult | return _handle_opt(prefix, parts, self.options) if prefix == "subcommands" or prefix in self.subcommands: return _handle_sub(prefix, parts, self.subcommands) - prefix = prefix.replace("$main", "main_args").replace("$other", "other_args") - if prefix in {"main_args", "other_args"}: + prefix = prefix.replace("$main", "main_args").replace("$all", "all_matched_args") + if prefix in {"main_args", "all_matched_args"}: return getattr(self, prefix, {}), parts.pop(0) return None, prefix diff --git a/src/arclet/alconna/completion.py b/src/arclet/alconna/completion.py index 2226362b..f86ec62d 100644 --- a/src/arclet/alconna/completion.py +++ b/src/arclet/alconna/completion.py @@ -132,8 +132,6 @@ def enter(self, content: list | None = None) -> EnterResult: argv.bak_data = argv.raw_data.copy() argv.ndata = len(argv.bak_data) argv.current_index = 0 - if argv.message_cache: - argv.token = argv.generate_token(argv.raw_data) argv.origin = argv.converter(argv.raw_data) exc = None try: diff --git a/src/arclet/alconna/config.py b/src/arclet/alconna/config.py index 56b3b824..3a360a60 100644 --- a/src/arclet/alconna/config.py +++ b/src/arclet/alconna/config.py @@ -37,8 +37,6 @@ class Namespace: """默认是否开启模糊匹配""" raise_exception: bool = field(default=False) """默认是否抛出异常""" - enable_message_cache: bool = field(default=True) - """默认是否启用消息缓存""" disable_builtin_options: set[Literal["help", "shortcut", "completion"]] = field(default_factory=set) builtin_option_name: OptionNames = field( default_factory=lambda: { @@ -48,7 +46,7 @@ class Namespace: } ) """默认的内置选项名称""" - to_text: Callable[[Any], str | None] = field(default=lambda x: x if isinstance(x, str) else None) + to_text: Callable[[Any], str | None] = field(default=lambda x: x if x.__class__ is str else None) """默认的选项转文本函数""" converter: Callable[[str | list], DataCollection[Any]] | None = field(default=lambda x: x) """默认的文本转选项函数""" diff --git a/src/arclet/alconna/core.py b/src/arclet/alconna/core.py index d37769cc..0a9790b5 100644 --- a/src/arclet/alconna/core.py +++ b/src/arclet/alconna/core.py @@ -2,7 +2,7 @@ from __future__ import annotations import sys -from dataclasses import dataclass, field, is_dataclass +from dataclasses import is_dataclass from functools import partial from pathlib import Path from typing import Any, Callable, Generic, Sequence, TypeVar, overload @@ -15,12 +15,13 @@ from .arparma import Arparma, ArparmaBehavior, requirement_handler from .base import Completion, Help, Option, Shortcut, Subcommand from .config import Namespace, config -from .exceptions import ExecuteFailed, NullMessage +from .exceptions import NullMessage from .formatter import TextFormatter from .manager import ShortcutArgs, command_manager from .typing import TDC, CommandMeta, DataCollection, TPrefixes T = TypeVar("T") +TCallable = TypeVar("TCallable", bound=Callable[..., Any]) TDC1 = TypeVar("TDC1", bound=DataCollection[Any]) @@ -49,33 +50,6 @@ def add_builtin_options(options: list[Option | Subcommand], ns: Namespace) -> No options.append(Completion("|".join(ns.builtin_option_name["completion"]), help_text=lang.require("builtin", "option_completion"))) # noqa: E501 -@dataclass(init=True, unsafe_hash=True) -class ArparmaExecutor(Generic[T]): - """Arparma 执行器 - - Attributes: - target(Callable[..., T]): 目标函数 - """ - - target: Callable[..., T] - binding: Callable[..., list[Arparma]] = field(default=lambda: [], repr=False) - - __call__ = lambda self, *args, **kwargs: self.target(*args, **kwargs) - - @property - def result(self) -> T: - """执行结果""" - if not self.binding: - raise ExecuteFailed(None) - arps = self.binding() - if not arps or not arps[0].matched: - raise ExecuteFailed("Unmatched") - try: - return arps[0].call(self.target) - except Exception as e: - raise ExecuteFailed(e) from e - - class Alconna(Subcommand, Generic[TDC]): """ 更加精确的命令解析 @@ -166,7 +140,7 @@ def __init__( for behavior in behaviors or []: self.behaviors.extend(requirement_handler(behavior)) command_manager.register(self) - self._executors: dict[ArparmaExecutor, Any] = {} + self._executors: dict[Callable[..., Any], Any] = {} self.union = set() @property @@ -306,31 +280,25 @@ def parse(self, message: TDC, *, _config: type[T] | None = None) -> Arparma[TDC] arp = arp.execute(self.behaviors) if self._executors: for ext in self._executors: - self._executors[ext] = arp.call(ext.target) + self._executors[ext] = arp.call(ext) if _config: if not is_dataclass(_config): raise TypeError("The type of _config must be a dataclass") return arp.call(_config) return arp - def bind(self, active: bool = True): - """绑定命令执行器 - - Args: - active (bool, optional): 该执行器是否由 `Alconna` 主动调用, 默认为 `True` - """ + def bind(self): + """绑定命令执行器""" - def wrapper(target: Callable[..., T]) -> ArparmaExecutor[T]: - ext = ArparmaExecutor(target, lambda: command_manager.get_result(self)) - if active: - self._executors[ext] = None - return ext + def wrapper(target: TCallable) -> TCallable: + self._executors[target] = None + return target return wrapper @property def exec_result(self) -> dict[str, Any]: - return {ext.target.__name__: res for ext, res in self._executors.items() if res is not None} + return {ext.__name__: res for ext, res in self._executors.items() if res is not None} def __truediv__(self, other) -> Self: return self.reset_namespace(other) @@ -379,4 +347,4 @@ def header_display(self): return str(ana.command_header) -__all__ = ["Alconna", "ArparmaExecutor"] +__all__ = ["Alconna"] diff --git a/src/arclet/alconna/formatter.py b/src/arclet/alconna/formatter.py index 361a321c..162fa309 100644 --- a/src/arclet/alconna/formatter.py +++ b/src/arclet/alconna/formatter.py @@ -26,6 +26,7 @@ def resolve(parts: list[str], options: list[Option | Subcommand]): if not parts: return opt return sub if (sub := resolve(parts, opt.options)) else opt + return resolve(parts, options) @dataclass(eq=True) @@ -94,7 +95,7 @@ def _handle(trace: Trace): if isinstance(end, Subcommand): return self.format(Trace( {"name": end.name, "prefix": [], "description": end.help_text}, end.args, - _cache.separators, _cache.options # type: ignore + end.separators, end.options # type: ignore )) return self.format(trace) diff --git a/src/arclet/alconna/manager.py b/src/arclet/alconna/manager.py index 578f5491..4abbdf13 100644 --- a/src/arclet/alconna/manager.py +++ b/src/arclet/alconna/manager.py @@ -12,7 +12,7 @@ from typing_extensions import NotRequired from weakref import WeakKeyDictionary, WeakValueDictionary -from tarina import LRU, lang +from tarina import lang from .argv import Argv, __argv_type__ from .arparma import Arparma @@ -53,7 +53,6 @@ class CommandManager: __analysers: WeakKeyDictionary[Alconna, Analyser] __argv: WeakKeyDictionary[Alconna, Argv] __abandons: list[Alconna] - __record: LRU[int, Arparma] __shortcuts: dict[str, Union[Arparma, ShortcutArgs]] def __init__(self): @@ -67,7 +66,6 @@ def __init__(self): self.__analysers = WeakKeyDictionary() self.__abandons = [] self.__shortcuts = {} - self.__record = LRU(128) def _del(): self.__commands.clear() @@ -75,9 +73,6 @@ def _del(): ana._clr() self.__analysers.clear() self.__abandons.clear() - for arp in self.__record.values(): - arp._clr() - self.__record.clear() self.__shortcuts.clear() weakref.finalize(self, _del) @@ -126,7 +121,6 @@ def register(self, command: Alconna) -> None: to_text=command.namespace_config.to_text, # type: ignore converter=command.namespace_config.converter, # type: ignore separators=command.separators, # type: ignore - message_cache=command.namespace_config.enable_message_cache, # type: ignore filter_crlf=not command.meta.keep_crlf, # type: ignore ) self.__analysers.pop(command, None) @@ -383,49 +377,6 @@ def command_help(self, command: str) -> str | None: if cmd := self.get_command(command): return cmd.get_help() - def record(self, token: int, result: Arparma): - """记录某个命令的 `token`""" - self.__record[token] = result - - def get_record(self, token: int) -> Arparma | None: - """获取某个 `token` 对应的 `Arparma` 对象""" - if token in self.__record: - return self.__record[token] - - def get_token(self, result: Arparma) -> int: - """获取某个命令的 `token`""" - return next((token for token, res in self.__record.items() if res == result), 0) - - def get_result(self, command: Alconna) -> list[Arparma]: - """获取某个命令的所有 `Arparma` 对象""" - return [v for v in self.__record.values() if v.source == command.path] - - @property - def recent_message(self) -> DataCollection[str | Any] | None: - """获取最近一次使用的命令""" - if rct := self.__record.peek_first_item(): - return rct[1].origin # type: ignore - - @property - def last_using(self): - """获取最近一次使用的 `Alconna` 对象""" - if rct := self.__record.peek_first_item(): - return rct[1].source # type: ignore - - @property - def records(self) -> LRU[int, Arparma]: - """获取当前记录""" - return self.__record - - def reuse(self, index: int = -1): - """获取当前记录中的某个值""" - key = self.__record.keys()[index] - return self.__record[key] - - def set_record_size(self, size: int): - """设置记录的最大长度""" - self.__record.set_size(size) - def __repr__(self): return ( f"Current: {hex(id(self))} in {datetime.now().strftime('%Y/%m/%d %H:%M:%S')}\n" @@ -433,8 +384,6 @@ def __repr__(self): + f"[{', '.join([cmd.path for cmd in self.get_commands()])}]" + "\nShortcuts:\n" + "\n".join([f" {k} => {v}" for k, v in self.__shortcuts.items()]) - + "\nRecords:\n" - + "\n".join([f" [{k}]: {v[1].origin}" for k, v in enumerate(self.__record.items()[:20])]) + "\nDisabled Commands:\n" + f"[{', '.join(map(lambda x: x.path, self.__abandons))}]" ) diff --git a/src/arclet/alconna/model.py b/src/arclet/alconna/model.py index 6ce7a836..02ab5511 100644 --- a/src/arclet/alconna/model.py +++ b/src/arclet/alconna/model.py @@ -36,4 +36,4 @@ def __init__(self, origin=None, result=None, matched=False, groups=None, fixes=N self.matched = matched self.groups = groups or {} if fixes: - self.groups.update({k: v.exec(self.groups[k]).value for k, v in fixes.items() if k in self.groups}) # noqa + self.groups.update({k: v.validate(self.groups[k])._value for k, v in fixes.items() if k in self.groups}) # noqa diff --git a/tests/analyser_test.py b/tests/analyser_test.py index 5d1e82c7..c2a7b941 100644 --- a/tests/analyser_test.py +++ b/tests/analyser_test.py @@ -3,10 +3,14 @@ from nepattern import BasePattern, MatchMode -from arclet.alconna import Alconna, Args, Option +from arclet.alconna import Alconna, Args, Option, Argv, set_default_argv_type from arclet.alconna.argv import argv_config +class DummyArgv(Argv): + ... + + @dataclass class Segment: type: str @@ -41,39 +45,22 @@ def gen_unit(type_: str): At = gen_unit("at") -def test_filter_out(): - argv_config(filter_out=[int]) - ana = Alconna("ana", Args["foo", str]) - assert ana.parse(["ana", 123, "bar"]).matched is True - assert ana.parse("ana bar").matched is True - argv_config(filter_out=[]) - ana_1 = Alconna("ana", Args["foo", str]) - assert ana_1.parse(["ana", 123, "bar"]).matched is False - - -def test_preprocessor(): - argv_config(preprocessors={float: int}) - ana1 = Alconna("ana1", Args["bar", int]) - assert ana1.parse(["ana1", 123.06]).matched is True - assert ana1.parse(["ana1", 123.06]).bar == 123 - argv_config(preprocessors={}) - ana1_1 = Alconna("ana1", Args["bar", int]) - assert ana1_1.parse(["ana1", 123.06]).matched is False - - def test_with_set_unit(): - argv_config(preprocessors={Segment: lambda x: str(x) if x.type == "text" else None}) + argv_config(DummyArgv, to_text=lambda x: x if x.__class__ is str else str(x) if x.type == "text" else None) + set_default_argv_type(DummyArgv) ana2 = Alconna("ana2", Args["foo", At]["bar", Face]) res = ana2.parse([Segment.text("ana2"), Segment.at(123456), Segment.face(103)]) assert res.matched is True assert res.foo.data["qq"] == "123456" assert not ana2.parse([Segment.text("ana2"), Segment.face(103), Segment.at(123456)]).matched - argv_config() + + set_default_argv_type(Argv) def test_unhashable_unit(): - argv_config(preprocessors={Segment: lambda x: str(x) if x.type == "text" else None}) + argv_config(DummyArgv, to_text=lambda x: x if x.__class__ is str else str(x) if x.type == "text" else None) + set_default_argv_type(DummyArgv) ana3 = Alconna("ana3", Args["foo", At]) print(ana3.parse(["ana3", Segment.at(123)])) @@ -85,15 +72,7 @@ def test_unhashable_unit(): print(ana3_1.parse(["ana3_1", "--foo", "--comp", Segment.at(123)])) print(ana3_1.parse(["ana3_1", "--comp", Segment.at(123)])) - -def test_checker(): - argv_config(checker=lambda x: isinstance(x, list)) - ana4 = Alconna("ana4", Args["foo", int]) - print(ana4.parse(["ana4", "123"])) - try: - print(ana4.parse("ana4 123")) - except TypeError as e: - print(e) + set_default_argv_type(Argv) if __name__ == "__main__": diff --git a/tests/args_test.py b/tests/args_test.py index f24c58c0..8f528b3f 100644 --- a/tests/args_test.py +++ b/tests/args_test.py @@ -7,17 +7,17 @@ def test_magic_create(): - arg1 = Args.round[float]["test", bool]["aaa", str] + arg1 = Args["round", float]["test", bool]["aaa", str] assert len(arg1) == 3 - arg1 = arg1 << Args.perm[str, ...] + ["month", int] + arg1 <<= Args["perm", str, ...] + ["month", int] assert len(arg1) == 5 - arg11: Args = Args.baz[int] + arg11: Args = Args["baz", int] arg11.add("foo", value=int, default=1) assert len(arg11) == 2 def test_type_convert(): - arg2 = Args.round[float]["test", bool] + arg2 = Args["round", float]["test", bool] assert analyse_args(arg2, ["1.2 False"]) != {"round": "1.2", "test": "False"} assert analyse_args(arg2, ["1.2 False"]) == {"round": 1.2, "test": False} assert analyse_args(arg2, ["a False"], raise_exception=False) != { @@ -27,7 +27,7 @@ def test_type_convert(): def test_regex(): - arg3 = Args.foo["re:abc[0-9]{3}"] + arg3 = Args["foo", "re:abc[0-9]{3}"] assert analyse_args(arg3, ["abc123"]) == {"foo": "abc123"} assert analyse_args(arg3, ["abc"], raise_exception=False) != {"foo": "abc"} @@ -38,18 +38,18 @@ def test_string(): def test_default(): - arg5 = Args.foo[int]["de", bool, True] + arg5 = Args["foo", int]["de", bool, True] assert analyse_args(arg5, ["123 False"]) == {"foo": 123, "de": False} assert analyse_args(arg5, ["123"]) == {"foo": 123, "de": True} def test_separate(): - arg6 = Args.foo[str]["bar", int] / ";" + arg6 = Args["foo", str]["bar", int] / ";" assert analyse_args(arg6, ["abc;123"]) == {"foo": "abc", "bar": 123} def test_object(): - arg7 = Args.foo[str]["bar", 123] + arg7 = Args["foo", str]["bar", 123] assert analyse_args(arg7, ["abc", 123]) == {"foo": "abc", "bar": 123} assert analyse_args(arg7, ["abc", 124], raise_exception=False) != { "foo": "abc", @@ -87,39 +87,39 @@ def test_anti(): def test_choice(): - arg10 = Args.choice[("a", "b", "c")] + arg10 = Args["choice", ("a", "b", "c")] assert analyse_args(arg10, ["a"]) == {"choice": "a"} assert analyse_args(arg10, ["d"], raise_exception=False) != {"choice": "d"} - arg10_1 = Args.mapping[{"a": 1, "b": 2, "c": 3}] + arg10_1 = Args["mapping", {"a": 1, "b": 2, "c": 3}] assert analyse_args(arg10_1, ["a"]) == {"mapping": 1} assert analyse_args(arg10_1, ["d"], raise_exception=False) != {"mapping": "d"} def test_union(): - arg11 = Args.bar[Union[int, float]] + arg11 = Args["bar", Union[int, float]] assert analyse_args(arg11, ["1.2"]) == {"bar": 1.2} assert analyse_args(arg11, ["1"]) == {"bar": 1} assert analyse_args(arg11, ["abc"], raise_exception=False) != {"bar": "abc"} - arg11_1 = Args.bar[[int, float, "abc"]] + arg11_1 = Args["bar", [int, float, "abc"]] assert analyse_args(arg11_1, ["1.2"]) == analyse_args(arg11, ["1.2"]) assert analyse_args(arg11_1, ["abc"]) == {"bar": "abc"} assert analyse_args(arg11_1, ["cba"], raise_exception=False) != {"bar": "cba"} def test_optional(): - arg13 = Args.foo[str].add("bar", value=int, flags="?") + arg13 = Args["foo", str].add("bar", value=int, flags=["?"]) assert analyse_args(arg13, ["abc 123"]) == {"foo": "abc", "bar": 123} assert analyse_args(arg13, ["abc"]) == {"foo": "abc"} - arg13_1 = Args.foo[str]["bar?", int] + arg13_1 = Args["foo", str]["bar?", int] assert analyse_args(arg13_1, ["abc 123"]) == {"foo": "abc", "bar": 123} assert analyse_args(arg13_1, ["abc"]) == {"foo": "abc"} - arg13_2 = Args.foo[str]["bar;?", int] + arg13_2 = Args["foo", str]["bar;?", int] assert analyse_args(arg13_2, ["abc 123"]) == {"foo": "abc", "bar": 123} assert analyse_args(arg13_2, ["abc"]) == {"foo": "abc"} def test_kwonly(): - arg14 = Args.foo[str].add("bar", value=Kw[int]) + arg14 = Args["foo", str].add("bar", value=Kw[int]) assert analyse_args(arg14, ["abc bar=123"]) == { "foo": "abc", "bar": 123, @@ -137,7 +137,7 @@ def test_kwonly(): "width": 960, "height": 480, } - arg14_2 = Args.foo[str]["bar", KeyWordVar(int, " ")]["baz", KeyWordVar(bool, ":")] + arg14_2 = Args["foo", str]["bar", KeyWordVar(int, " ")]["baz", KeyWordVar(bool, ":")] assert analyse_args(arg14_2, ["abc -bar 123 baz:false"]) == { "bar": 123, "baz": False, diff --git a/tests/base_test.py b/tests/base_test.py index 8ecce604..6dbad91f 100644 --- a/tests/base_test.py +++ b/tests/base_test.py @@ -5,7 +5,7 @@ def test_node_create(): - node = CommandNode("foo", Args.bar[int], dest="test") + node = CommandNode("foo", Args["bar", int], dest="test") assert node.name == "foo" assert node.dest != "foo" assert node.nargs == 1 @@ -13,7 +13,7 @@ def test_node_create(): def test_single_args(): node1 = CommandNode("foo", Arg("bar", int)) - assert node1.args == Args.bar[int] + assert node1.args == Args["bar", int] def test_option_aliases(): @@ -26,9 +26,9 @@ def test_option_aliases(): def test_separator(): - opt2 = Option("foo", Args.bar[int], separators="|") + opt2 = Option("foo", Args["bar", int], separators="|") assert analyse_option(opt2, "foo|123") == OptionResult(None, {"bar": 123}) - opt2_1 = Option("foo", Args.bar[int]).separate("|") + opt2_1 = Option("foo", Args["bar", int]).separate("|") assert opt2 == opt2_1 @@ -39,12 +39,12 @@ def test_subcommand(): def test_compact(): - opt3 = Option("-Foo", Args.bar[int], compact=True) + opt3 = Option("-Foo", Args["bar", int], compact=True) assert analyse_option(opt3, "-Foo123") == OptionResult(None, {"bar": 123}) def test_add(): - assert (Option("abcd") + Args.foo[int]).nargs == 1 + assert (Option("abcd") + Args["foo", int]).nargs == 1 assert len((Option("foo") + Option("bar") + "baz").options) == 2 diff --git a/tests/core_test.py b/tests/core_test.py index cd0d84e1..dfccf759 100644 --- a/tests/core_test.py +++ b/tests/core_test.py @@ -325,30 +325,11 @@ def test_alconna_synthesise(): def test_simple_override(): - alc11 = Alconna("core11") + Option("foo", Args["bar", str]) + Option("foo") + alc11 = Alconna("core11") + Option("foo", Args["bar", str]) + Option("bar", dest="foo") res = alc11.parse("core11 foo abc") - res1 = alc11.parse("core11 foo") - assert res.matched is True - assert res1.matched is True - - -def test_requires(): - alc12 = Alconna( - "core12", - Args["target", int], - Option("user perm set", Args["foo", str], help_text="set user permission"), - Option("user perm del", Args["foo", str], help_text="del user permission"), - Option("group perm set", Args["foo", str], help_text="set group permission"), - Option("group perm del", Args["foo", str], help_text="del group permission"), - Option("test"), - ) - - assert alc12.parse("core12 123 user perm set 123").find("user_perm_set") is True - assert alc12.parse("core12 123 user perm del 123").find("user_perm_del") is True - assert alc12.parse("core12 123 group perm set 123").find("group_perm_set") is True - assert alc12.parse("core12 123 group perm del 123 test").find("group_perm_del") is True - print("\n------------------------") - print(alc12.get_help()) + res1 = alc11.parse("core11 bar") + assert res.query("foo") is not None + assert res1.query("foo") is not None def test_wildcard(): @@ -504,10 +485,10 @@ def test_help(): ) with output_manager.capture("core17") as cap: alc17.parse("core17 --help foo") - assert cap["output"] == "foo \nFoo bar" + assert cap["output"] == "[foo] \nFoo bar" with output_manager.capture("core17") as cap: alc17.parse("core17 foo --help") - assert cap["output"] == "foo \nFoo bar" + assert cap["output"] == "[foo] \nFoo bar" with output_manager.capture("core17") as cap: alc17.parse("core17 add --help") assert cap["output"] == "add \nAdd bar" @@ -529,7 +510,7 @@ def test_help(): ) with output_manager.capture("core17_2") as cap: alc17_2.parse("core17_2 --help foo bar") - assert cap["output"] == "bar \nFoo bar" + assert cap["output"] == "[bar] \nFoo bar" with output_manager.capture("core17_2") as cap: alc17_2.parse("core17_2 --help foo") assert cap["output"] == "foo \nsub Foo\n\n可用的选项有:\n* Foo bar\n bar \n" @@ -635,18 +616,18 @@ def test_call(): from dataclasses import dataclass alc22 = Alconna("core22", Args["foo", int], Args["bar", str]) - alc22("core22 123 abc") - @alc22.bind(False) + @alc22.bind() def cb(foo: int, bar: str): print("") print("core22: ") print(foo, bar) return 2 * foo - assert cb.result == 246 + alc22("core22 123 abc") + assert alc22.exec_result["cb"] == 246 alc22.parse("core22 321 abc") - assert cb.result == 642 + assert alc22.exec_result["cb"] == 642 alc22_1 = Alconna("core22_1", Args["name", str]) @@ -848,7 +829,7 @@ def test_tips(): assert core27.parse("core27 1 1").matched assert str(core27.parse("core27 3 1").error_info) == "参数arg必须是1或2哦,不能是3" assert str(core27.parse("core27 1").error_info) == "缺少了arg参数哦" - assert str(core27.parse("core27 1 3").error_info) == "参数 3 不正确" + assert str(core27.parse("core27 1 3").error_info) in ("参数 '3' 不正确, 其应该符合 '1|2'", "参数 '3' 不正确, 其应该符合 '2|1'") assert str(core27.parse("core27").error_info) == "参数 arg1 丢失" diff --git a/tests/devtool.py b/tests/devtool.py index 12db18d6..3f300cae 100644 --- a/tests/devtool.py +++ b/tests/devtool.py @@ -32,6 +32,7 @@ def __new__(cls, *args, **kwargs): cls.command = cls._DummyALC() # type: ignore cls.compile_params = {} cls.compact_params = [] + cls.default_sub_result = {} return super().__new__(cls)