Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a synchronous grpc client #647

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/betterproto/plugin/typing_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,14 @@ def async_iterable(self, type: str) -> str:
def async_iterator(self, type: str) -> str:
self._imports["typing"].add("AsyncIterator")
return f"AsyncIterator[{type}]"

def sync_iterable(self, type: str) -> str:
self._imports["typing"].add("Iterable")
return f"Iterable[{type}]"

def sync_iterator(self, type: str) -> str:
self._imports["typing"].add("Iterator")
return f"Iterator[{type}]"

def imports(self) -> Dict[str, Optional[Set[str]]]:
return {k: v if v else None for k, v in self._imports.items()}
Expand Down
2 changes: 2 additions & 0 deletions src/betterproto/templates/header.py.j2
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ __all__ = (
{%- endfor -%}
{%- for service in output_file.services -%}
"{{ service.py_name }}Stub",
"{{ service.py_name }}SyncStub",
"{{ service.py_name }}Base",
{%- endfor -%}
)
Expand Down Expand Up @@ -46,6 +47,7 @@ import betterproto
{% if output_file.services %}
from betterproto.grpc.grpclib_server import ServiceBase
import grpclib
import grpc
{% endif %}

{% if output_file.imports_type_checking_only %}
Expand Down
71 changes: 69 additions & 2 deletions src/betterproto/templates/template.py.j2
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,73 @@ class {{ message.py_name }}(betterproto.Message):
{% endif %}

{% endfor %}


{% for service in output_file.services %}
class {{ service.py_name }}SyncStub():
{% if service.comment %}
{{ service.comment }}

{% elif not service.methods %}
pass
{% endif %}

def __init__(self, channel: grpc.Channel):
self._channel = channel

{% for method in service.methods %}
def {{ method.py_name }}(self
{%- if not method.client_streaming -%}
, {{ method.py_input_message_param }}: "{{ method.py_input_message_type }}"
{%- else -%}
, {{ method.py_input_message_param }}_iterator: "{{ output_file.typing_compiler.iterable(method.py_input_message_type) }}"
{%- endif -%}
) -> "{% if method.server_streaming %}{{ output_file.typing_compiler.sync_iterator(method.py_output_message_type ) }}{% else %}{{ method.py_output_message_type }}{% endif %}":
{% if method.comment %}
{{ method.comment }}

{% endif %}
{% if method.proto_obj.options.deprecated %}
warnings.warn("{{ service.py_name }}.{{ method.py_name }} is deprecated", DeprecationWarning)

