Skip to content

Commit

Permalink
output streams and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
cjjacks committed Oct 24, 2024
1 parent a289a49 commit 2cb87f5
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 86 deletions.
104 changes: 61 additions & 43 deletions ait/core/server/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,39 +113,44 @@ def output_stream_factory(name, inputs, outputs, handlers, zmq_args=None):
values in 'outputs' decides on the appropriate stream to instantiate and
then returns it.
"""
ostream = None
if type(outputs) is not list or (type(outputs) is list and len(outputs) == 0):
ostream = ZMQStream(
name,
inputs,
handlers,
zmq_args=zmq_args,
)
return ostream
# backwards compatability with original UDP spec
if type(outputs[0]) is int and ait.MIN_PORT <= outputs[0] <= ait.MAX_PORT:
ostream = UDPOutputStream(name, inputs, outputs[0], handlers, zmq_args=zmq_args)
elif is_valid_address_spec(outputs[0]):
protocol, hostname, port = outputs[0].split(":")
if int(port) < ait.MIN_PORT or int(port) > ait.MAX_PORT:

parsed_output = outputs
if type(parsed_output) is list and len(parsed_output) > 0:
if len(parsed_output) > 1:
ait.core.log.warn(f"Additional output args discarded {parsed_output[1:]}")
parsed_output = parsed_output[0]
if type(parsed_output) is int:
if ait.MIN_PORT <= parsed_output <= ait.MAX_PORT:
return UDPOutputStream(
name, inputs, parsed_output, handlers, zmq_args=zmq_args
)
else:
raise ValueError(f"Output stream specification invalid: {outputs}")
if protocol.lower() == "udp":
ostream = UDPOutputStream(
name, inputs, outputs[0], handlers, zmq_args=zmq_args

elif type(parsed_output) is str and is_valid_address_spec(parsed_output):
protocol, hostname, port = parsed_output.split(":")
if protocol.lower() == "udp" and ait.MIN_PORT <= int(port) <= ait.MAX_PORT:
return UDPOutputStream(
name, inputs, parsed_output, handlers, zmq_args=zmq_args
)
elif protocol.lower() == "tcp":
ostream = TCPOutputStream(
name, inputs, outputs[0], handlers, zmq_args=zmq_args
elif protocol.lower() == "tcp" and ait.MIN_PORT <= int(port) <= ait.MAX_PORT:
return TCPOutputStream(
name, inputs, parsed_output, handlers, zmq_args=zmq_args
)
else:
raise ValueError(f"Output stream specification invalid: {outputs}")
elif parsed_output is None or (
type(parsed_output) is list and len(parsed_output) == 0
):
return ZMQStream(
name,
inputs,
handlers,
zmq_args=zmq_args,
)
else:
raise ValueError(f"Output stream specification invalid: {outputs}")

if ostream is None:
raise ValueError(f"Output stream specification invalid: {outputs}")
return ostream


def input_stream_factory(name, inputs, handlers, zmq_args=None):
"""
Expand All @@ -156,31 +161,44 @@ def input_stream_factory(name, inputs, handlers, zmq_args=None):
"""

stream = None

if type(inputs) is not list or (type(inputs) is list and len(inputs) == 0):
raise ValueError(f"Input stream specification invalid: {inputs}")
parsed_inputs = inputs
if type(parsed_inputs) is int:
parsed_inputs = [parsed_inputs]
if type(parsed_inputs) is str:
parsed_inputs = [parsed_inputs]

if type(parsed_inputs) is not list or (
type(parsed_inputs) is list and len(parsed_inputs) == 0
):
raise ValueError(f"Input stream specification invalid: {parsed_inputs}")

# backwards compatability with original UDP server spec
if (
type(inputs) is list
and type(inputs[0]) is int
and ait.MIN_PORT <= inputs[0] <= ait.MAX_PORT
type(parsed_inputs) is list
and type(parsed_inputs[0]) is int
and ait.MIN_PORT <= parsed_inputs[0] <= ait.MAX_PORT
):
stream = UDPInputServerStream(name, inputs[0], handlers, zmq_args=zmq_args)
elif is_valid_address_spec(inputs[0]):
protocol, hostname, port = inputs[0].split(":")
stream = UDPInputServerStream(
name, parsed_inputs[0], handlers, zmq_args=zmq_args
)
elif is_valid_address_spec(parsed_inputs[0]):
protocol, hostname, port = parsed_inputs[0].split(":")
if int(port) < ait.MIN_PORT or int(port) > ait.MAX_PORT:
raise ValueError(f"Input stream specification invalid: {inputs}")
raise ValueError(f"Input stream specification invalid: {parsed_inputs}")
if protocol.lower() == "tcp":
if hostname.lower() in [
"server",
"localhost",
"127.0.0.1",
"0.0.0.0",
]:
stream = TCPInputServerStream(name, inputs[0], handlers, zmq_args)
stream = TCPInputServerStream(
name, parsed_inputs[0], handlers, zmq_args
)
else:
stream = TCPInputClientStream(name, inputs[0], handlers, zmq_args)
stream = TCPInputClientStream(
name, parsed_inputs[0], handlers, zmq_args
)
else:
if hostname.lower() in [
"server",
Expand All @@ -189,17 +207,17 @@ def input_stream_factory(name, inputs, handlers, zmq_args=None):
"0.0.0.0",
]:
stream = UDPInputServerStream(
name, inputs[0], handlers, zmq_args=zmq_args
name, parsed_inputs[0], handlers, zmq_args=zmq_args
)
else:
raise ValueError(f"Input stream specification invalid: {inputs}")
elif all(isinstance(item, str) for item in inputs):
stream = ZMQStream(name, inputs, handlers, zmq_args=zmq_args)
raise ValueError(f"Input stream specification invalid: {parsed_inputs}")
elif all(isinstance(item, str) for item in parsed_inputs):
stream = ZMQStream(name, parsed_inputs, handlers, zmq_args=zmq_args)
else:
raise ValueError(f"Input stream specification invalid: {inputs}")
raise ValueError(f"Input stream specification invalid: {parsed_inputs}")

if stream is None:
raise ValueError(f"Input stream specification invalid: {inputs}")
raise ValueError(f"Input stream specification invalid: {parsed_inputs}")
return stream


Expand Down
97 changes: 54 additions & 43 deletions tests/ait/core/server/test_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,12 +176,18 @@ def test_stream_creation_invalid_workflow(self, stream, args):
(["TCP:localhost:1234"], TCPInputServerStream),
(["TCP:foo:1234"], TCPInputClientStream),
([1234], UDPInputServerStream),
(1234, UDPInputServerStream),
(["UDP:server:1234"], UDPInputServerStream),
(["UDP:localhost:1234"], UDPInputServerStream),
(["UDP:0.0.0.0:1234"], UDPInputServerStream),
(["UDP:127.0.0.1:1234"], UDPInputServerStream),
("UDP:127.0.0.1:1234", UDPInputServerStream),
(["FOO"], ZMQStream),
(["FOO", "BAR"], ZMQStream),
(
[1234, "FOO", "BAR"],
UDPInputServerStream,
), # Technically valid but not really correct
],
)
def test_valid_input_stream_factory(self, args, expected):
Expand Down Expand Up @@ -215,47 +221,52 @@ def test_invalid_input_stream_factory(self, args, expected):
with pytest.raises(expected):
_ = input_stream_factory(*full_args)

# @pytest.mark.parametrize(
# "args,expected",
# [
# (["TCP", "127.0.0.1", 1234], PortOutputStream),
# (["TCP", "localhost", 1234], PortOutputStream),
# (["TCP", "foo", 1234], PortOutputStream),
# (["UDP", "127.0.0.1", 1234], PortOutputStream),
# (["UDP", "localhost", 1234], PortOutputStream),
# (["UDP", "foo", 1234], PortOutputStream),
# ([1234], PortOutputStream),
# (1234, PortOutputStream),
# ([], ZMQStream),
# (None, ZMQStream),
# ],
# )
# def test_valid_output_stream_factory(self, args, expected):
# full_args = [
# "foo",
# "bar",
# args,
# [PacketHandler(input_type=int, output_type=str, packet="CCSDS_HEADER")],
# {"zmq_context": broker.context},
# ]
# stream = output_stream_factory(*full_args)
# assert isinstance(stream, expected)
@pytest.mark.parametrize(
"args,expected",
[
(["TCP:127.0.0.1:1234"], TCPOutputStream),
(["TCP:localhost:1234"], TCPOutputStream),
(["TCP:foo:1234"], TCPOutputStream),
(["UDP:127.0.0.1:1234"], UDPOutputStream),
(["UDP:localhost:1234"], UDPOutputStream),
(["UDP:foo:1234"], UDPOutputStream),
([1234], UDPOutputStream),
(1234, UDPOutputStream),
("UDP:foo:1234", UDPOutputStream),
([], ZMQStream),
(None, ZMQStream),
(
[1234, "TCP:foo:1234"],
UDPOutputStream,
), # Technically valid but not really correct
],
)
def test_valid_output_stream_factory(self, args, expected):
full_args = [
"foo",
"bar",
args,
[PacketHandler(input_type=int, output_type=str, packet="CCSDS_HEADER")],
{"zmq_context": broker.context},
]
stream = output_stream_factory(*full_args)
assert isinstance(stream, expected)

# @pytest.mark.parametrize(
# "args,expected",
# [
# (["FOO", "127.0.0.1", 1234], ValueError),
# (["UDP", "127.0.0.1", "1234"], ValueError),
# (["UDP", 1, "1234"], ValueError),
# ],
# )
# def test_invalid_output_stream_factory(self, args, expected):
# full_args = [
# "foo",
# "bar",
# args,
# [PacketHandler(input_type=int, output_type=str, packet="CCSDS_HEADER")],
# {"zmq_context": broker.context},
# ]
# with pytest.raises(expected):
# _ = output_stream_factory(*full_args)
@pytest.mark.parametrize(
"args,expected",
[
(["FOO:127.0.0.1:1234"], ValueError),
(["UDP", "127.0.0.1", "1234"], ValueError),
(["FOO"], ValueError),
],
)
def test_invalid_output_stream_factory(self, args, expected):
full_args = [
"foo",
"bar",
args,
[PacketHandler(input_type=int, output_type=str, packet="CCSDS_HEADER")],
{"zmq_context": broker.context},
]
with pytest.raises(expected):
_ = output_stream_factory(*full_args)

0 comments on commit 2cb87f5

Please sign in to comment.