Skip to content

Commit

Permalink
Merge pull request #21 from angrybayblade/feat/context-propogation
Browse files Browse the repository at this point in the history
Fix context propagation
  • Loading branch information
angrybayblade authored Dec 4, 2023
2 parents 0710f2b + 6304361 commit 1cce897
Show file tree
Hide file tree
Showing 8 changed files with 96 additions and 61 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# v0.1.0.rc4

* Fix context propagation in the child wrappers

# v0.1.0.rc3

* Adds support for isolated runs
Expand Down
4 changes: 4 additions & 0 deletions clea/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,10 @@ def parse(self, value: t.Any) -> Path:
class ContextParameter(Parameter[Context]):
"""Context parameter."""

def set(self, context: Context) -> None:
"""Set context."""
self._default = context


class VersionParameter(Parameter[str]):
"""Version parameter."""
8 changes: 7 additions & 1 deletion clea/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import typing as t
from collections import deque

from clea.context import Context
from clea.exceptions import ArgumentsMissing, ExtraArgumentProvided
from clea.params import ChoiceByFlag, Parameter
from clea.params import ChoiceByFlag, ContextParameter, Parameter


Argv = t.List[str]
Expand All @@ -29,6 +30,11 @@ def __init__(self) -> None:
self._kwargs = {}
self._args = deque()

def set_context(self, context: Context) -> None:
"""Set context."""
if "--context" in self._kwargs:
t.cast(ContextParameter, self._kwargs["--context"]).set(context=context)

