Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[POC] Remap #129

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion scenario/_consistency_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def _check_relation_event(
f"relation event should start with relation endpoint name. {event.name} does "
f"not start with {event.relation.endpoint}.",
)
if event.relation not in state.relations:
if event.relation.id not in {relation.id for relation in state.relations}:
errors.append(
f"cannot emit {event.name} because relation {event.relation.id} is not in the state.",
)
Expand Down
4 changes: 4 additions & 0 deletions scenario/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ class InconsistentScenarioError(ScenarioRuntimeError):
"""Error raised when the combination of state and event is inconsistent."""


class RemapFailedError(ScenarioRuntimeError):
"""Error raised when scenario fails to remap some object in State."""


class StateValidationError(RuntimeError):
"""Raised when individual parts of the State are inconsistent."""

Expand Down
146 changes: 145 additions & 1 deletion scenario/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import re
import string
from enum import Enum
from functools import singledispatch
from itertools import chain
from pathlib import Path, PurePosixPath
from typing import (
Expand Down Expand Up @@ -46,7 +47,11 @@
from ops.model import CloudSpec as CloudSpec_Ops
from ops.model import SecretRotate, StatusBase

from scenario.errors import MetadataNotFoundError, StateValidationError
from scenario.errors import (
MetadataNotFoundError,
RemapFailedError,
StateValidationError,
)
from scenario.logger import logger as scenario_logger

if TYPE_CHECKING: # pragma: no cover
Expand All @@ -57,6 +62,11 @@
UnitID = int

CharmType = TypeVar("CharmType", bound=CharmBase)
_Remappable = Union["Container", "Relation", "Secret", "StoredState"]
_R = TypeVar(
"_R",
bound=_Remappable,
)

logger = scenario_logger.getChild("state")

Expand Down Expand Up @@ -1580,6 +1590,140 @@ def get_relations(self, endpoint: str) -> Tuple["RelationBase", ...]:
if _normalise_name(r.endpoint) == normalized_endpoint
)

def _remap(self, *obj: _R) -> Iterable[Tuple[str, Optional[_R]]]:
return map(self._remap_one, obj)

def _remap_one(self, obj: _R) -> Tuple[str, Optional[_R]]:
"""Return the attribute in which the object can be found and the object itself."""

@singledispatch
def _filter(x: Any):
raise NotImplementedError(type(x))

@_filter.register
def _(x: Relation):
return x.id

@_filter.register
def _(x: Container):
return x.name

@_filter.register
def _(x: Secret):
return x.id

@singledispatch
def _getter(x: Any):
raise NotImplementedError(type(x))

@_getter.register
def _(x: Relation):
return "relations"

@_getter.register
def _(x: Container):
return "containers"

@_getter.register
def _(x: Secret):
return "secrets"

attr = _getter(obj)
objects = getattr(self, attr)
try:
matches = [o for o in objects if _filter(o) == _filter(obj)]
except NotImplementedError:
raise TypeError(f"cannot remap {type(obj)}")

if not matches:
return attr, None

if len(matches) > 1:
raise RuntimeError(
f"too many matches for {obj} (filtered by {_filter(obj)}).",
)
return attr, matches[0]

def remap(self, obj: _R) -> Optional[_R]:
"""Get the corresponding object from this State.
>>> from scenario import Relation, State, Context
>>> rel1, rel2 = Relation("foo"), Relation("bar")
>>> state_in = State(leader=True, relations=[rel1, rel2])
>>> ctx = Context(...)
>>> state_out = ctx.run(ctx.on.update_status(), state=state_in)
>>> rel1_out = state_out.remap(rel1)
>>> assert rel1.endpoint == "foo"
"""
return self._remap_one(obj)[1]

def remap_multiple(self, *obj: _Remappable) -> Tuple[Optional[_Remappable], ...]:
"""Get the corresponding objects from this State.
>>> from scenario import Relation, State, Context
>>> rel1, rel2 = Relation("foo"), Relation("bar")
>>> state_in = State(leader=True, relations=[rel1, rel2])
>>> ctx = Context(...)
>>> state_out = ctx.run(ctx.on.update_status(), state=state_in)
>>> rel1_out, rel2_out = state_out.remap_multiple(rel1, rel2)
>>> assert rel1.endpoint == "foo"
"""
return tuple(rmp[1] for rmp in self._remap(obj)) # type: ignore

