Skip to content

Commit

Permalink
0.5.1
Browse files Browse the repository at this point in the history
  • Loading branch information
RF-Tar-Railt authored Jan 11, 2022
1 parent 540ec90 commit 6a60d30
Showing 1 changed file with 22 additions and 70 deletions.
92 changes: 22 additions & 70 deletions arclet/alconna/component.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,22 @@
"""Alconna 的组件相关"""

import re
import inspect
from typing import Union, Dict, List, Any, Optional, Callable, Type, Tuple
from dataclasses import dataclass
from .util import split_once
from .exceptions import InvalidParam
from .types import NonTextElement, Args


@dataclass
class CommandInterface:
"""命令体基类, 规定基础命令的参数"""
name: str
args: Args
separator: str = " "

def separate(self, sep: str):
"""设置命令头与命令参数的分隔符"""
self.separator = sep
return self

def help(self, help_string: str):
"""预处理 help 文档"""
setattr(
self, "help_doc",
f"# {help_string}\n {self.name}{self.separator}{self.args.params(self.separator)}\n"
)
return self

def __getitem__(self, item):
self.args.__merge__(Args.__class_getitem__(item))
return self
from typing import Union, Dict, List, Any, Optional, Callable, Type
from .util import split_once
from .types import NonTextElement
from .base import TemplateCommand, Args


class Option(CommandInterface):
class Option(TemplateCommand):
"""命令选项, 可以使用别名"""
alias: str
action: Callable[[Any], Union[List, Tuple]]

def __init__(self, name: str, args: Optional[Args] = None, alias: Optional[str] = None,
actions: Optional[Callable] = None, **kwargs):
if name == "":
raise InvalidParam("选项的名字不能为空")
if re.match(r"^[`~?/.,<>;\':\"|!@#$%^&*()_+=\[\]}{]+.*$", name):
raise InvalidParam("选项的名字含有非法字符")
if "|" in name:
name, alias = name.replace(' ', '').split('|')
self.name = name
super().__init__(name, args, actions, **kwargs)
self.alias = alias or name
self.args = args or Args(**kwargs)
self.__check_action__(actions)

def help(self, help_string: str):
"""预处理 help 文档"""
Expand All @@ -60,40 +26,16 @@ def help(self, help_string: str):
f"# {help_string}\n {alias}{self.name}{self.separator}{self.args.params(self.separator)}\n")
return self

def __check_action__(self, action):
if action:
argument = [
(name, param.annotation, param.default) for name, param in inspect.signature(action).parameters.items()
]
if len(argument) != len(self.args.argument):
raise InvalidParam("action 接受的参数个数必须与 Args 里的一致")

def _act_(*items: Any):
try:
result = action(*items)
if not isinstance(result, tuple):
result = [result]
return result
except Exception:
return items
self.action = _act_
else:
self.action = action


class Subcommand(CommandInterface):

class Subcommand(TemplateCommand):
"""子命令, 次于主命令, 可解析 SubOption"""
options: List[Option]
sub_params: Dict[str, Union[Args, Option]]

def __init__(self, name: str, *option: Option, args: Optional[Args] = None, **kwargs):
if name == "":
raise InvalidParam("子命令的名字不能为空")
if re.match(r"^[`~?/.,<>;\':\"|!@#$%^&*()_+=\[\]}{]+.*$", name):
raise InvalidParam("子命令的名字含有非法字符")
self.name = name
def __init__(self, name: str, *option: Option, args: Optional[Args] = None,
actions: Optional[Callable] = None, **kwargs):
super().__init__(name, args, actions, **kwargs)
self.options = list(option)
self.args = args or Args(**kwargs)
self.sub_params = {"sub_args": self.args}

def help(self, help_string: str):
Expand Down Expand Up @@ -183,7 +125,7 @@ def get(self, name: Union[str, Type[NonTextElement]]) -> Union[Dict, str, NonTex
if name in self._args:
return self._args[name]
for _, v in self.all_matched_args:
if type(v) is name:
if isinstance(v, name):
return v

def has(self, name: str) -> bool:
Expand Down Expand Up @@ -222,6 +164,16 @@ def next_data(self, separate: Optional[str] = None, pop: bool = True) -> Union[s
self.current_index += 1
return _current_data

def reduce_data(self, data: Union[str, NonTextElement]):
try:
if isinstance(data, str) and isinstance(self.raw_data[self.current_index], list):
self.raw_data[self.current_index].insert(0, data)
else:
self.current_index -= 1
self.raw_data[self.current_index] = [data]
except KeyError:
pass

def __repr__(self):
attrs = ((s, getattr(self, s)) for s in self.__slots__)
return " ".join([f"{a}={v}" for a, v in attrs if v is not None])

0 comments on commit 6a60d30

Please sign in to comment.