def get_arg_vars(self) -> t.List[str]:
"""Get a t.list of metavars."""
return list(map(lambda x: x.var, self._args))
Expand Down
115 changes: 64 additions & 51 deletions clea/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,6 @@ def __init__(
self.name = name or f.__name__
self.version = version
self.parent = parent
if self.parent is not None:
self.parent.add_child(self)

def __call__(self, *args: t.Any, **kwds: t.Any) -> t.Any:
"""Call the base function.
Expand Down Expand Up @@ -81,6 +79,11 @@ def _invoke(
return 1
raise

def set_context(self, context: Context) -> None:
"""Set context."""
self.context = context
self._parser.set_context(context=context)

def help(self) -> int:
"""
Print help string.
Expand Down Expand Up @@ -142,6 +145,8 @@ def __init__(
f=f, context=context, name=name, version=version, parent=parent
)
self._parser = parser
if self.parent is not None:
self.parent.add_child(self)

def invoke(self, argv: Argv, isolated: bool = False) -> int:
"""Run the command.
Expand Down Expand Up @@ -282,20 +287,74 @@ def __init__(
:return: None
"""
super().__init__(
f=f, context=context, name=name, version=version, parent=parent
f=f,
context=context,
name=name,
version=version,
parent=parent,
)

self._parser = parser
self._children = {}
self._parser = parser
self._allow_direct_exec = allow_direct_exec
if self.parent is not None:
self.parent.add_child(self)

self.command = partial(Command.wrap, parent=self, context=self.context)
self.group = partial(self.wrap, parent=self, context=self.context)

def add_child(self, child: t.Any) -> None:
def add_child(self, child: t.Union[Command, "Group"]) -> None:
"""Add child node."""
if self.context is not None:
child.set_context(context=self.context)
self._children[t.cast(BaseWrapper, child).name] = child

@classmethod
def _wrap(
cls,
f: t.Callable,
context: t.Optional[Context] = None,
version: t.Optional[str] = None,
**kwargs: t.Any,
) -> "Group":
"""
Decorator function to wrap a function as a command.
:param f: The function to be wrapped.
:type f: t.callable
:return: A `Command` object representing the wrapped function.
:rtype: Command
"""
parser = GroupParser()
context = context or Context()
if version:
version_param = p.VersionParameter(
long_flag="--version",
help="Program version",
)
version_param.name = "version"
version_param.default = version
parser.add(version_param)
defaults_mapping, annotations = get_function_metadata(f=f)
for name, annotation in t.cast(t.Dict[str, Annotations], annotations).items():
if name == "return":
continue
if name == "context":
context_param = p.ContextParameter()
context_param.name = "context"
context_param.default = context
parser.add(defintion=context_param)
continue
(parameter,) = t.cast(
t.Tuple[p.Parameter, ...], getattr(annotation, "__metadata__")
)
default = defaults_mapping.get(name)
if default is not None:
parameter.default = default
parameter.name = name
parser.add(defintion=parameter)
return cls(f=f, parser=parser, context=context, version=version, **kwargs)

@t.overload
@classmethod
def wrap(
Expand Down Expand Up @@ -360,52 +419,6 @@ def wrap(
version=version,
)

@classmethod
def _wrap(
cls,
f: t.Callable,
context: t.Optional[Context] = None,
version: t.Optional[str] = None,
**kwargs: t.Any,
) -> "Group":
"""
Decorator function to wrap a function as a command.
:param f: The function to be wrapped.
:type f: t.callable
:return: A `Command` object representing the wrapped function.
:rtype: Command
"""
parser = GroupParser()
context = context or Context()
if version:
version_param = p.VersionParameter(
long_flag="--version",
help="Program version",
)
version_param.name = "version"
version_param.default = version
parser.add(version_param)
defaults_mapping, annotations = get_function_metadata(f=f)
for name, annotation in t.cast(t.Dict[str, Annotations], annotations).items():
if name == "return":
continue
if name == "context":
context_param = p.ContextParameter()
context_param.name = "context"
context_param.default = context
parser.add(defintion=context_param)
continue
(parameter,) = t.cast(
t.Tuple[p.Parameter, ...], getattr(annotation, "__metadata__")
)
default = defaults_mapping.get(name)
if default is not None:
parameter.default = default
parameter.name = name
parser.add(defintion=parameter)
return cls(f=f, parser=parser, context=context, version=version, **kwargs)

def invoke(self, argv: Argv, isolated: bool = False) -> int:
"""Run the command."""
(
Expand Down
10 changes: 6 additions & 4 deletions docs/context.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Context object provides a simple key-value data store for storing and retrieving

from clea.context import Context
from clea.runner import run
from clea.wrappers import group
from clea.wrappers import command, group


@group
Expand All @@ -19,19 +19,21 @@ def admin(context: Context) -> None:
context.set("foo", "bar")


@admin.group(name="manage")
@admin.group(name="manage", allow_direct_exec=True)
def manage(context: Context) -> None:
"""Manage."""
context.set("hello", "world")


@manage.command
@command
def student(context: Context) -> None:
"""Manage."""
"""Student."""
print(context.get("foo"))
print(context.get("hello"))


manage.add_child(student)

if __name__ == "__main__":
run(cli=admin)
```
Expand Down
4 changes: 4 additions & 0 deletions docs/upgrading.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
## `0.1.0.rc3` to `0.1.0.rc4`

- No backwards incompatible changes

## `0.1.0.rc2` to `0.1.0.rc3`

- No backwards incompatible changes
Expand Down
10 changes: 6 additions & 4 deletions examples/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from clea.context import Context
from clea.runner import run
from clea.wrappers import group
from clea.wrappers import command, group


@group
Expand All @@ -11,18 +11,20 @@ def admin(context: Context) -> None:
context.set("foo", "bar")


@admin.group(name="manage")
@admin.group(name="manage", allow_direct_exec=True)
def manage(context: Context) -> None:
"""Manage."""
context.set("hello", "world")


@manage.command
@command
def student(context: Context) -> None:
"""Manage."""
"""Student."""
print(context.get("foo"))
print(context.get("hello"))


manage.add_child(student)

if __name__ == "__main__": # pragma: nocover
run(cli=admin)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "clea"
version = "0.1.0.rc3"
version = "0.1.0.rc4"
description = "Framework for writing CLI application quickly"
readme = "README.md"
authors = ["angrybayblade <[email protected]>"]
Expand Down

0 comments on commit 1cce897

Please sign in to comment.