From 801be049a532f43ece71cd5f1c427cae1b47bcbe Mon Sep 17 00:00:00 2001 From: Leonard Gerard Date: Tue, 19 Nov 2024 01:28:54 +0100 Subject: [PATCH 1/3] sync wip --- src/betterproto/plugin/typing_compiler.py | 8 +++ src/betterproto/templates/header.py.j2 | 1 + src/betterproto/templates/template.py.j2 | 70 +++++++++++++++++++++++ 3 files changed, 79 insertions(+) diff --git a/src/betterproto/plugin/typing_compiler.py b/src/betterproto/plugin/typing_compiler.py index eca3691f9..3267410de 100644 --- a/src/betterproto/plugin/typing_compiler.py +++ b/src/betterproto/plugin/typing_compiler.py @@ -92,6 +92,14 @@ def async_iterable(self, type: str) -> str: def async_iterator(self, type: str) -> str: self._imports["typing"].add("AsyncIterator") return f"AsyncIterator[{type}]" + + def sync_iterable(self, type: str) -> str: + self._imports["typing"].add("Iterable") + return f"Iterable[{type}]" + + def sync_iterator(self, type: str) -> str: + self._imports["typing"].add("Iterator") + return f"Iterator[{type}]" def imports(self) -> Dict[str, Optional[Set[str]]]: return {k: v if v else None for k, v in self._imports.items()} diff --git a/src/betterproto/templates/header.py.j2 b/src/betterproto/templates/header.py.j2 index b6d0a6c44..685a6f2ce 100644 --- a/src/betterproto/templates/header.py.j2 +++ b/src/betterproto/templates/header.py.j2 @@ -12,6 +12,7 @@ __all__ = ( {%- endfor -%} {%- for service in output_file.services -%} "{{ service.py_name }}Stub", + "{{ service.py_name }}SyncStub", "{{ service.py_name }}Base", {%- endfor -%} ) diff --git a/src/betterproto/templates/template.py.j2 b/src/betterproto/templates/template.py.j2 index 4a252aec6..28f4b69a9 100644 --- a/src/betterproto/templates/template.py.j2 +++ b/src/betterproto/templates/template.py.j2 @@ -63,6 +63,76 @@ class {{ message.py_name }}(betterproto.Message): {% endif %} {% endfor %} + + +{% for service in output_file.services %} +class {{ service.py_name }}SyncStub(): + {% if service.comment %} +{{ service.comment }} + + {% elif not service.methods %} + pass + {% endif %} + + import grpc + + def __init__(self, channel: grpc.Channel): + self._channel = channel + + {% for method in service.methods %} + def {{ method.py_name }}(self + {%- if not method.client_streaming -%} + , {{ method.py_input_message_param }}: "{{ method.py_input_message_type }}" + {%- else -%} + {# Client streaming: need a request iterator instead #} + , {{ method.py_input_message_param }}: {{ output_file.typing_compiler.iterable(method.py_input_message_type) }} + {%- endif -%} + ) -> {% if method.server_streaming %}{{ output_file.typing_compiler.sync_iterator(method.py_output_message_type ) }}{% else %}"{{ method.py_output_message_type }}"{% endif %}: + {% if method.comment %} +{{ method.comment }} + + {% endif %} + {% if method.proto_obj.options.deprecated %} + warnings.warn("{{ service.py_name }}.{{ method.py_name }} is deprecated", DeprecationWarning) + + {% endif %} + {% if method.server_streaming %} + {% if method.client_streaming %} + for response in self._channel.stream_stream( + "{{ method.route }}", + {{ method.py_input_message_type }}.SerializeToString, + {{ method.py_output_message_type }}.FromString, + )({{ method.py_input_message_param }}): + yield response + {% else %}{# i.e. not client streaming #} + for response in self._channel.unary_stream( + "{{ method.route }}", + {{ method.py_input_message_type }}.SerializeToString, + {{ method.py_output_message_type }}.FromString, + )({{ method.py_input_message_param }}): + yield response + + {% endif %}{# if client streaming #} + {% else %}{# i.e. not server streaming #} + {% if method.client_streaming %} + return self._channel.stream_unary( + "{{ method.route }}", + {{ method.py_input_message_type }}.SerializeToString, + {{ method.py_output_message_type }}.FromString, + )({{ method.py_input_message_param }}) + {% else %}{# i.e. not client streaming #} + return self._channel.unary_unary( + "{{ method.route }}", + {{ method.py_input_message_type }}.SerializeToString, + {{ method.py_output_message_type }}.FromString, + )({{ method.py_input_message_param }}) + {% endif %}{# client streaming #} + {% endif %} + + {% endfor %} +{% endfor %} + + {% for service in output_file.services %} class {{ service.py_name }}Stub(betterproto.ServiceStub): {% if service.comment %} From a97b3f1d2f3029ebd1c4dff805c69e4a950281d0 Mon Sep 17 00:00:00 2001 From: Leonard Gerard Date: Tue, 19 Nov 2024 13:55:58 +0100 Subject: [PATCH 2/3] Existing tests passing --- src/betterproto/templates/header.py.j2 | 1 + src/betterproto/templates/template.py.j2 | 15 ++++++--------- tests/README.md | 6 +++--- tests/test_all_definition.py | 1 + 4 files changed, 11 insertions(+), 12 deletions(-) diff --git a/src/betterproto/templates/header.py.j2 b/src/betterproto/templates/header.py.j2 index 685a6f2ce..6bd334474 100644 --- a/src/betterproto/templates/header.py.j2 +++ b/src/betterproto/templates/header.py.j2 @@ -47,6 +47,7 @@ import betterproto {% if output_file.services %} from betterproto.grpc.grpclib_server import ServiceBase import grpclib +import grpc {% endif %} {% if output_file.imports_type_checking_only %} diff --git a/src/betterproto/templates/template.py.j2 b/src/betterproto/templates/template.py.j2 index 28f4b69a9..7973711b9 100644 --- a/src/betterproto/templates/template.py.j2 +++ b/src/betterproto/templates/template.py.j2 @@ -74,8 +74,6 @@ class {{ service.py_name }}SyncStub(): pass {% endif %} - import grpc - def __init__(self, channel: grpc.Channel): self._channel = channel @@ -84,10 +82,9 @@ class {{ service.py_name }}SyncStub(): {%- if not method.client_streaming -%} , {{ method.py_input_message_param }}: "{{ method.py_input_message_type }}" {%- else -%} - {# Client streaming: need a request iterator instead #} - , {{ method.py_input_message_param }}: {{ output_file.typing_compiler.iterable(method.py_input_message_type) }} + , {{ method.py_input_message_param }}_iterator: "{{ output_file.typing_compiler.iterable(method.py_input_message_type) }}" {%- endif -%} - ) -> {% if method.server_streaming %}{{ output_file.typing_compiler.sync_iterator(method.py_output_message_type ) }}{% else %}"{{ method.py_output_message_type }}"{% endif %}: + ) -> "{% if method.server_streaming %}{{ output_file.typing_compiler.sync_iterator(method.py_output_message_type ) }}{% else %}{{ method.py_output_message_type }}{% endif %}": {% if method.comment %} {{ method.comment }} @@ -102,7 +99,7 @@ class {{ service.py_name }}SyncStub(): "{{ method.route }}", {{ method.py_input_message_type }}.SerializeToString, {{ method.py_output_message_type }}.FromString, - )({{ method.py_input_message_param }}): + )({{ method.py_input_message_param }}_iterator): yield response {% else %}{# i.e. not client streaming #} for response in self._channel.unary_stream( @@ -119,7 +116,7 @@ class {{ service.py_name }}SyncStub(): "{{ method.route }}", {{ method.py_input_message_type }}.SerializeToString, {{ method.py_output_message_type }}.FromString, - )({{ method.py_input_message_param }}) + )({{ method.py_input_message_param }}_iterator) {% else %}{# i.e. not client streaming #} return self._channel.unary_unary( "{{ method.route }}", @@ -230,9 +227,9 @@ class {{ service.py_name }}Base(ServiceBase): , {{ method.py_input_message_param }}: "{{ method.py_input_message_type }}" {%- else -%} {# Client streaming: need a request iterator instead #} - , {{ method.py_input_message_param }}_iterator: {{ output_file.typing_compiler.async_iterator(method.py_input_message_type) }} + , {{ method.py_input_message_param }}_iterator: "{{ output_file.typing_compiler.async_iterator(method.py_input_message_type) }}" {%- endif -%} - ) -> {% if method.server_streaming %}{{ output_file.typing_compiler.async_iterator(method.py_output_message_type) }}{% else %}"{{ method.py_output_message_type }}"{% endif %}: + ) -> "{% if method.server_streaming %}{{ output_file.typing_compiler.async_iterator(method.py_output_message_type) }}{% else %}{{ method.py_output_message_type }}{% endif %}": {% if method.comment %} {{ method.comment }} diff --git a/tests/README.md b/tests/README.md index f1ee609cf..9f0c5bc77 100644 --- a/tests/README.md +++ b/tests/README.md @@ -68,11 +68,11 @@ The following tests are automatically executed for all cases: ## Running the tests -- `pipenv run generate` +- `poe generate` This generates: - `betterproto/tests/output_betterproto` — *the plugin generated python classes* - `betterproto/tests/output_reference` — *reference implementation classes* -- `pipenv run test` +- `poe test` ## Intentionally Failing tests @@ -88,4 +88,4 @@ betterproto/tests/test_inputs.py ..x...x..x...x.X........xx........x.....x...... - `x` — XFAIL: expected failure - `X` — XPASS: expected failure, but still passed -Test cases marked for expected failure are declared in [inputs/config.py](inputs/config.py) \ No newline at end of file +Test cases marked for expected failure are declared in [inputs/config.py](inputs/config.py) diff --git a/tests/test_all_definition.py b/tests/test_all_definition.py index 61abb5f37..b00bb0843 100644 --- a/tests/test_all_definition.py +++ b/tests/test_all_definition.py @@ -14,6 +14,7 @@ def test_all_definition(): "GetThingRequest", "GetThingResponse", "TestStub", + "TestSyncStub", "TestBase", ) assert enum.__all__ == ("Choice", "ArithmeticOperator", "Test") From a05ce0e36650810369ac754d601cf26b83c2f789 Mon Sep 17 00:00:00 2001 From: Leonard Gerard Date: Thu, 21 Nov 2024 12:07:25 +0100 Subject: [PATCH 3/3] Superficial testing --- tests/grpc/test_grpc_client.py | 84 ++++++++++++++++++++++++++++++++++ 1 file changed, 84 insertions(+) create mode 100644 tests/grpc/test_grpc_client.py diff --git a/tests/grpc/test_grpc_client.py b/tests/grpc/test_grpc_client.py new file mode 100644 index 000000000..98846252a --- /dev/null +++ b/tests/grpc/test_grpc_client.py @@ -0,0 +1,84 @@ +""" Testing the sync version of the client stubs. + +This is not testing the lower level grpc calls, but rather the generated client stubs. +So instead of creating a real service and a real grpc channel, +we are going to mock the channel and simply test the client. + +If we wanted to test without mocking we would need to use all the machinery here: +https://github.com/grpc/grpc/blob/master/src/python/grpcio_tests/tests/testing/_client_test.py + +""" + +import re +from sys import version +from tests.output_betterproto.service import ( + DoThingRequest, + DoThingResponse, + GetThingRequest, + GetThingResponse, + TestSyncStub as ThingServiceClient, +) + + +class ChannelMock: + """channel.unary_unary( + "/service.Test/DoThing", + DoThingRequest.SerializeToString, + DoThingResponse.FromString, + )(do_thing_request) + the method calls the serialize, then use the deserialize and returns the response""" + + def unary_unary(self, route, request_serializer, response_deserializer): + """mock the unary_unary call""" + def _unary_unary(req): + return response_deserializer(request_serializer(req)) + return _unary_unary + + def stream_unary(self, route, request_serializer, response_deserializer): + """mock the stream_unary call""" + def _stream_unary(req): + return response_deserializer(request_serializer(next(req))) + return _stream_unary + + def stream_stream(self, route, request_serializer, response_deserializer): + """mock the stream_stream call""" + def _stream_stream(req): + return (response_deserializer(request_serializer(r)) for r in req) + return _stream_stream + + def unary_stream(self, route, request_serializer, response_deserializer): + """mock the unary_stream call""" + def _unary_stream(req): + return iter([response_deserializer(request_serializer(req))]*6) + return _unary_stream + + +def test_do_thing_call(mocker): + """mock the channel and test the client stub""" + client = ThingServiceClient(channel=ChannelMock()) + response = client.do_thing(DoThingRequest(name="clean room")) + assert response.names == ["clean room"] + +def test_do_many_things_call(mocker): + """mock the channel and test the client stub""" + client = ThingServiceClient(channel=ChannelMock()) + response = client.do_many_things(iter([ + DoThingRequest(name="only"), + DoThingRequest(name="room")])) + assert response == DoThingResponse(names=["only"]) #protobuf is stunning + +def test_get_thing_versions_call(mocker): + """mock the channel and test the client stub""" + client = ThingServiceClient(channel=ChannelMock()) + response = client.get_thing_versions(GetThingRequest(name="extra")) + response = list(response) + assert response == [GetThingResponse(name="extra")]*6 + +def test_get_different_things_call(mocker): + """mock the channel and test the client stub""" + client = ThingServiceClient(channel=ChannelMock()) + response = client.get_different_things([ + GetThingRequest(name="apple"), + GetThingRequest(name="orange")]) + response = list(response) + assert response == [GetThingResponse(name="apple"), GetThingResponse(name="orange")]