Skip to content

Commit

Permalink
✨ ContextVal
Browse files Browse the repository at this point in the history
  • Loading branch information
RF-Tar-Railt committed Feb 25, 2024
1 parent e9d1c60 commit fc0c91b
Show file tree
Hide file tree
Showing 7 changed files with 128 additions and 49 deletions.
89 changes: 49 additions & 40 deletions pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions src/arclet/alconna/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from .typing import Nargs as Nargs
from .typing import UnpackVar as UnpackVar
from .typing import Up as Up
from .typing import ContextVal as ContextVal

__version__ = "1.7.44"

Expand Down
2 changes: 1 addition & 1 deletion src/arclet/alconna/_internal/_analyser.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,14 +326,14 @@ def process(self, argv: Argv[TDC]) -> Arparma[TDC]:
if self.command.meta.raise_exception:
raise e
return self.export(argv, True, e)
argv.context[SHORTCUT_TRIGGER] = _next
try:
rest, short, mat = command_manager.find_shortcut(self.command, [_next] + argv.release())
except ValueError as exc:
if self.command.meta.raise_exception:
raise e from exc
return self.export(argv, True, e)
else:
argv.context[SHORTCUT_TRIGGER] = _next
argv.context[SHORTCUT_ARGS] = short
argv.context[SHORTCUT_REST] = rest
argv.context[SHORTCUT_REGEX_MATCH] = mat
Expand Down
7 changes: 5 additions & 2 deletions src/arclet/alconna/_internal/_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from ..exceptions import ArgumentMissing, FuzzyMatchSuccess, InvalidParam, PauseTriggered, SpecialOptionTriggered
from ..model import HeadResult, OptionResult, Sentence
from ..output import output_manager
from ..typing import KWBool, ShortcutRegWrapper, MultiKeyWordVar, MultiVar
from ..typing import KWBool, ShortcutRegWrapper, MultiKeyWordVar, MultiVar, ContextVal
from ._header import Double, Header
from ._util import escape, levenshtein, unescape

Expand All @@ -34,7 +34,10 @@ def _validate(argv: Argv, target: Arg[Any], value: BasePattern[Any, Any], result
result[target.name] = str(arg)
return
default_val = target.field.default
res = value.validate(arg, default_val)
_arg = arg
if value == ContextVal:
_arg = (arg, argv.context)
res = value.validate(_arg, default_val)
if res.flag != "valid":
argv.rollback(arg)
if res.flag == "error":
Expand Down
51 changes: 48 additions & 3 deletions src/arclet/alconna/typing.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
"""Alconna 参数相关"""
from __future__ import annotations

import re
from dataclasses import dataclass, field, fields, is_dataclass
from typing import Any, Dict, Iterator, List, Literal, Protocol, Tuple, TypeVar, TypedDict, Union, runtime_checkable
from typing import Any, Dict, Iterator, List, Literal, Protocol, Tuple, TypeVar, TypedDict, Union, runtime_checkable, final

from tarina import lang, safe_eval
from typing_extensions import NotRequired

from nepattern import BasePattern, MatchMode, parser
from nepattern import BasePattern, MatchMode, parser, MatchFailed

TPrefixes = Union[List[Union[str, object]], List[Tuple[object, str]]]
DataUnit = TypeVar("DataUnit", covariant=True)
Expand Down Expand Up @@ -102,6 +105,7 @@ class CommandMeta:
T = TypeVar("T")


@final
class _AllParamPattern(BasePattern[Any, Any]):
def __init__(self):
super().__init__(mode=MatchMode.KEEP, origin=Any, alias="*")
Expand All @@ -110,12 +114,53 @@ def match(self, input_: Any) -> Any: # pragma: no cover
return input_

def __calc_eq__(self, other): # pragma: no cover
return isinstance(other, _AllParamPattern)
return other.__class__ is _AllParamPattern


AllParam = _AllParamPattern()


@final
class _ContextPattern(BasePattern[Any, Tuple[str, Dict[str, Any]]]):
def __init__(self, *names: str, style: Literal["bracket", "parentheses"] = "parentheses"):
if not names:
pat = "$(VAR)" if style == "parentheses" else "${VAR}"
super().__init__(mode=MatchMode.TYPE_CONVERT, origin=Any, alias=pat)
self.pattern = pat
self.regex_pattern = re.compile(r"\$\((.+)\)", re.DOTALL) if style == "parentheses" else re.compile(r"\{(.+)\}", re.DOTALL)
else:
pat = f"$({'|'.join(names)})" if style == "parentheses" else f"${{{'|'.join(names)}}}"
super().__init__(mode=MatchMode.TYPE_CONVERT, origin=Any, alias=pat)
self.pattern = pat
if style == "parentheses":
self.regex_pattern = re.compile(rf"\$\(({'|'.join(map(re.escape, names))})\)")
else:
self.regex_pattern = re.compile(rf"\{{({'|'.join(map(re.escape, names))})\}}")

def match(self, input_: Tuple[str, Dict[str, Any]]) -> Any:
pat, ctx = input_
if not (mat := self.regex_pattern.fullmatch(pat)):
raise MatchFailed(lang.require("nepattern", "content_error").format(target=input_, expected=self.alias))
name = mat.group(1)
if name == "_":
return ctx
if name not in ctx:
try:
return safe_eval(name, ctx)
except Exception:
raise MatchFailed(lang.require("nepattern", "context_error").format(target=input_, expected=self.alias))
return ctx[name]

def __calc_eq__(self, other):
return other.__class__ is _ContextPattern

def __call__(self, *names: str, style: Literal["bracket", "parentheses"] = "parentheses"):
return _ContextPattern(*names, style=style)


ContextVal = _ContextPattern()


class KeyWordVar(BasePattern[T, Any]):
"""对具名参数的包装"""

Expand Down
16 changes: 16 additions & 0 deletions tests/args_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,22 @@ def test_multi_multi():
assert analyse_args(arg20_1, ["1 2 a b"]) == {"foo": (1, 2), "bar": ("a", "b")}


def test_contextval():
from arclet.alconna import ContextVal

arg21 = Args["foo", ContextVal]
assert analyse_args(arg21, ["$(bar)"], bar="baz") == {"foo": "baz"}
assert analyse_args(arg21, ["{bar}"], raise_exception=False, bar="baz") != {"foo": "baz"}

arg21_1 = Args["foo", ContextVal(style="bracket")]
assert analyse_args(arg21_1, ["{bar}"], bar="baz") == {"foo": "baz"}
assert analyse_args(arg21_1, ["$(bar)"], raise_exception=False, bar="baz") != {"foo": "baz"}

arg21_2 = Args["foo", ContextVal("bar")]
assert analyse_args(arg21_2, ["$(bar)"], bar="baz") == {"foo": "baz"}
assert analyse_args(arg21_2, ["$(baz)"], raise_exception=False, baz="baz") != {"foo": "baz"}


if __name__ == "__main__":
import pytest

Expand Down
Loading

0 comments on commit fc0c91b

Please sign in to comment.