diff --git a/scenario/state.py b/scenario/state.py index 8b82a60c..36bda7bd 100644 --- a/scenario/state.py +++ b/scenario/state.py @@ -865,7 +865,7 @@ def handle_path(self): class Port: """Represents a port on the charm host.""" - protocol: _RawPortProtocolLiteral + protocol: _RawPortProtocolLiteral = "tcp" port: Optional[int] = None """The port to open. Required for TCP and UDP; not allowed for ICMP.""" @@ -873,14 +873,12 @@ class Port: def __init__( self, *, - protocol: _RawPortProtocolLiteral, + protocol: _RawPortProtocolLiteral = "tcp", port: Optional[int] = None, ): object.__setattr__(self, "protocol", protocol) object.__setattr__(self, "port", port) - def __post_init__(self): - port = self.port is_icmp = self.protocol == "icmp" if port: if is_icmp: diff --git a/tests/test_e2e/test_ports.py b/tests/test_e2e/test_ports.py index deead8f4..a0f32798 100644 --- a/tests/test_e2e/test_ports.py +++ b/tests/test_e2e/test_ports.py @@ -2,7 +2,7 @@ from ops import CharmBase, Framework, StartEvent, StopEvent from scenario import Context, State -from scenario.state import Port +from scenario.state import Port, StateValidationError class MyCharm(CharmBase): @@ -37,3 +37,25 @@ def test_open_port(ctx): def test_close_port(ctx): out = ctx.run(ctx.on.stop(), State(opened_ports=[Port(protocol="tcp", port=42)])) assert not out.opened_ports + + +def test_port_no_arguments(): + with pytest.raises(StateValidationError): + Port() + + +def test_port_default_protocol(): + port = Port(port=42) + assert port.protocol == "tcp" + assert port.port == 42 + + +def test_port_port(): + with pytest.raises(StateValidationError): + Port(protocol="icmp", port=42) + with pytest.raises(StateValidationError): + Port(protocol="tcp", port=0) + with pytest.raises(StateValidationError): + Port(protocol="udp", port=65536) + with pytest.raises(StateValidationError): + Port(protocol="tcp")