diff --git a/ait/core/server/stream.py b/ait/core/server/stream.py index 9aeac911..492fc892 100644 --- a/ait/core/server/stream.py +++ b/ait/core/server/stream.py @@ -113,39 +113,44 @@ def output_stream_factory(name, inputs, outputs, handlers, zmq_args=None): values in 'outputs' decides on the appropriate stream to instantiate and then returns it. """ - ostream = None - if type(outputs) is not list or (type(outputs) is list and len(outputs) == 0): - ostream = ZMQStream( - name, - inputs, - handlers, - zmq_args=zmq_args, - ) - return ostream - # backwards compatability with original UDP spec - if type(outputs[0]) is int and ait.MIN_PORT <= outputs[0] <= ait.MAX_PORT: - ostream = UDPOutputStream(name, inputs, outputs[0], handlers, zmq_args=zmq_args) - elif is_valid_address_spec(outputs[0]): - protocol, hostname, port = outputs[0].split(":") - if int(port) < ait.MIN_PORT or int(port) > ait.MAX_PORT: + + parsed_output = outputs + if type(parsed_output) is list and len(parsed_output) > 0: + if len(parsed_output) > 1: + ait.core.log.warn(f"Additional output args discarded {parsed_output[1:]}") + parsed_output = parsed_output[0] + if type(parsed_output) is int: + if ait.MIN_PORT <= parsed_output <= ait.MAX_PORT: + return UDPOutputStream( + name, inputs, parsed_output, handlers, zmq_args=zmq_args + ) + else: raise ValueError(f"Output stream specification invalid: {outputs}") - if protocol.lower() == "udp": - ostream = UDPOutputStream( - name, inputs, outputs[0], handlers, zmq_args=zmq_args + + elif type(parsed_output) is str and is_valid_address_spec(parsed_output): + protocol, hostname, port = parsed_output.split(":") + if protocol.lower() == "udp" and ait.MIN_PORT <= int(port) <= ait.MAX_PORT: + return UDPOutputStream( + name, inputs, parsed_output, handlers, zmq_args=zmq_args ) - elif protocol.lower() == "tcp": - ostream = TCPOutputStream( - name, inputs, outputs[0], handlers, zmq_args=zmq_args + elif protocol.lower() == "tcp" and ait.MIN_PORT <= int(port) <= ait.MAX_PORT: + return TCPOutputStream( + name, inputs, parsed_output, handlers, zmq_args=zmq_args ) else: raise ValueError(f"Output stream specification invalid: {outputs}") + elif parsed_output is None or ( + type(parsed_output) is list and len(parsed_output) == 0 + ): + return ZMQStream( + name, + inputs, + handlers, + zmq_args=zmq_args, + ) else: raise ValueError(f"Output stream specification invalid: {outputs}") - if ostream is None: - raise ValueError(f"Output stream specification invalid: {outputs}") - return ostream - def input_stream_factory(name, inputs, handlers, zmq_args=None): """ @@ -156,21 +161,30 @@ def input_stream_factory(name, inputs, handlers, zmq_args=None): """ stream = None - - if type(inputs) is not list or (type(inputs) is list and len(inputs) == 0): - raise ValueError(f"Input stream specification invalid: {inputs}") + parsed_inputs = inputs + if type(parsed_inputs) is int: + parsed_inputs = [parsed_inputs] + if type(parsed_inputs) is str: + parsed_inputs = [parsed_inputs] + + if type(parsed_inputs) is not list or ( + type(parsed_inputs) is list and len(parsed_inputs) == 0 + ): + raise ValueError(f"Input stream specification invalid: {parsed_inputs}") # backwards compatability with original UDP server spec if ( - type(inputs) is list - and type(inputs[0]) is int - and ait.MIN_PORT <= inputs[0] <= ait.MAX_PORT + type(parsed_inputs) is list + and type(parsed_inputs[0]) is int + and ait.MIN_PORT <= parsed_inputs[0] <= ait.MAX_PORT ): - stream = UDPInputServerStream(name, inputs[0], handlers, zmq_args=zmq_args) - elif is_valid_address_spec(inputs[0]): - protocol, hostname, port = inputs[0].split(":") + stream = UDPInputServerStream( + name, parsed_inputs[0], handlers, zmq_args=zmq_args + ) + elif is_valid_address_spec(parsed_inputs[0]): + protocol, hostname, port = parsed_inputs[0].split(":") if int(port) < ait.MIN_PORT or int(port) > ait.MAX_PORT: - raise ValueError(f"Input stream specification invalid: {inputs}") + raise ValueError(f"Input stream specification invalid: {parsed_inputs}") if protocol.lower() == "tcp": if hostname.lower() in [ "server", @@ -178,9 +192,13 @@ def input_stream_factory(name, inputs, handlers, zmq_args=None): "127.0.0.1", "0.0.0.0", ]: - stream = TCPInputServerStream(name, inputs[0], handlers, zmq_args) + stream = TCPInputServerStream( + name, parsed_inputs[0], handlers, zmq_args + ) else: - stream = TCPInputClientStream(name, inputs[0], handlers, zmq_args) + stream = TCPInputClientStream( + name, parsed_inputs[0], handlers, zmq_args + ) else: if hostname.lower() in [ "server", @@ -189,17 +207,17 @@ def input_stream_factory(name, inputs, handlers, zmq_args=None): "0.0.0.0", ]: stream = UDPInputServerStream( - name, inputs[0], handlers, zmq_args=zmq_args + name, parsed_inputs[0], handlers, zmq_args=zmq_args ) else: - raise ValueError(f"Input stream specification invalid: {inputs}") - elif all(isinstance(item, str) for item in inputs): - stream = ZMQStream(name, inputs, handlers, zmq_args=zmq_args) + raise ValueError(f"Input stream specification invalid: {parsed_inputs}") + elif all(isinstance(item, str) for item in parsed_inputs): + stream = ZMQStream(name, parsed_inputs, handlers, zmq_args=zmq_args) else: - raise ValueError(f"Input stream specification invalid: {inputs}") + raise ValueError(f"Input stream specification invalid: {parsed_inputs}") if stream is None: - raise ValueError(f"Input stream specification invalid: {inputs}") + raise ValueError(f"Input stream specification invalid: {parsed_inputs}") return stream diff --git a/tests/ait/core/server/test_stream.py b/tests/ait/core/server/test_stream.py index 54f77dc5..aa2e1df2 100644 --- a/tests/ait/core/server/test_stream.py +++ b/tests/ait/core/server/test_stream.py @@ -176,12 +176,18 @@ def test_stream_creation_invalid_workflow(self, stream, args): (["TCP:localhost:1234"], TCPInputServerStream), (["TCP:foo:1234"], TCPInputClientStream), ([1234], UDPInputServerStream), + (1234, UDPInputServerStream), (["UDP:server:1234"], UDPInputServerStream), (["UDP:localhost:1234"], UDPInputServerStream), (["UDP:0.0.0.0:1234"], UDPInputServerStream), (["UDP:127.0.0.1:1234"], UDPInputServerStream), + ("UDP:127.0.0.1:1234", UDPInputServerStream), (["FOO"], ZMQStream), (["FOO", "BAR"], ZMQStream), + ( + [1234, "FOO", "BAR"], + UDPInputServerStream, + ), # Technically valid but not really correct ], ) def test_valid_input_stream_factory(self, args, expected): @@ -215,47 +221,52 @@ def test_invalid_input_stream_factory(self, args, expected): with pytest.raises(expected): _ = input_stream_factory(*full_args) - # @pytest.mark.parametrize( - # "args,expected", - # [ - # (["TCP", "127.0.0.1", 1234], PortOutputStream), - # (["TCP", "localhost", 1234], PortOutputStream), - # (["TCP", "foo", 1234], PortOutputStream), - # (["UDP", "127.0.0.1", 1234], PortOutputStream), - # (["UDP", "localhost", 1234], PortOutputStream), - # (["UDP", "foo", 1234], PortOutputStream), - # ([1234], PortOutputStream), - # (1234, PortOutputStream), - # ([], ZMQStream), - # (None, ZMQStream), - # ], - # ) - # def test_valid_output_stream_factory(self, args, expected): - # full_args = [ - # "foo", - # "bar", - # args, - # [PacketHandler(input_type=int, output_type=str, packet="CCSDS_HEADER")], - # {"zmq_context": broker.context}, - # ] - # stream = output_stream_factory(*full_args) - # assert isinstance(stream, expected) + @pytest.mark.parametrize( + "args,expected", + [ + (["TCP:127.0.0.1:1234"], TCPOutputStream), + (["TCP:localhost:1234"], TCPOutputStream), + (["TCP:foo:1234"], TCPOutputStream), + (["UDP:127.0.0.1:1234"], UDPOutputStream), + (["UDP:localhost:1234"], UDPOutputStream), + (["UDP:foo:1234"], UDPOutputStream), + ([1234], UDPOutputStream), + (1234, UDPOutputStream), + ("UDP:foo:1234", UDPOutputStream), + ([], ZMQStream), + (None, ZMQStream), + ( + [1234, "TCP:foo:1234"], + UDPOutputStream, + ), # Technically valid but not really correct + ], + ) + def test_valid_output_stream_factory(self, args, expected): + full_args = [ + "foo", + "bar", + args, + [PacketHandler(input_type=int, output_type=str, packet="CCSDS_HEADER")], + {"zmq_context": broker.context}, + ] + stream = output_stream_factory(*full_args) + assert isinstance(stream, expected) - # @pytest.mark.parametrize( - # "args,expected", - # [ - # (["FOO", "127.0.0.1", 1234], ValueError), - # (["UDP", "127.0.0.1", "1234"], ValueError), - # (["UDP", 1, "1234"], ValueError), - # ], - # ) - # def test_invalid_output_stream_factory(self, args, expected): - # full_args = [ - # "foo", - # "bar", - # args, - # [PacketHandler(input_type=int, output_type=str, packet="CCSDS_HEADER")], - # {"zmq_context": broker.context}, - # ] - # with pytest.raises(expected): - # _ = output_stream_factory(*full_args) + @pytest.mark.parametrize( + "args,expected", + [ + (["FOO:127.0.0.1:1234"], ValueError), + (["UDP", "127.0.0.1", "1234"], ValueError), + (["FOO"], ValueError), + ], + ) + def test_invalid_output_stream_factory(self, args, expected): + full_args = [ + "foo", + "bar", + args, + [PacketHandler(input_type=int, output_type=str, packet="CCSDS_HEADER")], + {"zmq_context": broker.context}, + ] + with pytest.raises(expected): + _ = output_stream_factory(*full_args)