Skip to content

Commit

Permalink
sistana: perf: hand-written init and slots
Browse files Browse the repository at this point in the history
  • Loading branch information
GreyElaina committed Sep 23, 2024
1 parent b6906d4 commit 3d93f47
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 58 deletions.
3 changes: 1 addition & 2 deletions src/arclet/alconna/sistana/model/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from .pointer import Pointer
from .track import Preset
from .snapshot import AnalyzeSnapshot, SubcommandTraverse

if TYPE_CHECKING:
from .fragment import _Fragment
Expand Down Expand Up @@ -70,8 +71,6 @@ def root_ref(self):
return Pointer().subcommand(self.header)

def create_snapshot(self, ref: Pointer):
from .snapshot import AnalyzeSnapshot, SubcommandTraverse

return AnalyzeSnapshot(
traverses=[
SubcommandTraverse(
Expand Down
8 changes: 4 additions & 4 deletions src/arclet/alconna/sistana/model/pointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,16 @@ def __init__(self, data: tuple[PointerContent, ...] = ()) -> None:
self.data = data

def subcommand(self, name: str):
return Pointer((*self.data, (PointerRole.SUBCOMMAND, name)))
return Pointer(self.data + ((PointerRole.SUBCOMMAND, name), ))

def option(self, name: str):
return Pointer((*self.data, (PointerRole.OPTION, name)))
return Pointer(self.data + ((PointerRole.OPTION, name), ))

def header(self):
return Pointer((*self.data, (PointerRole.HEADER, HEADER_STR)))
return Pointer(self.data + ((PointerRole.HEADER, HEADER_STR), ))

def prefix(self):
return Pointer((*self.data, (PointerRole.PREFIX, PREFIX_STR)))
return Pointer(self.data + ((PointerRole.PREFIX, PREFIX_STR),))

@property
def parent(self):
Expand Down
55 changes: 40 additions & 15 deletions src/arclet/alconna/sistana/model/snapshot.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
from __future__ import annotations

from collections import defaultdict
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Generic, TypeVar

from arclet.alconna._dcls import safe_dcls_kw

from ..utils.misc import Some, Value

if TYPE_CHECKING:
Expand All @@ -16,22 +13,37 @@
T = TypeVar("T")


@dataclass(**safe_dcls_kw(slots=True))
class OptionTraverse:
__slots__ = ("trigger", "is_compact", "completed", "option", "track")

trigger: str
is_compact: bool
completed: bool
option: OptionPattern
track: Track

def __init__(self, trigger: str, is_compact: bool, completed: bool, option: OptionPattern, track: Track):
self.trigger = trigger
self.is_compact = is_compact
self.completed = completed
self.option = option
self.track = track


@dataclass(**safe_dcls_kw(slots=True))
class IndexedOptionTraversesRecord:
traverses: list[OptionTraverse] = field(default_factory=list)
__slots__ = ("traverses", "_by_trigger", "_by_keyword", "_count")

traverses: list[OptionTraverse]

_by_trigger: dict[str, list[OptionTraverse]]
_by_keyword: dict[str, list[OptionTraverse]]
_count: defaultdict[str, int]

_by_trigger: dict[str, list[OptionTraverse]] = field(default_factory=dict, repr=False)
_by_keyword: dict[str, list[OptionTraverse]] = field(default_factory=dict, repr=False)
_count: defaultdict[str, int] = field(default_factory=lambda: defaultdict(lambda: 0), repr=False)
def __init__(self, traverses: list[OptionTraverse] | None = None):
self.traverses = traverses or []
self._by_trigger = {}
self._by_keyword = {}
self._count = defaultdict(lambda: 0)

def append(self, traverse: OptionTraverse):
self.traverses.append(traverse)
Expand All @@ -40,7 +52,7 @@ def append(self, traverse: OptionTraverse):
self._by_trigger[traverse.trigger].append(traverse)
else:
self._by_trigger[traverse.trigger] = [traverse]

if traverse.option.keyword in self._by_keyword:
self._by_keyword[traverse.option.keyword].append(traverse)
else:
Expand All @@ -64,19 +76,32 @@ def __contains__(self, keyword: str):
return keyword in self._by_keyword


@dataclass(**safe_dcls_kw(slots=True))
class SubcommandTraverse:
__slots__ = ("subcommand", "trigger", "ref", "mix", "option_traverses")

subcommand: SubcommandPattern
trigger: str
ref: Pointer
mix: Mix
option_traverses: IndexedOptionTraversesRecord = field(default_factory=IndexedOptionTraversesRecord)
option_traverses: IndexedOptionTraversesRecord

def __init__(self, subcommand: SubcommandPattern, trigger: str, ref: Pointer, mix: Mix):
self.subcommand = subcommand
self.trigger = trigger
self.ref = ref
self.mix = mix
self.option_traverses = IndexedOptionTraversesRecord()


@dataclass(**safe_dcls_kw(slots=True))
class AnalyzeSnapshot(Generic[T]):
traverses: list[SubcommandTraverse] = field(default_factory=list)
endpoint: Some[Pointer] = None
__slots__ = ("traverses", "endpoint")

traverses: list[SubcommandTraverse]
endpoint: Some[Pointer]

def __init__(self, traverses: list[SubcommandTraverse] | None = None, endpoint: Some[Pointer] = None):
self.traverses = traverses or []
self.endpoint = endpoint

@property
def determined(self):
Expand Down
55 changes: 18 additions & 37 deletions src/arclet/alconna/sistana/model/track.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
from __future__ import annotations

from collections import deque
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any

from arclet.alconna._dcls import safe_dcls_kw

from ..err import ReceivePanic, TransformPanic, ValidateRejected
from ..utils.misc import Value
from .fragment import _Fragment, assert_fragments_order
Expand All @@ -15,10 +12,15 @@
from .receiver import RxGet, RxPut


@dataclass(**safe_dcls_kw(slots=True))
class Track:
__slots__ = ("fragments", "assignes")

fragments: deque[_Fragment]
assignes: dict[str, Any] = field(default_factory=dict)
assignes: dict[str, Any]

def __init__(self, fragments: deque[_Fragment], assignes: dict[str, Any] | None = None):
self.fragments = fragments
self.assignes = assignes or {}

@property
def satisfied(self):
Expand All @@ -44,32 +46,6 @@ def complete(self):
if first.variadic and first.name not in self.assignes:
self.assignes[first.name] = []


# def fetch(
# self,
# frag: _Fragment,
# buffer: Buffer,
# upper_separators: str,
# rxget: RxGet[Any] | None = None, # type: ignore
# rxput: RxPut[Any] | None = None, # type: ignore
# ):
# if frag.separators is not None:
# if frag.hybrid_separators:
# separators = frag.separators + upper_separators
# else:
# separators = frag.separators
# else:
# separators = upper_separators

# token = buffer.next(separators)

# if rxput is None:
# def rxput(val):
# self.assignes[frag.name] = val

# rxput(token.val)
# token.apply()

def fetch(
self,
frag: _Fragment,
Expand Down Expand Up @@ -153,11 +129,12 @@ def copy(self):
return Track(self.fragments.copy(), self.assignes.copy())


@dataclass(**safe_dcls_kw(slots=True))
class Preset:
tracks: dict[str, deque[_Fragment]] = field(default_factory=dict)
tracks: dict[str, deque[_Fragment]]

def __init__(self, tracks: dict[str, deque[_Fragment]] | None = None):
self.tracks = tracks or {}

def __post_init__(self):
for fragments in self.tracks.values():
assert_fragments_order(fragments)

Expand All @@ -168,12 +145,16 @@ def new_mix(self) -> Mix:
return Mix(self)


@dataclass(**safe_dcls_kw(slots=True))
class Mix:
__slots__ = ("preset", "tracks")

preset: Preset
tracks: dict[str, Track] = field(init=False, default_factory=dict)
tracks: dict[str, Track]

def __init__(self, preset: Preset):
self.preset = preset
self.tracks = {}

def __post_init__(self):
for name in self.preset.tracks:
self.init_track(name)

Expand Down

0 comments on commit 3d93f47

Please sign in to comment.