diff --git a/scenario/context.py b/scenario/context.py index cb5331d5..26665c06 100644 --- a/scenario/context.py +++ b/scenario/context.py @@ -9,7 +9,6 @@ from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Type, Union, cast import ops -import ops.testing from scenario.errors import AlreadyEmittedError, ContextSetupError from scenario.logger import logger as scenario_logger @@ -28,8 +27,20 @@ ) if TYPE_CHECKING: # pragma: no cover + try: + from ops._private.harness import ExecArgs # type: ignore + except ImportError: + from ops.testing import ExecArgs # type: ignore + from scenario.ops_main_mock import Ops - from scenario.state import AnyJson, JujuLogLine, RelationBase, State, _EntityStatus + from scenario.state import ( + AnyJson, + CharmType, + JujuLogLine, + RelationBase, + State, + _EntityStatus, + ) logger = scenario_logger.getChild("runtime") @@ -426,7 +437,7 @@ def test_foo(): def __init__( self, - charm_type: Type[ops.testing.CharmType], + charm_type: Type["CharmType"], meta: Optional[Dict[str, Any]] = None, *, actions: Optional[Dict[str, Any]] = None, @@ -491,7 +502,7 @@ def __init__( self.charm_root = charm_root self.juju_version = juju_version if juju_version.split(".")[0] == "2": - logger.warn( + logger.warning( "Juju 2.x is closed and unsupported. You may encounter inconsistencies.", ) @@ -508,7 +519,7 @@ def __init__( self.juju_log: List["JujuLogLine"] = [] self.app_status_history: List["_EntityStatus"] = [] self.unit_status_history: List["_EntityStatus"] = [] - self.exec_history: Dict[str, List[ops.testing.ExecArgs]] = {} + self.exec_history: Dict[str, List["ExecArgs"]] = {} self.workload_version_history: List[str] = [] self.removed_secret_revisions: List[int] = [] self.emitted_events: List[ops.EventBase] = [] @@ -644,7 +655,10 @@ def run(self, event: "_Event", state: "State") -> "State": assert self._output_state is not None if event.action: if self._action_failure_message is not None: - raise ActionFailed(self._action_failure_message, self._output_state) + raise ActionFailed( + self._action_failure_message, + state=self._output_state, + ) return self._output_state @contextmanager diff --git a/scenario/mocking.py b/scenario/mocking.py index 9e004af4..def3c395 100644 --- a/scenario/mocking.py +++ b/scenario/mocking.py @@ -1,7 +1,9 @@ #!/usr/bin/env python3 # Copyright 2023 Canonical Ltd. # See LICENSE file for licensing details. + import datetime +import io import shutil from pathlib import Path from typing import ( @@ -20,6 +22,12 @@ ) from ops import JujuVersion, pebble + +try: + from ops._private.harness import ExecArgs, _TestingPebbleClient # type: ignore +except ImportError: + from ops.testing import ExecArgs, _TestingPebbleClient # type: ignore + from ops.model import CloudSpec as CloudSpec_Ops from ops.model import ModelError from ops.model import Port as Port_Ops @@ -33,7 +41,6 @@ _ModelBackend, ) from ops.pebble import Client, ExecError -from ops.testing import ExecArgs, _TestingPebbleClient from scenario.errors import ActionMissingFromContextError from scenario.logger import logger as scenario_logger @@ -66,9 +73,9 @@ def __init__( change_id: int, args: ExecArgs, return_code: int, - stdin: Optional[TextIO], - stdout: Optional[TextIO], - stderr: Optional[TextIO], + stdin: Optional[Union[TextIO, io.BytesIO]], + stdout: Optional[Union[TextIO, io.BytesIO]], + stderr: Optional[Union[TextIO, io.BytesIO]], ): self._change_id = change_id self._args = args @@ -99,7 +106,12 @@ def wait_output(self): stdout = self.stdout.read() if self.stdout is not None else None stderr = self.stderr.read() if self.stderr is not None else None if self._return_code != 0: - raise ExecError(list(self._args.command), self._return_code, stdout, stderr) + raise ExecError( + list(self._args.command), + self._return_code, + stdout, # type: ignore + stderr, # type: ignore + ) return stdout, stderr def send_signal(self, sig: Union[int, str]): # noqa: U100 @@ -167,15 +179,18 @@ def get_pebble(self, socket_path: str) -> "Client": # container not defined in state. mounts = {} - return _MockPebbleClient( - socket_path=socket_path, - container_root=container_root, - mounts=mounts, - state=self._state, - event=self._event, - charm_spec=self._charm_spec, - context=self._context, - container_name=container_name, + return cast( + Client, + _MockPebbleClient( + socket_path=socket_path, + container_root=container_root, + mounts=mounts, + state=self._state, + event=self._event, + charm_spec=self._charm_spec, + context=self._context, + container_name=container_name, + ), ) def _get_relation_by_id(self, rel_id) -> "RelationBase": @@ -616,7 +631,7 @@ def storage_add(self, name: str, count: int = 1): ) if "/" in name: - # this error is raised by ops.testing but not by ops at runtime + # this error is raised by Harness but not by ops at runtime raise ModelError('storage name cannot contain "/"') self._context.requested_storages[name] = count @@ -752,6 +767,10 @@ def __init__( self._root = container_root + self._notices: Dict[Tuple[str, str], pebble.Notice] = {} + self._last_notice_id = 0 + self._changes: Dict[str, pebble.Change] = {} + # load any existing notices and check information from the state self._notices: Dict[Tuple[str, str], pebble.Notice] = {} self._check_infos: Dict[str, pebble.CheckInfo] = {} @@ -790,7 +809,7 @@ def _layers(self) -> Dict[str, pebble.Layer]: def _service_status(self) -> Dict[str, pebble.ServiceStatus]: return self._container.service_statuses - # Based on a method of the same name from ops.testing. + # Based on a method of the same name from Harness. def _find_exec_handler(self, command) -> Optional["Exec"]: handlers = {exec.command_prefix: exec for exec in self._container.execs} # Start with the full command and, each loop iteration, drop the last diff --git a/scenario/runtime.py b/scenario/runtime.py index f4df73db..92cfcae7 100644 --- a/scenario/runtime.py +++ b/scenario/runtime.py @@ -38,10 +38,8 @@ ) if TYPE_CHECKING: # pragma: no cover - from ops.testing import CharmType - from scenario.context import Context - from scenario.state import State, _CharmSpec, _Event + from scenario.state import CharmType, State, _CharmSpec, _Event logger = scenario_logger.getChild("runtime") STORED_STATE_REGEX = re.compile( diff --git a/scenario/state.py b/scenario/state.py index 9179735d..b089c582 100644 --- a/scenario/state.py +++ b/scenario/state.py @@ -117,7 +117,7 @@ class ActionFailed(Exception): """Raised at the end of the hook if the charm has called ``event.fail()``.""" - def __init__(self, message: str, state: "State"): + def __init__(self, message: str, *, state: "State"): self.message = message self.state = state diff --git a/tests/helpers.py b/tests/helpers.py index 5ceffa9d..49ffaeff 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -18,9 +18,7 @@ from scenario.context import _DEFAULT_JUJU_VERSION, Context if TYPE_CHECKING: # pragma: no cover - from ops.testing import CharmType - - from scenario.state import State, _Event + from scenario.state import CharmType, State, _Event _CT = TypeVar("_CT", bound=Type[CharmType]) diff --git a/tests/test_charm_spec_autoload.py b/tests/test_charm_spec_autoload.py index 57b93a31..b9da4d24 100644 --- a/tests/test_charm_spec_autoload.py +++ b/tests/test_charm_spec_autoload.py @@ -1,18 +1,15 @@ import importlib import sys -import tempfile from contextlib import contextmanager from pathlib import Path from typing import Type import pytest import yaml -from ops import CharmBase -from ops.testing import CharmType from scenario import Context, Relation, State from scenario.context import ContextSetupError -from scenario.state import MetadataNotFoundError, _CharmSpec +from scenario.state import CharmType, MetadataNotFoundError, _CharmSpec CHARM = """ from ops import CharmBase