Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Grpcio support #328

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@ packages = [
python = ">=3.6.2,<4.0"
black = { version = ">=19.3b0", optional = true }
dataclasses = { version = "^0.7", python = ">=3.6, <3.7" }
grpcio = { version = "^1.43.0", optional = true }
grpclib = "^0.4.1"
jinja2 = { version = "^2.11.2", optional = true }
python-dateutil = "^2.8"


[tool.poetry.dev-dependencies]
asv = "^0.4.2"
black = "^21.11b0"
Expand All @@ -43,7 +45,7 @@ protoc-gen-python_betterproto = "betterproto.plugin:main"

[tool.poetry.extras]
compiler = ["black", "jinja2"]

grpcio = ["grpcio"]

# Dev workflow tasks

Expand Down
34 changes: 34 additions & 0 deletions src/betterproto/grpc/grpcio_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from typing import Dict, TYPE_CHECKING
from abc import ABC, abstractmethod

if TYPE_CHECKING:
import grpc


class ServicerBase(ABC):
"""
Base class for async grpcio servers.
"""

@property
@abstractmethod
def __rpc_methods__(self) -> Dict[str, "grpc.RpcMethodHandler"]:
...

@property
@abstractmethod
def __proto_path__(self) -> str:
...


def register_servicers(server: "grpc.aio.Server", *servicers: ServicerBase):
from grpc import method_handlers_generic_handler

server.add_generic_rpc_handlers(
tuple(
method_handlers_generic_handler(
servicer.__proto_path__, servicer.__rpc_handlers__
)
for servicer in servicers
)
)
12 changes: 11 additions & 1 deletion src/betterproto/plugin/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,15 @@ def comment(self) -> str:


@dataclass
class PluginRequestCompiler:
class Options:
grpc_kind: str = "grpclib"
include_google: bool = False


@dataclass
class PluginRequestCompiler:
plugin_request_obj: CodeGeneratorRequest
options: Options
output_packages: Dict[str, "OutputTemplate"] = field(default_factory=dict)

@property
Expand Down Expand Up @@ -657,6 +663,10 @@ def __post_init__(self) -> None:
def proto_name(self) -> str:
return self.proto_obj.name

@property
def proto_path(self) -> str:
return self.parent.package + "." + self.proto_name

@property
def py_name(self) -> str:
return pythonize_class_name(self.proto_name)
Expand Down
21 changes: 16 additions & 5 deletions src/betterproto/plugin/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
FieldCompiler,
MapEntryCompiler,
MessageCompiler,
Options,
OneOfFieldCompiler,
OutputTemplate,
PluginRequestCompiler,
Expand Down Expand Up @@ -61,19 +62,29 @@ def _traverse(
)


def parse_options(plugin_options: List[str]) -> Options:
options = Options()
for option in plugin_options:
if option.startswith("grpc="):
options.grpc_kind = option.split("=", 1)[1]
if option == "INCLUDE_GOOGLE":
options.include_google = True
return options

def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
response = CodeGeneratorResponse()

plugin_options = request.parameter.split(",") if request.parameter else []
response.supported_features = CodeGeneratorResponseFeature.FEATURE_PROTO3_OPTIONAL

request_data = PluginRequestCompiler(plugin_request_obj=request)
options = parse_options(plugin_options)

request_data = PluginRequestCompiler(
plugin_request_obj=request, options=options
)
# Gather output packages
for proto_file in request.proto_file:
if (
proto_file.package == "google.protobuf"
and "INCLUDE_GOOGLE" not in plugin_options
):
if proto_file.package == "google.protobuf" and options.include_google:
# If not INCLUDE_GOOGLE,
# skip re-compiling Google's well-known types
continue
Expand Down
66 changes: 65 additions & 1 deletion src/betterproto/templates/template.py.j2
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,16 @@ from typing import {% for i in output_file.typing_imports|sort %}{{ i }}{% if no
{% endif %}

import betterproto
{% if output_file.parent_request.options.grpc_kind == "grpclib" %}
from betterproto.grpc.grpclib_server import ServiceBase
{% if output_file.services %}
{% endif %}
{% if output_file.services and output_file.parent_request.options.grpc_kind == "grpclib" %}
import grpclib
{% endif %}
{% if output_file.services and output_file.parent_request.options.grpc_kind == "grpcio" %}
import grpc
from betterproto.grpc.grpcio_server import ServicerBase
{% endif %}


{% if output_file.enums %}{% for enum in output_file.enums %}
Expand Down Expand Up @@ -68,6 +74,9 @@ class {{ message.py_name }}(betterproto.Message):


{% endfor %}

{% if output_file.parent_request.options.grpc_kind == "grpclib" %}

{% for service in output_file.services %}
class {{ service.py_name }}Stub(betterproto.ServiceStub):
{% if service.comment %}
Expand Down Expand Up @@ -239,6 +248,61 @@ class {{ service.py_name }}Base(ServiceBase):
}

{% endfor %}
{% endif %}

{% if output_file.parent_request.options.grpc_kind == "grpcio" %}
{% for service in output_file.services %}
class {{ service.py_name }}Base(ServicerBase):
{% if service.comment %}
{{ service.comment }}

{% endif %}

{% for method in service.methods %}
async def {{ method.py_name }}(self
{%- if not method.client_streaming -%}
, request: "{{ method.py_input_message_type }}"
{%- else -%}
{# Client streaming: need a request iterator instead #}
, request_iterator: AsyncIterator["{{ method.py_input_message_type }}"]
{%- endif -%}
, context: grpc.aio.ServicerContext
) -> {% if method.server_streaming %}AsyncGenerator["{{ method.py_output_message_type }}", None]{% else %}"{{ method.py_output_message_type }}"{% endif %}:
{% if method.comment %}
{{ method.comment }}

{% endif %}
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

{% endfor %}

__proto_path__ = "{{ service.proto_path }}"

@property
def __rpc_methods__(self):
return {
{% for method in service.methods %}
"{{ method.proto_name }}":
{% if not method.client_streaming and not method.server_streaming %}
grpc.unary_unary_rpc_method_handler(
{% elif method.client_streaming and method.server_streaming %}
grpc.stream_stream_rpc_method_handler(
{% elif method.client_streaming %}
grpc.stream_unary_rpc_method_handler(
{% else %}
grpc.unary_stream_rpc_method_handler(
{% endif %}
self.{{ method.py_name }},
request_deserializer={{ method.py_input_message_type }}.FromString,
response_serializer={{ method.py_input_message_type }}.SerializeToString,
),
{% endfor %}
}

{% endfor %}
{% endif %}

{% for i in output_file.imports|sort %}
{{ i }}
Expand Down