From fbec5adc04a5598a98fee2a5650d5145f2c65c20 Mon Sep 17 00:00:00 2001 From: Tony Meyer Date: Fri, 5 Jul 2024 13:04:52 +1200 Subject: [PATCH] Fix minor upstream issues. --- README.md | 32 ++--- scenario/__init__.py | 8 +- scenario/context.py | 22 +-- scenario/mocking.py | 24 ++-- scenario/state.py | 208 +++++++++++++++++++++------- tests/test_consistency_checker.py | 26 ++-- tests/test_context.py | 22 ++- tests/test_context_on.py | 14 +- tests/test_e2e/test_actions.py | 16 ++- tests/test_e2e/test_pebble.py | 17 ++- tests/test_e2e/test_ports.py | 17 ++- tests/test_e2e/test_relations.py | 52 +++++++ tests/test_e2e/test_secrets.py | 20 +++ tests/test_e2e/test_state.py | 69 ++++++++- tests/test_e2e/test_stored_state.py | 17 ++- 15 files changed, 445 insertions(+), 119 deletions(-) diff --git a/README.md b/README.md index f7a794e3..0b0fcf80 100644 --- a/README.md +++ b/README.md @@ -397,6 +397,8 @@ meta = { } ctx = scenario.Context(ops.CharmBase, meta=meta, unit_id=1) ctx.run(ctx.on.start(), state_in) # invalid: this unit's id cannot be the ID of a peer. + + ``` ### SubordinateRelation @@ -531,7 +533,7 @@ local_file = pathlib.Path('/path/to/local/real/file.txt') container = scenario.Container( name="foo", can_connect=True, - mounts={'local': scenario.Mount('/local/share/config.yaml', local_file)} + mounts={'local': scenario.Mount(location='/local/share/config.yaml', source=local_file)} ) state = scenario.State(containers=[container]) ``` @@ -568,7 +570,7 @@ def test_pebble_push(): container = scenario,Container( name='foo', can_connect=True, - mounts={'local': Mount('/local/share/config.yaml', local_file.name)} + mounts={'local': Mount(location='/local/share/config.yaml', source=local_file.name)} ) state_in = State(containers=[container]) ctx = Context( @@ -667,32 +669,29 @@ Pebble can generate notices, which Juju will detect, and wake up the charm to let it know that something has happened in the container. The most common use-case is Pebble custom notices, which is a mechanism for the workload application to trigger a charm event. - +- When the charm is notified, there might be a queue of existing notices, or just the one that has triggered the event: ```python -import ops -import scenario - class MyCharm(ops.CharmBase): def __init__(self, framework): super().__init__(framework) - framework.observe(self.on["cont"].pebble_custom_notice, self._on_notice) + framework.observe(self.on["my-container"].pebble_custom_notice, self._on_notice) def _on_notice(self, event): event.notice.key # == "example.com/c" - for notice in self.unit.get_container("cont").get_notices(): + for notice in self.unit.get_container("my-container").get_notices(): ... ctx = scenario.Context(MyCharm, meta={"name": "foo", "containers": {"my-container": {}}}) notices = [ - scenario.Notice(key="example.com/a", occurences=10), + scenario.Notice(key="example.com/a", occurrences=10), scenario.Notice(key="example.com/b", last_data={"bar": "baz"}), scenario.Notice(key="example.com/c"), ] -cont = scenario.Container(notices=notices) -ctx.run(container.get_notice("example.com/c").event, scenario.State(containers=[cont])) +container = scenario.Container("my-container", notices=notices) +ctx.run(container.get_notice("example.com/c").event, scenario.State(containers=[container])) ``` ## Storage @@ -766,15 +765,14 @@ ctx.run(ctx.on.storage_attached(foo_1), scenario.State(storage=[foo_0, foo_1])) Since `ops 2.6.0`, charms can invoke the `open-port`, `close-port`, and `opened-ports` hook tools to manage the ports opened on the host VM/container. Using the `State.opened_ports` API, you can: - simulate a charm run with a port opened by some previous execution -```python ctx = scenario.Context(MyCharm, meta=MyCharm.META) -ctx.run(ctx.on.start(), scenario.State(opened_ports=[scenario.Port("tcp", 42)])) +ctx.run(ctx.on.start(), scenario.State(opened_ports=[scenario.TCPPort(port=42)])) ``` - assert that a charm has called `open-port` or `close-port`: ```python ctx = scenario.Context(PortCharm, meta=MyCharm.META) state1 = ctx.run(ctx.on.start(), scenario.State()) -assert state1.opened_ports == [scenario.Port("tcp", 42)] +assert state1.opened_ports == [scenario.TCPPort(port=42)] state2 = ctx.run(ctx.on.stop(), state1) assert state2.opened_ports == [] @@ -788,8 +786,8 @@ Scenario has secrets. Here's how you use them. state = scenario.State( secrets=[ scenario.Secret( + {0: {'key': 'public'}}, id='foo', - contents={0: {'key': 'public'}} ) ] ) @@ -817,8 +815,8 @@ To specify a secret owned by this unit (or app): state = scenario.State( secrets=[ scenario.Secret( + {0: {'key': 'private'}}, id='foo', - contents={0: {'key': 'private'}}, owner='unit', # or 'app' remote_grants={0: {"remote"}} # the secret owner has granted access to the "remote" app over some relation with ID 0 @@ -833,8 +831,8 @@ To specify a secret owned by some other application and give this unit (or app) state = scenario.State( secrets=[ scenario.Secret( + {0: {'key': 'public'}}, id='foo', - contents={0: {'key': 'public'}}, # owner=None, which is the default revision=0, # the revision that this unit (or app) is currently tracking ) diff --git a/scenario/__init__.py b/scenario/__init__.py index 93059ebf..a73570a6 100644 --- a/scenario/__init__.py +++ b/scenario/__init__.py @@ -11,12 +11,12 @@ Container, DeferredEvent, ExecOutput, + ICMPPort, Model, Mount, Network, Notice, PeerRelation, - Port, Relation, Secret, State, @@ -24,6 +24,8 @@ Storage, StoredState, SubordinateRelation, + TCPPort, + UDPPort, deferred, ) @@ -47,7 +49,9 @@ "Address", "BindAddress", "Network", - "Port", + "ICMPPort", + "TCPPort", + "UDPPort", "Storage", "StoredState", "State", diff --git a/scenario/context.py b/scenario/context.py index 6d60541b..1e693d46 100644 --- a/scenario/context.py +++ b/scenario/context.py @@ -5,7 +5,7 @@ import tempfile from contextlib import contextmanager from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union, cast +from typing import TYPE_CHECKING, Any, Dict, Final, List, Optional, Type, Union, cast from ops import CharmBase, EventBase @@ -19,6 +19,7 @@ Storage, _CharmSpec, _Event, + _MaxPositionalArgs, ) if TYPE_CHECKING: # pragma: no cover @@ -34,8 +35,8 @@ DEFAULT_JUJU_VERSION = "3.4" -@dataclasses.dataclass -class ActionOutput: +@dataclasses.dataclass(frozen=True) +class ActionOutput(_MaxPositionalArgs): """Wraps the results of running an action event with `run_action`.""" state: "State" @@ -43,12 +44,14 @@ class ActionOutput: In most cases, actions are not expected to be affecting it.""" logs: List[str] """Any logs associated with the action output, set by the charm.""" - results: Optional[Dict[str, Any]] + results: Optional[Dict[str, Any]] = None """Key-value mapping assigned by the charm as a result of the action. Will be None if the charm never calls action-set.""" failure: Optional[str] = None """If the action is not a success: the message the charm set when failing the action.""" + _max_positional_args: Final = 0 + @property def success(self) -> bool: """Return whether this action was a success.""" @@ -316,6 +319,7 @@ def __init__( self, charm_type: Type["CharmType"], meta: Optional[Dict[str, Any]] = None, + *, actions: Optional[Dict[str, Any]] = None, config: Optional[Dict[str, Any]] = None, charm_root: Optional["PathLike"] = None, @@ -382,7 +386,7 @@ def __init__( defined in metadata.yaml. :arg unit_id: Unit ID that this charm is deployed as. Defaults to 0. :arg app_trusted: whether the charm has Juju trust (deployed with ``--trust`` or added with - ``juju trust``). Defaults to False + ``juju trust``). Defaults to False. :arg charm_root: virtual charm root the charm will be executed with. If the charm, say, expects a `./src/foo/bar.yaml` file present relative to the execution cwd, you need to use this. E.g.: @@ -553,10 +557,10 @@ def run_action(self, action: "Action", state: "State") -> ActionOutput: def _finalize_action(self, state_out: "State"): ao = ActionOutput( - state_out, - self._action_logs, - self._action_results, - self._action_failure, + state=state_out, + logs=self._action_logs, + results=self._action_results, + failure=self._action_failure, ) # reset all action-related state diff --git a/scenario/mocking.py b/scenario/mocking.py index 8885e420..b17627d3 100644 --- a/scenario/mocking.py +++ b/scenario/mocking.py @@ -20,8 +20,11 @@ cast, ) -from ops import CloudSpec, JujuVersion, pebble -from ops.model import ModelError, RelationNotFoundError +from ops import JujuVersion, pebble +from ops.model import CloudSpec as CloudSpec_Ops +from ops.model import ModelError +from ops.model import Port as Port_Ops +from ops.model import RelationNotFoundError from ops.model import Secret as Secret_Ops # lol from ops.model import ( SecretInfo, @@ -39,8 +42,8 @@ Mount, Network, PeerRelation, - Port, Storage, + _port_cls_by_protocol, _RawPortProtocolLiteral, _RawStatusLiteral, ) @@ -112,8 +115,11 @@ def __init__( self._context = context self._charm_spec = charm_spec - def opened_ports(self) -> Set[Port]: - return set(self._state.opened_ports) + def opened_ports(self) -> Set[Port_Ops]: + return { + Port_Ops(protocol=port.protocol, port=port.port) + for port in self._state.opened_ports + } def open_port( self, @@ -122,7 +128,7 @@ def open_port( ): # fixme: the charm will get hit with a StateValidationError # here, not the expected ModelError... - port_ = Port(protocol, port) + port_ = _port_cls_by_protocol[protocol](port=port) ports = self._state.opened_ports if port_ not in ports: ports.append(port_) @@ -132,7 +138,7 @@ def close_port( protocol: "_RawPortProtocolLiteral", port: Optional[int] = None, ): - _port = Port(protocol, port) + _port = _port_cls_by_protocol[protocol](port=port) ports = self._state.opened_ports if _port in ports: ports.remove(_port) @@ -629,7 +635,7 @@ def resource_get(self, resource_name: str) -> str: f"resource {resource_name} not found in State. please pass it.", ) - def credential_get(self) -> CloudSpec: + def credential_get(self) -> CloudSpec_Ops: if not self._context.app_trusted: raise ModelError( "ERROR charm is not trusted, initialise Context with `app_trusted=True`", @@ -669,7 +675,7 @@ def __init__( path = Path(mount.location).parts mounting_dir = container_root.joinpath(*path[1:]) mounting_dir.parent.mkdir(parents=True, exist_ok=True) - mounting_dir.symlink_to(mount.src) + mounting_dir.symlink_to(mount.source) self._root = container_root diff --git a/scenario/state.py b/scenario/state.py index b2c6fe76..bb88e784 100644 --- a/scenario/state.py +++ b/scenario/state.py @@ -14,6 +14,7 @@ Any, Callable, Dict, + Final, Generic, List, Literal, @@ -27,10 +28,11 @@ ) from uuid import uuid4 -import ops import yaml from ops import pebble from ops.charm import CharmBase, CharmEvents +from ops.model import CloudCredential as CloudCredential_Ops +from ops.model import CloudSpec as CloudSpec_Ops from ops.model import SecretRotate, StatusBase from scenario.logger import logger as scenario_logger @@ -120,6 +122,26 @@ class MetadataNotFoundError(RuntimeError): """Raised when Scenario can't find a metadata.yaml file in the provided charm root.""" +# This can be replaced with the KW_ONLY dataclasses functionality in Python 3.10+. +class _MaxPositionalArgs: + """Raises TypeError when instantiating objects if arguments are not passed as keywords. + + Looks for a `_max_positional_args` class attribute, which should be an int + indicating the maximum number of positional arguments that can be passed to + `__init__` (excluding `self`). If not present, no limit is applied. + """ + + _max_positional_args = 0 + + def __new__(cls, *args, **_): + if len(args) > getattr(cls, "_max_positional_args", float("inf")): + raise TypeError( + f"{cls.__name__}.__init__() takes {cls._max_positional_args + 1} " + f"positional arguments but {len(args) + 1} were given", + ) + return super().__new__(cls) + + @dataclasses.dataclass(frozen=True) class CloudCredential: auth_type: str @@ -135,8 +157,8 @@ class CloudCredential: redacted: List[str] = dataclasses.field(default_factory=list) """A list of redacted generic cloud API secrets.""" - def _to_ops(self) -> ops.CloudCredential: - return ops.CloudCredential( + def _to_ops(self) -> CloudCredential_Ops: + return CloudCredential_Ops( auth_type=self.auth_type, attributes=self.attributes, redacted=self.redacted, @@ -175,8 +197,8 @@ class CloudSpec: is_controller_cloud: bool = False """If this is the cloud used by the controller.""" - def _to_ops(self) -> ops.CloudSpec: - return ops.CloudSpec( + def _to_ops(self) -> CloudSpec_Ops: + return CloudSpec_Ops( type=self.type, name=self.name, region=self.region, @@ -191,16 +213,16 @@ def _to_ops(self) -> ops.CloudSpec: @dataclasses.dataclass(frozen=True) -class Secret: +class Secret(_MaxPositionalArgs): + # mapping from revision IDs to each revision's contents + contents: Dict[int, "RawSecretRevisionContents"] + id: str # CAUTION: ops-created Secrets (via .add_secret()) will have a canonicalized # secret id (`secret:` prefix) # but user-created ones will not. Using post-init to patch it in feels bad, but requiring the user to # add the prefix manually every time seems painful as well. - # mapping from revision IDs to each revision's contents - contents: Dict[int, "RawSecretRevisionContents"] - # indicates if the secret is owned by THIS unit, THIS app or some other app/unit. # if None, the implication is that the secret has been granted to this unit. owner: Literal["unit", "app", None] = None @@ -217,6 +239,8 @@ class Secret: expire: Optional[datetime.datetime] = None rotate: Optional[SecretRotate] = None + _max_positional_args: Final = 1 + def _set_revision(self, revision: int): """Set a new tracked revision.""" # bypass frozen dataclass @@ -254,19 +278,23 @@ def normalize_name(s: str): @dataclasses.dataclass(frozen=True) -class Address: - hostname: str +class Address(_MaxPositionalArgs): value: str - cidr: str + hostname: str = "" + cidr: str = "" address: str = "" # legacy + _max_positional_args: Final = 1 + @dataclasses.dataclass(frozen=True) -class BindAddress: - interface_name: str +class BindAddress(_MaxPositionalArgs): addresses: List[Address] + interface_name: str = "" mac_address: Optional[str] = None + _max_positional_args: Final = 1 + def hook_tool_output_fmt(self): # dumps itself to dict in the same format the hook tool would # todo support for legacy (deprecated) `interfacename` and `macaddress` fields? @@ -280,11 +308,13 @@ def hook_tool_output_fmt(self): @dataclasses.dataclass(frozen=True) -class Network: +class Network(_MaxPositionalArgs): bind_addresses: List[BindAddress] ingress_addresses: List[str] egress_subnets: List[str] + _max_positional_args: Final = 0 + def hook_tool_output_fmt(self): # dumps itself to dict in the same format the hook tool would return { @@ -323,7 +353,7 @@ def default( _next_relation_id_counter = 1 -def next_relation_id(update=True): +def next_relation_id(*, update=True): global _next_relation_id_counter cur = _next_relation_id_counter if update: @@ -332,7 +362,7 @@ def next_relation_id(update=True): @dataclasses.dataclass(frozen=True) -class _RelationBase: +class _RelationBase(_MaxPositionalArgs): endpoint: str """Relation endpoint name. Must match some endpoint name defined in metadata.yaml.""" @@ -352,6 +382,8 @@ class _RelationBase: ) """This unit's databag for this relation.""" + _max_positional_args: Final = 2 + @property def _databags(self): """Yield all databags in this relation.""" @@ -508,8 +540,9 @@ def _random_model_name(): @dataclasses.dataclass(frozen=True) -class Model: +class Model(_MaxPositionalArgs): name: str = dataclasses.field(default_factory=_random_model_name) + uuid: str = dataclasses.field(default_factory=lambda: str(uuid4())) # whatever juju models --format=json | jq '.models[].type' gives back. @@ -519,6 +552,8 @@ class Model: cloud_spec: Optional[CloudSpec] = None """Cloud specification information (metadata) including credentials.""" + _max_positional_args: Final = 1 + # for now, proc mock allows you to map one command to one mocked output. # todo extend: one input -> multiple outputs, at different times @@ -538,7 +573,7 @@ def _generate_new_change_id(): @dataclasses.dataclass(frozen=True) -class ExecOutput: +class ExecOutput(_MaxPositionalArgs): return_code: int = 0 stdout: str = "" stderr: str = "" @@ -546,6 +581,8 @@ class ExecOutput: # change ID: used internally to keep track of mocked processes _change_id: int = dataclasses.field(default_factory=_generate_new_change_id) + _max_positional_args: Final = 0 + def _run(self) -> int: return self._change_id @@ -554,9 +591,11 @@ def _run(self) -> int: @dataclasses.dataclass(frozen=True) -class Mount: +class Mount(_MaxPositionalArgs): location: Union[str, PurePosixPath] - src: Union[str, Path] + source: Union[str, Path] + + _max_positional_args: Final = 0 def _now_utc(): @@ -566,7 +605,7 @@ def _now_utc(): _next_notice_id_counter = 1 -def next_notice_id(update=True): +def next_notice_id(*, update=True): global _next_notice_id_counter cur = _next_notice_id_counter if update: @@ -575,7 +614,7 @@ def next_notice_id(update=True): @dataclasses.dataclass(frozen=True) -class Notice: +class Notice(_MaxPositionalArgs): key: str """The notice key, a string that differentiates notices of this type. @@ -617,6 +656,8 @@ class Notice: expire_after: Optional[datetime.timedelta] = None """How long since one of these last occurred until Pebble will drop the notice.""" + _max_positional_args: Final = 1 + def _to_ops(self) -> pebble.Notice: return pebble.Notice( id=self.id, @@ -634,10 +675,12 @@ def _to_ops(self) -> pebble.Notice: @dataclasses.dataclass(frozen=True) -class _BoundNotice: +class _BoundNotice(_MaxPositionalArgs): notice: Notice container: "Container" + _max_positional_args: Final = 0 + @property def event(self): """Sugar to generate a -pebble-custom-notice event for this notice.""" @@ -650,8 +693,9 @@ def event(self): @dataclasses.dataclass(frozen=True) -class Container: +class Container(_MaxPositionalArgs): name: str + can_connect: bool = False # This is the base plan. On top of it, one can add layers. @@ -676,8 +720,8 @@ class Container: # # this becomes: # mounts = { - # 'foo': Mount('/home/foo/', Path('/path/to/local/dir/containing/bar/py/')) - # 'bin': Mount('/bin/', Path('/path/to/local/dir/containing/bash/and/baz/')) + # 'foo': Mount(location='/home/foo/', source=Path('/path/to/local/dir/containing/bar/py/')) + # 'bin': Mount(location='/bin/', source=Path('/path/to/local/dir/containing/bash/and/baz/')) # } # when the charm runs `pebble.pull`, it will return .open() from one of those paths. # when the charm pushes, it will either overwrite one of those paths (careful!) or it will @@ -688,6 +732,8 @@ class Container: notices: List[Notice] = dataclasses.field(default_factory=list) + _max_positional_args: Final = 1 + def _render_services(self): # copied over from ops.testing._TestingPebbleClient._render_services() services = {} # type: Dict[str, pebble.Service] @@ -757,7 +803,7 @@ def get_notice( """ for notice in self.notices: if notice.key == key and notice.type == notice_type: - return _BoundNotice(notice, self) + return _BoundNotice(notice=notice, container=self) raise KeyError( f"{self.name} does not have a notice with key {key} and type {notice_type}", ) @@ -811,12 +857,13 @@ class _MyClass(_EntityStatus, statusbase_subclass): @dataclasses.dataclass(frozen=True) -class StoredState: +class StoredState(_MaxPositionalArgs): + name: str = "_stored" + # /-separated Object names. E.g. MyCharm/MyCharmLib. # if None, this StoredState instance is owned by the Framework. - owner_path: Optional[str] + owner_path: Optional[str] = None - name: str = "_stored" # Ideally, the type here would be only marshallable types, rather than Any. # However, it's complex to describe those types, since it's a recursive # definition - even in TypeShed the _Marshallable type includes containers @@ -825,6 +872,8 @@ class StoredState: _data_type_name: str = "StoredStateData" + _max_positional_args: Final = 1 + @property def handle_path(self): return f"{self.owner_path or ''}/{self._data_type_name}[{self.name}]" @@ -834,35 +883,84 @@ def handle_path(self): @dataclasses.dataclass(frozen=True) -class Port: +class _Port(_MaxPositionalArgs): """Represents a port on the charm host.""" - protocol: _RawPortProtocolLiteral port: Optional[int] = None """The port to open. Required for TCP and UDP; not allowed for ICMP.""" + protocol: _RawPortProtocolLiteral = "tcp" + + _max_positional_args = 1 def __post_init__(self): - port = self.port - is_icmp = self.protocol == "icmp" - if port: - if is_icmp: - raise StateValidationError( - "`port` arg not supported with `icmp` protocol", - ) - if not (1 <= port <= 65535): - raise StateValidationError( - f"`port` outside bounds [1:65535], got {port}", - ) - elif not is_icmp: + if type(self) is _Port: + raise RuntimeError( + "_Port cannot be instantiated directly; " + "please use TCPPort, UDPPort, or ICMPPort", + ) + + +@dataclasses.dataclass(frozen=True) +class TCPPort(_Port): + """Represents a TCP port on the charm host.""" + + port: int + """The port to open.""" + protocol: _RawPortProtocolLiteral = "tcp" + + _max_positional_args: Final = 1 + + def __post_init__(self): + super().__post_init__() + if not (1 <= self.port <= 65535): raise StateValidationError( - f"`port` arg required with `{self.protocol}` protocol", + f"`port` outside bounds [1:65535], got {self.port}", ) +@dataclasses.dataclass(frozen=True) +class UDPPort(_Port): + """Represents a UDP port on the charm host.""" + + port: int + """The port to open.""" + protocol: _RawPortProtocolLiteral = "udp" + + _max_positional_args: Final = 1 + + def __post_init__(self): + super().__post_init__() + if not (1 <= self.port <= 65535): + raise StateValidationError( + f"`port` outside bounds [1:65535], got {self.port}", + ) + + +@dataclasses.dataclass(frozen=True) +class ICMPPort(_Port): + """Represents an ICMP port on the charm host.""" + + protocol: _RawPortProtocolLiteral = "icmp" + + _max_positional_args: Final = 0 + + def __post_init__(self): + super().__post_init__() + if self.port is not None: + raise StateValidationError("`port` cannot be set for `ICMPPort`") + + +_port_cls_by_protocol = { + "tcp": TCPPort, + "udp": UDPPort, + "icmp": ICMPPort, +} + + _next_storage_index_counter = 0 # storage indices start at 0 -def next_storage_index(update=True): +def next_storage_index(*, update=True): """Get the index (used to be called ID) the next Storage to be created will get. Pass update=False if you're only inspecting it. @@ -876,7 +974,7 @@ def next_storage_index(update=True): @dataclasses.dataclass(frozen=True) -class Storage: +class Storage(_MaxPositionalArgs): """Represents an (attached!) storage made available to the charm container.""" name: str @@ -884,13 +982,15 @@ class Storage: index: int = dataclasses.field(default_factory=next_storage_index) # Every new Storage instance gets a new one, if there's trouble, override. + _max_positional_args: Final = 1 + def get_filesystem(self, ctx: "Context") -> Path: """Simulated filesystem root in this context.""" return ctx._get_storage_root(self.name, self.index) @dataclasses.dataclass(frozen=True) -class State: +class State(_MaxPositionalArgs): """Represents the juju-owned portion of a unit's state. Roughly speaking, it wraps all hook-tool- and pebble-mediated data a charm can access in its @@ -919,7 +1019,7 @@ class State: If a storage is not attached, omit it from this listing.""" # we don't use sets to make json serialization easier - opened_ports: List[Port] = dataclasses.field(default_factory=list) + opened_ports: List[_Port] = dataclasses.field(default_factory=list) """Ports opened by juju on this charm.""" leader: bool = False """Whether this charm has leadership.""" @@ -952,6 +1052,8 @@ class State: workload_version: str = "" """Workload version.""" + _max_positional_args: Final = 0 + def __post_init__(self): for name in ["app_status", "unit_status"]: val = getattr(self, name) @@ -1384,7 +1486,7 @@ def deferred(self, handler: Callable, event_id: int = 1) -> DeferredEvent: _next_action_id_counter = 1 -def next_action_id(update=True): +def next_action_id(*, update=True): global _next_action_id_counter cur = _next_action_id_counter if update: @@ -1395,7 +1497,7 @@ def next_action_id(update=True): @dataclasses.dataclass(frozen=True) -class Action: +class Action(_MaxPositionalArgs): name: str params: Dict[str, "AnyJson"] = dataclasses.field(default_factory=dict) @@ -1406,6 +1508,8 @@ class Action: Every action invocation is automatically assigned a new one. Override in the rare cases where a specific ID is required.""" + _max_positional_args: Final = 1 + @property def event(self) -> _Event: """Helper to generate an action event from this action.""" diff --git a/tests/test_consistency_checker.py b/tests/test_consistency_checker.py index 6a955be7..82321558 100644 --- a/tests/test_consistency_checker.py +++ b/tests/test_consistency_checker.py @@ -3,7 +3,6 @@ import pytest from ops.charm import CharmBase -from scenario import Model from scenario.consistency_checker import check_consistency from scenario.runtime import InconsistentScenarioError from scenario.state import ( @@ -12,6 +11,7 @@ CloudCredential, CloudSpec, Container, + Model, Network, Notice, PeerRelation, @@ -285,7 +285,7 @@ def test_secrets_jujuv_bad(bad_v): @pytest.mark.parametrize("good_v", ("3.0", "3.1", "3", "3.33", "4", "100")) def test_secrets_jujuv_bad(good_v): assert_consistent( - State(secrets=[Secret("secret:foo", {0: {"a": "b"}})]), + State(secrets=[Secret(id="secret:foo", contents={0: {"a": "b"}})]), _Event("bar"), _CharmSpec(MyCharm, {}), good_v, @@ -293,7 +293,7 @@ def test_secrets_jujuv_bad(good_v): def test_secret_not_in_state(): - secret = Secret("secret:foo", {"a": "b"}) + secret = Secret(id="secret:foo", contents={"a": "b"}) assert_inconsistent( State(), _Event("secret_changed", secret=secret), @@ -673,10 +673,10 @@ def test_storedstate_consistency(): assert_consistent( State( stored_state=[ - StoredState(None, content={"foo": "bar"}), - StoredState(None, "my_stored_state", content={"foo": 1}), - StoredState("MyCharmLib", content={"foo": None}), - StoredState("OtherCharmLib", content={"foo": (1, 2, 3)}), + StoredState(content={"foo": "bar"}), + StoredState(name="my_stored_state", content={"foo": 1}), + StoredState(owner_path="MyCharmLib", content={"foo": None}), + StoredState(owner_path="OtherCharmLib", content={"foo": (1, 2, 3)}), ] ), _Event("start"), @@ -690,8 +690,8 @@ def test_storedstate_consistency(): assert_inconsistent( State( stored_state=[ - StoredState(None, content={"foo": "bar"}), - StoredState(None, "_stored", content={"foo": "bar"}), + StoredState(owner_path=None, content={"foo": "bar"}), + StoredState(owner_path=None, name="_stored", content={"foo": "bar"}), ] ), _Event("start"), @@ -703,7 +703,13 @@ def test_storedstate_consistency(): ), ) assert_inconsistent( - State(stored_state=[StoredState(None, content={"secret": Secret("foo", {})})]), + State( + stored_state=[ + StoredState( + owner_path=None, content={"secret": Secret(id="foo", contents={})} + ) + ] + ), _Event("start"), _CharmSpec( MyCharm, diff --git a/tests/test_context.py b/tests/test_context.py index d6995efc..aed14159 100644 --- a/tests/test_context.py +++ b/tests/test_context.py @@ -3,7 +3,7 @@ import pytest from ops import CharmBase -from scenario import Action, Context, State +from scenario import Action, ActionOutput, Context, State from scenario.state import _Event, next_action_id @@ -59,3 +59,23 @@ def test_app_name(app_name, unit_id): with ctx.manager(ctx.on.start(), State()) as mgr: assert mgr.charm.app.name == app_name assert mgr.charm.unit.name == f"{app_name}/{unit_id}" + + +def test_action_output_no_positional_arguments(): + with pytest.raises(TypeError): + ActionOutput(None, None) + + +def test_action_output_no_results(): + class MyCharm(CharmBase): + def __init__(self, framework): + super().__init__(framework) + framework.observe(self.on.act_action, self._on_act_action) + + def _on_act_action(self, _): + pass + + ctx = Context(MyCharm, meta={"name": "foo"}, actions={"act": {}}) + out = ctx.run_action(Action("act"), State()) + assert out.results is None + assert out.failure is None diff --git a/tests/test_context_on.py b/tests/test_context_on.py index be8c70b5..d9609d2e 100644 --- a/tests/test_context_on.py +++ b/tests/test_context_on.py @@ -81,7 +81,9 @@ def test_simple_events(event_name, event_kind): ) def test_simple_secret_events(as_kwarg, event_name, event_kind, owner): ctx = scenario.Context(ContextCharm, meta=META, actions=ACTIONS) - secret = scenario.Secret("secret:123", {0: {"password": "xxxx"}}, owner=owner) + secret = scenario.Secret( + id="secret:123", contents={0: {"password": "xxxx"}}, owner=owner + ) state_in = scenario.State(secrets=[secret]) # These look like: # ctx.run(ctx.on.secret_changed(secret=secret), state) @@ -112,8 +114,8 @@ def test_simple_secret_events(as_kwarg, event_name, event_kind, owner): def test_revision_secret_events(event_name, event_kind): ctx = scenario.Context(ContextCharm, meta=META, actions=ACTIONS) secret = scenario.Secret( - "secret:123", - {42: {"password": "yyyy"}, 43: {"password": "xxxx"}}, + id="secret:123", + contents={42: {"password": "yyyy"}, 43: {"password": "xxxx"}}, owner="app", ) state_in = scenario.State(secrets=[secret]) @@ -135,7 +137,9 @@ def test_revision_secret_events(event_name, event_kind): def test_revision_secret_events_as_positional_arg(event_name): ctx = scenario.Context(ContextCharm, meta=META, actions=ACTIONS) secret = scenario.Secret( - "secret:123", {42: {"password": "yyyy"}, 43: {"password": "xxxx"}}, owner=None + id="secret:123", + contents={42: {"password": "yyyy"}, 43: {"password": "xxxx"}}, + owner=None, ) state_in = scenario.State(secrets=[secret]) with pytest.raises(TypeError): @@ -180,7 +184,7 @@ def test_action_event_no_params(): def test_action_event_with_params(): ctx = scenario.Context(ContextCharm, meta=META, actions=ACTIONS) - action = scenario.Action("act", {"param": "hello"}) + action = scenario.Action("act", params={"param": "hello"}) # These look like: # ctx.run_action(ctx.on.action(action=action), state) # So that any parameters can be included and the ID can be customised. diff --git a/tests/test_e2e/test_actions.py b/tests/test_e2e/test_actions.py index 6256885c..39a057e6 100644 --- a/tests/test_e2e/test_actions.py +++ b/tests/test_e2e/test_actions.py @@ -5,7 +5,7 @@ from scenario import Context from scenario.context import InvalidEventError -from scenario.state import Action, State, _Event +from scenario.state import Action, State, _Event, next_action_id @pytest.fixture(scope="function") @@ -154,3 +154,17 @@ def handle_evt(charm: CharmBase, evt: ActionEvent): action = Action("foo", id=uuid) ctx = Context(mycharm, meta={"name": "foo"}, actions={"foo": {}}) ctx.run_action(action, State()) + + +def test_positional_arguments(): + with pytest.raises(TypeError): + Action("foo", {}) + + +def test_default_arguments(): + expected_id = next_action_id(update=False) + name = "foo" + action = Action(name) + assert action.name == name + assert action.params == {} + assert action.id == expected_id diff --git a/tests/test_e2e/test_pebble.py b/tests/test_e2e/test_pebble.py index a9223120..7dfbba67 100644 --- a/tests/test_e2e/test_pebble.py +++ b/tests/test_e2e/test_pebble.py @@ -10,7 +10,7 @@ from ops.pebble import ExecError, ServiceStartup, ServiceStatus from scenario import Context -from scenario.state import Container, ExecOutput, Mount, Notice, Port, State +from scenario.state import Container, ExecOutput, Mount, Notice, State from tests.helpers import jsonpatch_delta, trigger @@ -86,7 +86,7 @@ def callback(self: CharmBase): Container( name="foo", can_connect=True, - mounts={"bar": Mount("/bar/baz.txt", pth)}, + mounts={"bar": Mount(location="/bar/baz.txt", source=pth)}, ) ] ), @@ -97,10 +97,6 @@ def callback(self: CharmBase): ) -def test_port_equality(): - assert Port("tcp", 42) == Port("tcp", 42) - - @pytest.mark.parametrize("make_dirs", (True, False)) def test_fs_pull(charm_cls, make_dirs): text = "lorem ipsum/n alles amat gloriae foo" @@ -122,7 +118,9 @@ def callback(self: CharmBase): td = tempfile.TemporaryDirectory() container = Container( - name="foo", can_connect=True, mounts={"foo": Mount("/foo", td.name)} + name="foo", + can_connect=True, + mounts={"foo": Mount(location="/foo", source=td.name)}, ) state = State(containers=[container]) @@ -135,14 +133,15 @@ def callback(self: CharmBase): callback(mgr.charm) if make_dirs: - # file = (out.get_container("foo").mounts["foo"].src + "bar/baz.txt").open("/foo/bar/baz.txt") + # file = (out.get_container("foo").mounts["foo"].source + "bar/baz.txt").open("/foo/bar/baz.txt") # this is one way to retrieve the file file = Path(td.name + "/bar/baz.txt") # another is: assert ( - file == Path(out.get_container("foo").mounts["foo"].src) / "bar" / "baz.txt" + file + == Path(out.get_container("foo").mounts["foo"].source) / "bar" / "baz.txt" ) # but that is actually a symlink to the context's root tmp folder: diff --git a/tests/test_e2e/test_ports.py b/tests/test_e2e/test_ports.py index 13502971..3a19148f 100644 --- a/tests/test_e2e/test_ports.py +++ b/tests/test_e2e/test_ports.py @@ -2,7 +2,7 @@ from ops import CharmBase, Framework, StartEvent, StopEvent from scenario import Context, State -from scenario.state import Port +from scenario.state import StateValidationError, TCPPort, UDPPort, _Port class MyCharm(CharmBase): @@ -35,5 +35,18 @@ def test_open_port(ctx): def test_close_port(ctx): - out = ctx.run(ctx.on.stop(), State(opened_ports=[Port("tcp", 42)])) + out = ctx.run(ctx.on.stop(), State(opened_ports=[TCPPort(42)])) assert not out.opened_ports + + +def test_port_no_arguments(): + with pytest.raises(RuntimeError): + _Port() + + +@pytest.mark.parametrize("klass", (TCPPort, UDPPort)) +def test_port_port(klass): + with pytest.raises(StateValidationError): + klass(port=0) + with pytest.raises(StateValidationError): + klass(port=65536) diff --git a/tests/test_e2e/test_relations.py b/tests/test_e2e/test_relations.py index e72f754c..853c7ba5 100644 --- a/tests/test_e2e/test_relations.py +++ b/tests/test_e2e/test_relations.py @@ -21,6 +21,7 @@ StateValidationError, SubordinateRelation, _RelationBase, + next_relation_id, ) from tests.helpers import trigger @@ -421,3 +422,54 @@ def test_broken_relation_not_in_model_relations(mycharm): assert charm.model.get_relation("foo") is None assert charm.model.relations["foo"] == [] + + +@pytest.mark.parametrize("klass", (Relation, PeerRelation, SubordinateRelation)) +def test_relation_positional_arguments(klass): + with pytest.raises(TypeError): + klass("foo", "bar", None) + + +def test_relation_default_values(): + expected_id = next_relation_id(update=False) + endpoint = "database" + interface = "postgresql" + relation = Relation(endpoint, interface) + assert relation.id == expected_id + assert relation.endpoint == endpoint + assert relation.interface == interface + assert relation.local_app_data == {} + assert relation.local_unit_data == DEFAULT_JUJU_DATABAG + assert relation.remote_app_name == "remote" + assert relation.limit == 1 + assert relation.remote_app_data == {} + assert relation.remote_units_data == {0: DEFAULT_JUJU_DATABAG} + + +def test_subordinate_relation_default_values(): + expected_id = next_relation_id(update=False) + endpoint = "database" + interface = "postgresql" + relation = SubordinateRelation(endpoint, interface) + assert relation.id == expected_id + assert relation.endpoint == endpoint + assert relation.interface == interface + assert relation.local_app_data == {} + assert relation.local_unit_data == DEFAULT_JUJU_DATABAG + assert relation.remote_app_name == "remote" + assert relation.remote_unit_id == 0 + assert relation.remote_app_data == {} + assert relation.remote_unit_data == DEFAULT_JUJU_DATABAG + + +def test_peer_relation_default_values(): + expected_id = next_relation_id(update=False) + endpoint = "peers" + interface = "shared" + relation = PeerRelation(endpoint, interface) + assert relation.id == expected_id + assert relation.endpoint == endpoint + assert relation.interface == interface + assert relation.local_app_data == {} + assert relation.local_unit_data == DEFAULT_JUJU_DATABAG + assert relation.peers_data == {0: DEFAULT_JUJU_DATABAG} diff --git a/tests/test_e2e/test_secrets.py b/tests/test_e2e/test_secrets.py index 97e1f3b2..7229bd9f 100644 --- a/tests/test_e2e/test_secrets.py +++ b/tests/test_e2e/test_secrets.py @@ -538,3 +538,23 @@ def __init__(self, *args): secret.remove_all_revisions() assert not mgr.output.secrets[0].contents # secret wiped + + +def test_no_additional_positional_arguments(): + with pytest.raises(TypeError): + Secret({}, None) + + +def test_default_values(): + contents = {"foo": "bar"} + id = "secret:1" + secret = Secret(contents, id=id) + assert secret.contents == contents + assert secret.id == id + assert secret.label is None + assert secret.revision == 0 + assert secret.description is None + assert secret.owner is None + assert secret.rotate is None + assert secret.expire is None + assert secret.remote_grants == {} diff --git a/tests/test_e2e/test_state.py b/tests/test_e2e/test_state.py index 0c79da86..ccac80b8 100644 --- a/tests/test_e2e/test_state.py +++ b/tests/test_e2e/test_state.py @@ -6,7 +6,16 @@ from ops.framework import EventBase, Framework from ops.model import ActiveStatus, UnknownStatus, WaitingStatus -from scenario.state import DEFAULT_JUJU_DATABAG, Container, Relation, State +from scenario.state import ( + DEFAULT_JUJU_DATABAG, + Address, + BindAddress, + Container, + Model, + Network, + Relation, + State, +) from tests.helpers import jsonpatch_delta, sort_patch, trigger CUSTOM_EVT_SUFFIXES = { @@ -231,3 +240,61 @@ def pre_event(charm: CharmBase): assert out.relations[0].local_app_data == {"a": "b"} assert out.relations[0].local_unit_data == {"c": "d", **DEFAULT_JUJU_DATABAG} + + +@pytest.mark.parametrize( + "klass,num_args", + [ + (State, (1,)), + (Address, (0, 2)), + (BindAddress, (0, 2)), + (Network, (0, 2)), + ], +) +def test_positional_arguments(klass, num_args): + for num in num_args: + args = (None,) * num + with pytest.raises(TypeError): + klass(*args) + + +def test_model_positional_arguments(): + with pytest.raises(TypeError): + Model("", "") + + +def test_container_positional_arguments(): + with pytest.raises(TypeError): + Container("", "") + + +def test_container_default_values(): + name = "foo" + container = Container(name) + assert container.name == name + assert container.can_connect is False + assert container.layers == {} + assert container.service_status == {} + assert container.mounts == {} + assert container.exec_mock == {} + assert container.layers == {} + assert container._base_plan == {} + + +def test_state_default_values(): + state = State() + assert state.config == {} + assert state.relations == [] + assert state.networks == {} + assert state.containers == [] + assert state.storage == [] + assert state.opened_ports == [] + assert state.secrets == [] + assert state.resources == {} + assert state.deferred == [] + assert isinstance(state.model, Model) + assert state.leader is False + assert state.planned_units == 1 + assert state.app_status == UnknownStatus() + assert state.unit_status == UnknownStatus() + assert state.workload_version == "" diff --git a/tests/test_e2e/test_stored_state.py b/tests/test_e2e/test_stored_state.py index 22a6235e..38c38efd 100644 --- a/tests/test_e2e/test_stored_state.py +++ b/tests/test_e2e/test_stored_state.py @@ -39,7 +39,9 @@ def test_stored_state_initialized(mycharm): out = trigger( State( stored_state=[ - StoredState("MyCharm", name="_stored", content={"foo": "FOOX"}), + StoredState( + owner_path="MyCharm", name="_stored", content={"foo": "FOOX"} + ), ] ), "start", @@ -49,3 +51,16 @@ def test_stored_state_initialized(mycharm): # todo: ordering is messy? assert out.stored_state[1].content == {"foo": "FOOX", "baz": {12: 142}} assert out.stored_state[0].content == {"foo": "bar", "baz": {12: 142}} + + +def test_positional_arguments(): + with pytest.raises(TypeError): + StoredState("_stored", "") + + +def test_default_arguments(): + s = StoredState() + assert s.name == "_stored" + assert s.owner_path == None + assert s.content == {} + assert s._data_type_name == "StoredStateData"