{% endif %}
{% if method.server_streaming %}
{% if method.client_streaming %}
for response in self._channel.stream_stream(
"{{ method.route }}",
{{ method.py_input_message_type }}.SerializeToString,
{{ method.py_output_message_type }}.FromString,
)({{ method.py_input_message_param }}_iterator):
yield response
{% else %}{# i.e. not client streaming #}
for response in self._channel.unary_stream(
"{{ method.route }}",
{{ method.py_input_message_type }}.SerializeToString,
{{ method.py_output_message_type }}.FromString,
)({{ method.py_input_message_param }}):
yield response

{% endif %}{# if client streaming #}
{% else %}{# i.e. not server streaming #}
{% if method.client_streaming %}
return self._channel.stream_unary(
"{{ method.route }}",
{{ method.py_input_message_type }}.SerializeToString,
{{ method.py_output_message_type }}.FromString,
)({{ method.py_input_message_param }}_iterator)
{% else %}{# i.e. not client streaming #}
return self._channel.unary_unary(
"{{ method.route }}",
{{ method.py_input_message_type }}.SerializeToString,
{{ method.py_output_message_type }}.FromString,
)({{ method.py_input_message_param }})
{% endif %}{# client streaming #}
{% endif %}

{% endfor %}
{% endfor %}


{% for service in output_file.services %}
class {{ service.py_name }}Stub(betterproto.ServiceStub):
{% if service.comment %}
Expand Down Expand Up @@ -160,9 +227,9 @@ class {{ service.py_name }}Base(ServiceBase):
, {{ method.py_input_message_param }}: "{{ method.py_input_message_type }}"
{%- else -%}
{# Client streaming: need a request iterator instead #}
, {{ method.py_input_message_param }}_iterator: {{ output_file.typing_compiler.async_iterator(method.py_input_message_type) }}
, {{ method.py_input_message_param }}_iterator: "{{ output_file.typing_compiler.async_iterator(method.py_input_message_type) }}"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This and the next diff are in fact fixes for the existing stub generation. If wanted I could split it into a separate PR.

{%- endif -%}
) -> {% if method.server_streaming %}{{ output_file.typing_compiler.async_iterator(method.py_output_message_type) }}{% else %}"{{ method.py_output_message_type }}"{% endif %}:
) -> "{% if method.server_streaming %}{{ output_file.typing_compiler.async_iterator(method.py_output_message_type) }}{% else %}{{ method.py_output_message_type }}{% endif %}":
{% if method.comment %}
{{ method.comment }}

Expand Down
6 changes: 3 additions & 3 deletions tests/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,11 @@ The following tests are automatically executed for all cases:

## Running the tests

- `pipenv run generate`
- `poe generate`
This generates:
- `betterproto/tests/output_betterproto` — *the plugin generated python classes*
- `betterproto/tests/output_reference` — *reference implementation classes*
- `pipenv run test`
- `poe test`

## Intentionally Failing tests

Expand All @@ -88,4 +88,4 @@ betterproto/tests/test_inputs.py ..x...x..x...x.X........xx........x.....x......
- `x` — XFAIL: expected failure
- `X` — XPASS: expected failure, but still passed

Test cases marked for expected failure are declared in [inputs/config.py](inputs/config.py)
Test cases marked for expected failure are declared in [inputs/config.py](inputs/config.py)
84 changes: 84 additions & 0 deletions tests/grpc/test_grpc_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
""" Testing the sync version of the client stubs.

This is not testing the lower level grpc calls, but rather the generated client stubs.
So instead of creating a real service and a real grpc channel,
we are going to mock the channel and simply test the client.

If we wanted to test without mocking we would need to use all the machinery here:
https://github.com/grpc/grpc/blob/master/src/python/grpcio_tests/tests/testing/_client_test.py

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or I could spawn an actual server. I would think that's the best testing, but it is not done so in the rest of the code (which uses the nice grpclib ChannelFor)


"""

import re
from sys import version
from tests.output_betterproto.service import (
DoThingRequest,
DoThingResponse,
GetThingRequest,
GetThingResponse,
TestSyncStub as ThingServiceClient,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the new synchronous stub.

)


class ChannelMock:
"""channel.unary_unary(
"/service.Test/DoThing",
DoThingRequest.SerializeToString,
DoThingResponse.FromString,
)(do_thing_request)
the method calls the serialize, then use the deserialize and returns the response"""

def unary_unary(self, route, request_serializer, response_deserializer):
"""mock the unary_unary call"""
def _unary_unary(req):
return response_deserializer(request_serializer(req))
return _unary_unary

def stream_unary(self, route, request_serializer, response_deserializer):
"""mock the stream_unary call"""
def _stream_unary(req):
return response_deserializer(request_serializer(next(req)))
return _stream_unary

def stream_stream(self, route, request_serializer, response_deserializer):
"""mock the stream_stream call"""
def _stream_stream(req):
return (response_deserializer(request_serializer(r)) for r in req)
return _stream_stream

def unary_stream(self, route, request_serializer, response_deserializer):
"""mock the unary_stream call"""
def _unary_stream(req):
return iter([response_deserializer(request_serializer(req))]*6)
return _unary_stream


def test_do_thing_call(mocker):
"""mock the channel and test the client stub"""
client = ThingServiceClient(channel=ChannelMock())
response = client.do_thing(DoThingRequest(name="clean room"))
assert response.names == ["clean room"]

def test_do_many_things_call(mocker):
"""mock the channel and test the client stub"""
client = ThingServiceClient(channel=ChannelMock())
response = client.do_many_things(iter([
DoThingRequest(name="only"),
DoThingRequest(name="room")]))
assert response == DoThingResponse(names=["only"]) #protobuf is stunning

def test_get_thing_versions_call(mocker):
"""mock the channel and test the client stub"""
client = ThingServiceClient(channel=ChannelMock())
response = client.get_thing_versions(GetThingRequest(name="extra"))
response = list(response)
assert response == [GetThingResponse(name="extra")]*6

def test_get_different_things_call(mocker):
"""mock the channel and test the client stub"""
client = ThingServiceClient(channel=ChannelMock())
response = client.get_different_things([
GetThingRequest(name="apple"),
GetThingRequest(name="orange")])
response = list(response)
assert response == [GetThingResponse(name="apple"), GetThingResponse(name="orange")]
1 change: 1 addition & 0 deletions tests/test_all_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def test_all_definition():
"GetThingRequest",
"GetThingResponse",
"TestStub",
"TestSyncStub",
"TestBase",
)
assert enum.__all__ == ("Choice", "ArithmeticOperator", "Test")