Skip to content

Commit

Permalink
sistana: new receiver, endpoint determine
Browse files Browse the repository at this point in the history
  • Loading branch information
GreyElaina committed Sep 20, 2024
1 parent 4112acc commit 9cdfd3b
Show file tree
Hide file tree
Showing 9 changed files with 87 additions and 75 deletions.
6 changes: 6 additions & 0 deletions src/arclet/alconna/sistana/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,11 @@ def loopflow(self, snapshot: AnalyzeSnapshot[T], buffer: Buffer[T]) -> tuple[Loo
# 在 option context 里面,因为 satisfied 了,所以可以直接返回 completed。
# 并且还得确保 option 也被记录于 activated_options 里面。
if pointer_type == "option":
mix.tracks[pointer_val].complete()
traverse.activated_options.add(pointer_val)
traverse.ref = traverse.ref.parent

snapshot.determine(traverse.ref)
return LoopflowDescription.completed, snapshot

# 这里如果没有 satisfied,如果是 option 的 track,则需要 reset
Expand Down Expand Up @@ -182,6 +184,7 @@ def loopflow(self, snapshot: AnalyzeSnapshot[T], buffer: Buffer[T]) -> tuple[Loo
elif pointer_type == "option":
if token.val in context.subcommands:
# 当且仅当 option 已经 satisfied 时才能让状态流转进 subcommand。
# subcommand.satisfy_previous 处理起来比较复杂,这里先 reject。
subcommand = context.subcommands[token.val]
option = context.options[pointer_val]
track = mix.tracks[option.keyword]
Expand All @@ -191,6 +194,7 @@ def loopflow(self, snapshot: AnalyzeSnapshot[T], buffer: Buffer[T]) -> tuple[Loo
mix.reset(option.keyword)
return LoopflowDescription.switch_unsatisfied_option, snapshot
else:
mix.tracks[option.keyword].complete()
traverse.ref = traverse.ref.parent
traverse.activated_options.add(pointer_val)
# TODO: 重新考虑 traverse 记录 option 的方法
Expand Down Expand Up @@ -224,6 +228,7 @@ def loopflow(self, snapshot: AnalyzeSnapshot[T], buffer: Buffer[T]) -> tuple[Loo
mix.reset(target_option.keyword)
return LoopflowDescription.option_switch_prohibited_direction, snapshot
else:
mix.tracks[target_option.keyword].complete()
traverse.ref = traverse.ref.parent
traverse.activated_options.add(pointer_val)

Expand Down Expand Up @@ -275,6 +280,7 @@ def loopflow(self, snapshot: AnalyzeSnapshot[T], buffer: Buffer[T]) -> tuple[Loo
else:
if response is None:
# track 上没有 fragments 可供分配了。
# 这里没必要 complete。

traverse.ref = traverse.ref.parent
traverse.activated_options.add(origin_option.keyword)
Expand Down
6 changes: 3 additions & 3 deletions src/arclet/alconna/sistana/err.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,6 @@ class CaptureRejected(Rejected): ...
class ValidateRejected(Rejected): ...


class ReceiveRejected(Rejected): ...


@dataclass
class UnexpectedType(CaptureRejected):
expected: type | tuple[type, ...]
Expand All @@ -47,3 +44,6 @@ class ParsePanic(Exception): ...


class TransformPanic(ParsePanic): ...


class ReceivePanic(ParsePanic): ...
19 changes: 9 additions & 10 deletions src/arclet/alconna/sistana/fragment.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,40 +5,39 @@
from .model.fragment import _Fragment



@dataclass
class Fragment(_Fragment):
def apply_msgspec(self):
if self.type is None:
return

t = self.type.value

from msgspec import convert, ValidationError

def _validate(v: Segment):
if not isinstance(v, (str, Quoted, UnmatchedQuoted)):
return False

v = str(v)

try:
convert(v, t)
except ValidationError:
return False

return True

def _transform(v: Segment):
return convert(str(v), t)

self.validator = _validate
self.transformer = _transform

def apply_nepattern(self):
if self.type is None:
return

from nepattern import type_parser

pat = type_parser(self.type.value)
Expand All @@ -51,9 +50,9 @@ def _validate(v: Segment):
v = v.ref

return pat.validate(v).success

def _transform(v: Segment):
return pat.transform(str(v)).value()

self.validator = _validate
self.transformer = _transform
4 changes: 2 additions & 2 deletions src/arclet/alconna/sistana/model/capture.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import re
from dataclasses import dataclass
from typing import Any, Generic, TypeVar, Union
from typing import Any, Generic, Tuple, TypeVar, Union

from elaina_segment import Quoted, UnmatchedQuoted
from typing_extensions import TypeAlias
Expand All @@ -14,7 +14,7 @@

T = TypeVar("T")

CaptureResult: TypeAlias = "tuple[T, Some[Any], Union[SegmentToken[T], AheadToken[T]]]"
CaptureResult = Tuple[T, Some[Any], Union[SegmentToken[T], AheadToken[T]]]


class Capture(Generic[T]):
Expand Down
2 changes: 1 addition & 1 deletion src/arclet/alconna/sistana/model/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from .track import Preset

if TYPE_CHECKING:
from .fragment import _Fragment
pass


@dataclass
Expand Down
2 changes: 1 addition & 1 deletion src/arclet/alconna/sistana/model/pointer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Literal, Union, Tuple
from typing import Literal, Tuple, Union

PointerRole = Literal["subcommand", "option"]
PointerPair = Tuple[PointerRole, str]
Expand Down
48 changes: 21 additions & 27 deletions src/arclet/alconna/sistana/model/receiver.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,37 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Generic, TypeVar
from typing import Any, Callable, Generic, TypeVar

if TYPE_CHECKING:
from .snapshot import AnalyzeSnapshot
from ..utils.misc import Some

T = TypeVar("T")

T = TypeVar("T")

@dataclass
class Rx(Generic[T]):
name: str
RxGet = Callable[[], Some[T]]
RxPut = Callable[[T], None]

def receive(self, snapshot: AnalyzeSnapshot, data: T) -> None:
snapshot.cache[self.name] = data

def load(self, snapshot: AnalyzeSnapshot):
return snapshot.cache[self.name]
class Rx(Generic[T]):
def receive(self, get: RxGet[Any], put: RxPut[Any], data: T) -> None:
if get() is None:
put(data)


@dataclass
class CountRx(Rx[Any]):
def receive(self, snapshot: AnalyzeSnapshot, data: Any) -> None:
snapshot.cache[self.name] = snapshot.cache.get(self.name, 0) + 1
def receive(self, get: RxGet[int], put: RxPut[int], data: Any) -> None:
v = get()

def load(self, snapshot: AnalyzeSnapshot):
return snapshot.cache.get(self.name, 0)
if v is None:
put(1)
else:
put(v.value + 1)


@dataclass
class AccumRx(Rx[T]):
def receive(self, snapshot: AnalyzeSnapshot, data: str) -> None:
if self.name in snapshot.cache:
target = snapshot.cache[self.name]
else:
target = snapshot.cache[self.name] = []
def receive(self, get: RxGet[list[T]], put: RxPut[list[T]], data: T) -> None:
v = get()

target.append(data)

def load(self, snapshot: AnalyzeSnapshot):
return snapshot.cache.get(self.name) or []
if v is None:
put([data])
else:
put([*v.value, data])
16 changes: 12 additions & 4 deletions src/arclet/alconna/sistana/model/snapshot.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Generic, TypeVar
from typing import TYPE_CHECKING, Generic, TypeVar

from ..utils.misc import Some, Value

if TYPE_CHECKING:
from .track import Mix
from .pattern import OptionPattern, SubcommandPattern
from .pattern import SubcommandPattern
from .pointer import Pointer
from .track import Mix

T = TypeVar("T")

Expand All @@ -29,4 +30,11 @@ def satisfied(self):
@dataclass
class AnalyzeSnapshot(Generic[T]):
traverses: list[SubcommandTraverse] = field(default_factory=list)
cache: dict[str, Any] = field(default_factory=dict)
endpoint: Some[Pointer] = None

@property
def determined(self):
return self.endpoint is not None

def determine(self, endpoint: Pointer):
self.endpoint = Value(endpoint)
59 changes: 32 additions & 27 deletions src/arclet/alconna/sistana/model/track.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any

from ..err import TransformPanic, ValidateRejected
from ..err import ReceivePanic, TransformPanic, ValidateRejected
from ..utils.misc import Value
from .fragment import _Fragment, assert_fragments_order

if TYPE_CHECKING:
Expand Down Expand Up @@ -41,13 +42,28 @@ def complete(self):
if first.variadic and first.name not in self.assignes:
self.assignes[first.name] = []

def fetch(
self,
snapshot: AnalyzeSnapshot,
frag: _Fragment,
buffer: Buffer,
separators: str,
):
def _assign_getter(self, name: str):
def getter():
if name in self.assignes:
return Value(self.assignes[name])

return getter

def _assign_setter(self, name: str, is_variadic: bool = False):
def setter(val):
if is_variadic:
if name not in self.assignes:
target = self.assignes[name] = []
else:
target = self.assignes[name]

target.append(val)
else:
self.assignes[name] = val

return setter

def fetch(self, snapshot: AnalyzeSnapshot, frag: _Fragment, buffer: Buffer, separators: str):
val, tail, token = frag.capture.capture(buffer, separators)

if frag.validator is not None and not frag.validator(val):
Expand All @@ -61,36 +77,25 @@ def fetch(

if frag.receiver is not None:
try:
frag.receiver.receive(snapshot, val)
frag.receiver.receive(self._assign_getter(frag.name), self._assign_setter(frag.name, frag.variadic), val)
except Exception as e:
raise ValidateRejected from e

val = frag.receiver.load(snapshot)
raise ReceivePanic from e

if tail is not None:
buffer.ahead.append(tail)

token.apply()
return val

def forward(self, snapshot: AnalyzeSnapshot, buffer: Buffer, separators: str):
if not self.fragments:
return

first = self.fragments[0]
val = self.fetch(snapshot, first, buffer, separators)
self.fetch(snapshot, first, buffer, separators)

if first.variadic:
if first.name not in self.assignes:
variadics = self.assignes[first.name] = []
else:
variadics = self.assignes[first.name]
if not first.variadic:
self.fragments.popleft()

variadics.append(val)
return first

self.assignes[first.name] = val
self.fragments.popleft()
return first

@property
Expand Down Expand Up @@ -162,6 +167,6 @@ def complete_all(self):
for track in self.tracks.values():
track.complete()

@property
def result(self):
return {name: track.assignes for name, track in self.tracks.items()}
# @property
# def result(self):
# return {name: track.assignes for name, track in self.tracks.items()}

0 comments on commit 9cdfd3b

Please sign in to comment.