diff --git a/CHANGELOG b/CHANGELOG index f5f8df4..1fae6a2 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -1,3 +1,7 @@ +# v0.1.0.rc4 + +* Fix context propagation in the child wrappers + # v0.1.0.rc3 * Adds support for isolated runs diff --git a/clea/params.py b/clea/params.py index b956347..befd38a 100644 --- a/clea/params.py +++ b/clea/params.py @@ -403,6 +403,10 @@ def parse(self, value: t.Any) -> Path: class ContextParameter(Parameter[Context]): """Context parameter.""" + def set(self, context: Context) -> None: + """Set context.""" + self._default = context + class VersionParameter(Parameter[str]): """Version parameter.""" diff --git a/clea/parser.py b/clea/parser.py index eb56c01..44634e3 100644 --- a/clea/parser.py +++ b/clea/parser.py @@ -3,8 +3,9 @@ import typing as t from collections import deque +from clea.context import Context from clea.exceptions import ArgumentsMissing, ExtraArgumentProvided -from clea.params import ChoiceByFlag, Parameter +from clea.params import ChoiceByFlag, ContextParameter, Parameter Argv = t.List[str] @@ -29,6 +30,11 @@ def __init__(self) -> None: self._kwargs = {} self._args = deque() + def set_context(self, context: Context) -> None: + """Set context.""" + if "--context" in self._kwargs: + t.cast(ContextParameter, self._kwargs["--context"]).set(context=context) + def get_arg_vars(self) -> t.List[str]: """Get a t.list of metavars.""" return list(map(lambda x: x.var, self._args)) diff --git a/clea/wrappers.py b/clea/wrappers.py index 0eccb71..7355f3f 100644 --- a/clea/wrappers.py +++ b/clea/wrappers.py @@ -48,8 +48,6 @@ def __init__( self.name = name or f.__name__ self.version = version self.parent = parent - if self.parent is not None: - self.parent.add_child(self) def __call__(self, *args: t.Any, **kwds: t.Any) -> t.Any: """Call the base function. @@ -81,6 +79,11 @@ def _invoke( return 1 raise + def set_context(self, context: Context) -> None: + """Set context.""" + self.context = context + self._parser.set_context(context=context) + def help(self) -> int: """ Print help string. @@ -142,6 +145,8 @@ def __init__( f=f, context=context, name=name, version=version, parent=parent ) self._parser = parser + if self.parent is not None: + self.parent.add_child(self) def invoke(self, argv: Argv, isolated: bool = False) -> int: """Run the command. @@ -282,20 +287,74 @@ def __init__( :return: None """ super().__init__( - f=f, context=context, name=name, version=version, parent=parent + f=f, + context=context, + name=name, + version=version, + parent=parent, ) - self._parser = parser self._children = {} + self._parser = parser self._allow_direct_exec = allow_direct_exec + if self.parent is not None: + self.parent.add_child(self) self.command = partial(Command.wrap, parent=self, context=self.context) self.group = partial(self.wrap, parent=self, context=self.context) - def add_child(self, child: t.Any) -> None: + def add_child(self, child: t.Union[Command, "Group"]) -> None: """Add child node.""" + if self.context is not None: + child.set_context(context=self.context) self._children[t.cast(BaseWrapper, child).name] = child + @classmethod + def _wrap( + cls, + f: t.Callable, + context: t.Optional[Context] = None, + version: t.Optional[str] = None, + **kwargs: t.Any, + ) -> "Group": + """ + Decorator function to wrap a function as a command. + + :param f: The function to be wrapped. + :type f: t.callable + :return: A `Command` object representing the wrapped function. + :rtype: Command + """ + parser = GroupParser() + context = context or Context() + if version: + version_param = p.VersionParameter( + long_flag="--version", + help="Program version", + ) + version_param.name = "version" + version_param.default = version + parser.add(version_param) + defaults_mapping, annotations = get_function_metadata(f=f) + for name, annotation in t.cast(t.Dict[str, Annotations], annotations).items(): + if name == "return": + continue + if name == "context": + context_param = p.ContextParameter() + context_param.name = "context" + context_param.default = context + parser.add(defintion=context_param) + continue + (parameter,) = t.cast( + t.Tuple[p.Parameter, ...], getattr(annotation, "__metadata__") + ) + default = defaults_mapping.get(name) + if default is not None: + parameter.default = default + parameter.name = name + parser.add(defintion=parameter) + return cls(f=f, parser=parser, context=context, version=version, **kwargs) + @t.overload @classmethod def wrap( @@ -360,52 +419,6 @@ def wrap( version=version, ) - @classmethod - def _wrap( - cls, - f: t.Callable, - context: t.Optional[Context] = None, - version: t.Optional[str] = None, - **kwargs: t.Any, - ) -> "Group": - """ - Decorator function to wrap a function as a command. - - :param f: The function to be wrapped. - :type f: t.callable - :return: A `Command` object representing the wrapped function. - :rtype: Command - """ - parser = GroupParser() - context = context or Context() - if version: - version_param = p.VersionParameter( - long_flag="--version", - help="Program version", - ) - version_param.name = "version" - version_param.default = version - parser.add(version_param) - defaults_mapping, annotations = get_function_metadata(f=f) - for name, annotation in t.cast(t.Dict[str, Annotations], annotations).items(): - if name == "return": - continue - if name == "context": - context_param = p.ContextParameter() - context_param.name = "context" - context_param.default = context - parser.add(defintion=context_param) - continue - (parameter,) = t.cast( - t.Tuple[p.Parameter, ...], getattr(annotation, "__metadata__") - ) - default = defaults_mapping.get(name) - if default is not None: - parameter.default = default - parameter.name = name - parser.add(defintion=parameter) - return cls(f=f, parser=parser, context=context, version=version, **kwargs) - def invoke(self, argv: Argv, isolated: bool = False) -> int: """Run the command.""" ( diff --git a/docs/context.md b/docs/context.md index 0786983..f8dd45d 100644 --- a/docs/context.md +++ b/docs/context.md @@ -10,7 +10,7 @@ Context object provides a simple key-value data store for storing and retrieving from clea.context import Context from clea.runner import run -from clea.wrappers import group +from clea.wrappers import command, group @group @@ -19,19 +19,21 @@ def admin(context: Context) -> None: context.set("foo", "bar") -@admin.group(name="manage") +@admin.group(name="manage", allow_direct_exec=True) def manage(context: Context) -> None: """Manage.""" context.set("hello", "world") -@manage.command +@command def student(context: Context) -> None: - """Manage.""" + """Student.""" print(context.get("foo")) print(context.get("hello")) +manage.add_child(student) + if __name__ == "__main__": run(cli=admin) ``` diff --git a/docs/upgrading.md b/docs/upgrading.md index fcfd076..78ce651 100644 --- a/docs/upgrading.md +++ b/docs/upgrading.md @@ -1,3 +1,7 @@ +## `0.1.0.rc3` to `0.1.0.rc4` + +- No backwards incompatible changes + ## `0.1.0.rc2` to `0.1.0.rc3` - No backwards incompatible changes diff --git a/examples/context.py b/examples/context.py index 6304f71..9ee174a 100644 --- a/examples/context.py +++ b/examples/context.py @@ -2,7 +2,7 @@ from clea.context import Context from clea.runner import run -from clea.wrappers import group +from clea.wrappers import command, group @group @@ -11,18 +11,20 @@ def admin(context: Context) -> None: context.set("foo", "bar") -@admin.group(name="manage") +@admin.group(name="manage", allow_direct_exec=True) def manage(context: Context) -> None: """Manage.""" context.set("hello", "world") -@manage.command +@command def student(context: Context) -> None: - """Manage.""" + """Student.""" print(context.get("foo")) print(context.get("hello")) +manage.add_child(student) + if __name__ == "__main__": # pragma: nocover run(cli=admin) diff --git a/pyproject.toml b/pyproject.toml index 73adf1d..8510a76 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "clea" -version = "0.1.0.rc3" +version = "0.1.0.rc4" description = "Framework for writing CLI application quickly" readme = "README.md" authors = ["angrybayblade "]