diff --git a/src/betterproto/plugin/typing_compiler.py b/src/betterproto/plugin/typing_compiler.py index eca3691f9..3267410de 100644 --- a/src/betterproto/plugin/typing_compiler.py +++ b/src/betterproto/plugin/typing_compiler.py @@ -92,6 +92,14 @@ def async_iterable(self, type: str) -> str: def async_iterator(self, type: str) -> str: self._imports["typing"].add("AsyncIterator") return f"AsyncIterator[{type}]" + + def sync_iterable(self, type: str) -> str: + self._imports["typing"].add("Iterable") + return f"Iterable[{type}]" + + def sync_iterator(self, type: str) -> str: + self._imports["typing"].add("Iterator") + return f"Iterator[{type}]" def imports(self) -> Dict[str, Optional[Set[str]]]: return {k: v if v else None for k, v in self._imports.items()} diff --git a/src/betterproto/templates/header.py.j2 b/src/betterproto/templates/header.py.j2 index b6d0a6c44..6bd334474 100644 --- a/src/betterproto/templates/header.py.j2 +++ b/src/betterproto/templates/header.py.j2 @@ -12,6 +12,7 @@ __all__ = ( {%- endfor -%} {%- for service in output_file.services -%} "{{ service.py_name }}Stub", + "{{ service.py_name }}SyncStub", "{{ service.py_name }}Base", {%- endfor -%} ) @@ -46,6 +47,7 @@ import betterproto {% if output_file.services %} from betterproto.grpc.grpclib_server import ServiceBase import grpclib +import grpc {% endif %} {% if output_file.imports_type_checking_only %} diff --git a/src/betterproto/templates/template.py.j2 b/src/betterproto/templates/template.py.j2 index 4a252aec6..7973711b9 100644 --- a/src/betterproto/templates/template.py.j2 +++ b/src/betterproto/templates/template.py.j2 @@ -63,6 +63,73 @@ class {{ message.py_name }}(betterproto.Message): {% endif %} {% endfor %} + + +{% for service in output_file.services %} +class {{ service.py_name }}SyncStub(): + {% if service.comment %} +{{ service.comment }} + + {% elif not service.methods %} + pass + {% endif %} + + def __init__(self, channel: grpc.Channel): + self._channel = channel + + {% for method in service.methods %} + def {{ method.py_name }}(self + {%- if not method.client_streaming -%} + , {{ method.py_input_message_param }}: "{{ method.py_input_message_type }}" + {%- else -%} + , {{ method.py_input_message_param }}_iterator: "{{ output_file.typing_compiler.iterable(method.py_input_message_type) }}" + {%- endif -%} + ) -> "{% if method.server_streaming %}{{ output_file.typing_compiler.sync_iterator(method.py_output_message_type ) }}{% else %}{{ method.py_output_message_type }}{% endif %}": + {% if method.comment %} +{{ method.comment }} + + {% endif %} + {% if method.proto_obj.options.deprecated %} + warnings.warn("{{ service.py_name }}.{{ method.py_name }} is deprecated", DeprecationWarning) + + {% endif %} + {% if method.server_streaming %} + {% if method.client_streaming %} + for response in self._channel.stream_stream( + "{{ method.route }}", + {{ method.py_input_message_type }}.SerializeToString, + {{ method.py_output_message_type }}.FromString, + )({{ method.py_input_message_param }}_iterator): + yield response + {% else %}{# i.e. not client streaming #} + for response in self._channel.unary_stream( + "{{ method.route }}", + {{ method.py_input_message_type }}.SerializeToString, + {{ method.py_output_message_type }}.FromString, + )({{ method.py_input_message_param }}): + yield response + + {% endif %}{# if client streaming #} + {% else %}{# i.e. not server streaming #} + {% if method.client_streaming %} + return self._channel.stream_unary( + "{{ method.route }}", + {{ method.py_input_message_type }}.SerializeToString, + {{ method.py_output_message_type }}.FromString, + )({{ method.py_input_message_param }}_iterator) + {% else %}{# i.e. not client streaming #} + return self._channel.unary_unary( + "{{ method.route }}", + {{ method.py_input_message_type }}.SerializeToString, + {{ method.py_output_message_type }}.FromString, + )({{ method.py_input_message_param }}) + {% endif %}{# client streaming #} + {% endif %} + + {% endfor %} +{% endfor %} + + {% for service in output_file.services %} class {{ service.py_name }}Stub(betterproto.ServiceStub): {% if service.comment %} @@ -160,9 +227,9 @@ class {{ service.py_name }}Base(ServiceBase): , {{ method.py_input_message_param }}: "{{ method.py_input_message_type }}" {%- else -%} {# Client streaming: need a request iterator instead #} - , {{ method.py_input_message_param }}_iterator: {{ output_file.typing_compiler.async_iterator(method.py_input_message_type) }} + , {{ method.py_input_message_param }}_iterator: "{{ output_file.typing_compiler.async_iterator(method.py_input_message_type) }}" {%- endif -%} - ) -> {% if method.server_streaming %}{{ output_file.typing_compiler.async_iterator(method.py_output_message_type) }}{% else %}"{{ method.py_output_message_type }}"{% endif %}: + ) -> "{% if method.server_streaming %}{{ output_file.typing_compiler.async_iterator(method.py_output_message_type) }}{% else %}{{ method.py_output_message_type }}{% endif %}": {% if method.comment %} {{ method.comment }} diff --git a/tests/README.md b/tests/README.md index f1ee609cf..9f0c5bc77 100644 --- a/tests/README.md +++ b/tests/README.md @@ -68,11 +68,11 @@ The following tests are automatically executed for all cases: ## Running the tests -- `pipenv run generate` +- `poe generate` This generates: - `betterproto/tests/output_betterproto` — *the plugin generated python classes* - `betterproto/tests/output_reference` — *reference implementation classes* -- `pipenv run test` +- `poe test` ## Intentionally Failing tests @@ -88,4 +88,4 @@ betterproto/tests/test_inputs.py ..x...x..x...x.X........xx........x.....x...... - `x` — XFAIL: expected failure - `X` — XPASS: expected failure, but still passed -Test cases marked for expected failure are declared in [inputs/config.py](inputs/config.py) \ No newline at end of file +Test cases marked for expected failure are declared in [inputs/config.py](inputs/config.py) diff --git a/tests/grpc/test_grpc_client.py b/tests/grpc/test_grpc_client.py new file mode 100644 index 000000000..98846252a --- /dev/null +++ b/tests/grpc/test_grpc_client.py @@ -0,0 +1,84 @@ +""" Testing the sync version of the client stubs. + +This is not testing the lower level grpc calls, but rather the generated client stubs. +So instead of creating a real service and a real grpc channel, +we are going to mock the channel and simply test the client. + +If we wanted to test without mocking we would need to use all the machinery here: +https://github.com/grpc/grpc/blob/master/src/python/grpcio_tests/tests/testing/_client_test.py + +""" + +import re +from sys import version +from tests.output_betterproto.service import ( + DoThingRequest, + DoThingResponse, + GetThingRequest, + GetThingResponse, + TestSyncStub as ThingServiceClient, +) + + +class ChannelMock: + """channel.unary_unary( + "/service.Test/DoThing", + DoThingRequest.SerializeToString, + DoThingResponse.FromString, + )(do_thing_request) + the method calls the serialize, then use the deserialize and returns the response""" + + def unary_unary(self, route, request_serializer, response_deserializer): + """mock the unary_unary call""" + def _unary_unary(req): + return response_deserializer(request_serializer(req)) + return _unary_unary + + def stream_unary(self, route, request_serializer, response_deserializer): + """mock the stream_unary call""" + def _stream_unary(req): + return response_deserializer(request_serializer(next(req))) + return _stream_unary + + def stream_stream(self, route, request_serializer, response_deserializer): + """mock the stream_stream call""" + def _stream_stream(req): + return (response_deserializer(request_serializer(r)) for r in req) + return _stream_stream + + def unary_stream(self, route, request_serializer, response_deserializer): + """mock the unary_stream call""" + def _unary_stream(req): + return iter([response_deserializer(request_serializer(req))]*6) + return _unary_stream + + +def test_do_thing_call(mocker): + """mock the channel and test the client stub""" + client = ThingServiceClient(channel=ChannelMock()) + response = client.do_thing(DoThingRequest(name="clean room")) + assert response.names == ["clean room"] + +def test_do_many_things_call(mocker): + """mock the channel and test the client stub""" + client = ThingServiceClient(channel=ChannelMock()) + response = client.do_many_things(iter([ + DoThingRequest(name="only"), + DoThingRequest(name="room")])) + assert response == DoThingResponse(names=["only"]) #protobuf is stunning + +def test_get_thing_versions_call(mocker): + """mock the channel and test the client stub""" + client = ThingServiceClient(channel=ChannelMock()) + response = client.get_thing_versions(GetThingRequest(name="extra")) + response = list(response) + assert response == [GetThingResponse(name="extra")]*6 + +def test_get_different_things_call(mocker): + """mock the channel and test the client stub""" + client = ThingServiceClient(channel=ChannelMock()) + response = client.get_different_things([ + GetThingRequest(name="apple"), + GetThingRequest(name="orange")]) + response = list(response) + assert response == [GetThingResponse(name="apple"), GetThingResponse(name="orange")] diff --git a/tests/test_all_definition.py b/tests/test_all_definition.py index 61abb5f37..b00bb0843 100644 --- a/tests/test_all_definition.py +++ b/tests/test_all_definition.py @@ -14,6 +14,7 @@ def test_all_definition(): "GetThingRequest", "GetThingResponse", "TestStub", + "TestSyncStub", "TestBase", ) assert enum.__all__ == ("Choice", "ArithmeticOperator", "Test")