def patch(self, obj_: _Remappable, /, **kwargs) -> "State":
"""Return a copy of this state with ``obj_`` modified by ``kwargs``.

For example:
>>> from scenario import Relation, State
>>> rel1, rel2 = Relation("foo"), Relation("bar")
>>> s = State(leader=True, relations=[rel1, rel2])
>>> s1 = s.patch(rel1, local_app_data = {"foo": "bar"})
... # is equivalent to:
>>> s1_ = State(leader=True, relations=[dataclasses.replace(rel1,local_app_data={"foo": "bar"}), rel2])
"""
obj = self.remap(obj_)
if not obj:
raise RemapFailedError(
f"cannot remap {obj_} to something in {self}: unable to patch it.",
)
modified_obj = dataclasses.replace(obj, **kwargs)
return self.insert(modified_obj)

def insert(self, *obj: _Remappable) -> "State":
"""Insert ``obj`` in the right place in this State.
>>> from scenario import Relation, State
>>> rel1, rel2 = Relation("foo"), Relation("bar")
>>> s = State(leader=True, relations=[rel1])
>>> s1 = s.insert(rel2)
... # is equivalent to:
>>> s1_ = State(leader=True, relations=[rel1, rel2])
... # and
>>> s1__ = s.insert(dataclasses.replace(rel2, endpoint="bar"))
... # is equivalent to:
>>> s1___ = State(leader=True, relations=[rel2])
"""
# if we can remap the object, we know we have to kick something out in order to insert it.
out = self
for attr, replace in self._remap(*obj):
current = getattr(out, attr)
new = [c for c in current if c != replace] + list(obj)
out = dataclasses.replace(out, **{attr: new})
return out

def without(self, *obj: _Remappable) -> "State":
"""Remove ``obj`` from this State.
>>> from scenario import Relation, State
>>> rel1, rel2 = Relation("foo"), Relation("bar")
>>> s = State(leader=True, relations=[rel1, rel2])
>>> s1 = s.without(rel2)
... # is equivalent to:
>>> s1_ = State(leader=True, relations=[rel1])
"""
out = self
for attr, replace in self._remap(*obj):
current = getattr(out, attr)
new = [c for c in current if c != replace]
out = dataclasses.replace(out, **{attr: new})
return out


def _is_valid_charmcraft_25_metadata(meta: Dict[str, Any]):
# Check whether this dict has the expected mandatory metadata fields according to the
Expand Down
60 changes: 60 additions & 0 deletions tests/test_remap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import scenario


def test_patch():
relation = scenario.Relation("foo", local_app_data={"foo": "bar"})
state = scenario.State(relations=[relation])

patched = state.patch(relation, local_app_data={"baz": "qux"})
assert list(patched.relations)[0].local_app_data == {"baz": "qux"}


def test_remap():
relation = scenario.Relation("foo", local_app_data={"foo": "bar"})
state = scenario.State(relations=[relation])
relation_out = state.remap(relation)
# in this case we didn't change it
assert relation_out is relation


def test_insert():
relation = scenario.Relation("foo", local_app_data={"foo": "bar"})
state = scenario.State().insert(relation)
assert state.relations == {relation}


def test_insert_multiple():
relation = scenario.Relation("foo", local_app_data={"foo": "bar"})
relation2 = scenario.Relation("foo", local_app_data={"buz": "fuz"})

state = scenario.State().insert(relation, relation2)

assert state.relations == {relation2, relation}


def test_without():
relation = scenario.Relation("foo", local_app_data={"foo": "bar"})
relation2 = scenario.Relation("foo", local_app_data={"buz": "fuz"})

state = scenario.State(relations=[relation, relation2]).without(relation)
assert list(state.relations) == [relation2]


def test_without_multiple():
relation = scenario.Relation("foo", local_app_data={"foo": "bar"})
relation2 = scenario.Relation("foo", local_app_data={"buz": "fuz"})

state = scenario.State(relations=[relation, relation2]).without(relation, relation2)
assert list(state.relations) == []


def test_insert_replace():
relation1 = scenario.Relation("foo", local_app_data={"foo": "bar"}, id=1)
relation2 = scenario.Relation("foo", local_app_data={"buz": "fuz"}, id=2)

relation1_dupe = scenario.Relation("foo", local_app_data={"noz": "soz"}, id=1)

state = scenario.State(relations=[relation1, relation2]).insert(relation1_dupe)

# inserting a relation with identical ID will kick out the old one
assert set(state.relations) == {relation2, relation1_dupe}
Loading