diff --git a/pyproject.toml b/pyproject.toml index e948e9384..422a39c60 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,10 +15,12 @@ packages = [ python = ">=3.6.2,<4.0" black = { version = ">=19.3b0", optional = true } dataclasses = { version = "^0.7", python = ">=3.6, <3.7" } +grpcio = { version = "^1.43.0", optional = true } grpclib = "^0.4.1" jinja2 = { version = "^2.11.2", optional = true } python-dateutil = "^2.8" + [tool.poetry.dev-dependencies] asv = "^0.4.2" black = "^21.11b0" @@ -43,7 +45,7 @@ protoc-gen-python_betterproto = "betterproto.plugin:main" [tool.poetry.extras] compiler = ["black", "jinja2"] - +grpcio = ["grpcio"] # Dev workflow tasks diff --git a/src/betterproto/grpc/grpcio_server.py b/src/betterproto/grpc/grpcio_server.py new file mode 100644 index 000000000..70b55289d --- /dev/null +++ b/src/betterproto/grpc/grpcio_server.py @@ -0,0 +1,34 @@ +from typing import Dict, TYPE_CHECKING +from abc import ABC, abstractmethod + +if TYPE_CHECKING: + import grpc + + +class ServicerBase(ABC): + """ + Base class for async grpcio servers. + """ + + @property + @abstractmethod + def __rpc_methods__(self) -> Dict[str, "grpc.RpcMethodHandler"]: + ... + + @property + @abstractmethod + def __proto_path__(self) -> str: + ... + + +def register_servicers(server: "grpc.aio.Server", *servicers: ServicerBase): + from grpc import method_handlers_generic_handler + + server.add_generic_rpc_handlers( + tuple( + method_handlers_generic_handler( + servicer.__proto_path__, servicer.__rpc_handlers__ + ) + for servicer in servicers + ) + ) diff --git a/src/betterproto/plugin/models.py b/src/betterproto/plugin/models.py index 840140043..21d0ad89d 100644 --- a/src/betterproto/plugin/models.py +++ b/src/betterproto/plugin/models.py @@ -203,9 +203,15 @@ def comment(self) -> str: @dataclass -class PluginRequestCompiler: +class Options: + grpc_kind: str = "grpclib" + include_google: bool = False + +@dataclass +class PluginRequestCompiler: plugin_request_obj: CodeGeneratorRequest + options: Options output_packages: Dict[str, "OutputTemplate"] = field(default_factory=dict) @property @@ -657,6 +663,10 @@ def __post_init__(self) -> None: def proto_name(self) -> str: return self.proto_obj.name + @property + def proto_path(self) -> str: + return self.parent.package + "." + self.proto_name + @property def py_name(self) -> str: return pythonize_class_name(self.proto_name) diff --git a/src/betterproto/plugin/parser.py b/src/betterproto/plugin/parser.py index 21a2caf14..7e4701a39 100644 --- a/src/betterproto/plugin/parser.py +++ b/src/betterproto/plugin/parser.py @@ -21,6 +21,7 @@ FieldCompiler, MapEntryCompiler, MessageCompiler, + Options, OneOfFieldCompiler, OutputTemplate, PluginRequestCompiler, @@ -61,19 +62,29 @@ def _traverse( ) +def parse_options(plugin_options: List[str]) -> Options: + options = Options() + for option in plugin_options: + if option.startswith("grpc="): + options.grpc_kind = option.split("=", 1)[1] + if option == "INCLUDE_GOOGLE": + options.include_google = True + return options + def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse: response = CodeGeneratorResponse() plugin_options = request.parameter.split(",") if request.parameter else [] response.supported_features = CodeGeneratorResponseFeature.FEATURE_PROTO3_OPTIONAL - request_data = PluginRequestCompiler(plugin_request_obj=request) + options = parse_options(plugin_options) + + request_data = PluginRequestCompiler( + plugin_request_obj=request, options=options + ) # Gather output packages for proto_file in request.proto_file: - if ( - proto_file.package == "google.protobuf" - and "INCLUDE_GOOGLE" not in plugin_options - ): + if proto_file.package == "google.protobuf" and options.include_google: # If not INCLUDE_GOOGLE, # skip re-compiling Google's well-known types continue diff --git a/src/betterproto/templates/template.py.j2 b/src/betterproto/templates/template.py.j2 index d27cff610..26f35d76a 100644 --- a/src/betterproto/templates/template.py.j2 +++ b/src/betterproto/templates/template.py.j2 @@ -15,10 +15,16 @@ from typing import {% for i in output_file.typing_imports|sort %}{{ i }}{% if no {% endif %} import betterproto +{% if output_file.parent_request.options.grpc_kind == "grpclib" %} from betterproto.grpc.grpclib_server import ServiceBase -{% if output_file.services %} +{% endif %} +{% if output_file.services and output_file.parent_request.options.grpc_kind == "grpclib" %} import grpclib {% endif %} +{% if output_file.services and output_file.parent_request.options.grpc_kind == "grpcio" %} +import grpc +from betterproto.grpc.grpcio_server import ServicerBase +{% endif %} {% if output_file.enums %}{% for enum in output_file.enums %} @@ -68,6 +74,9 @@ class {{ message.py_name }}(betterproto.Message): {% endfor %} + +{% if output_file.parent_request.options.grpc_kind == "grpclib" %} + {% for service in output_file.services %} class {{ service.py_name }}Stub(betterproto.ServiceStub): {% if service.comment %} @@ -239,6 +248,61 @@ class {{ service.py_name }}Base(ServiceBase): } {% endfor %} +{% endif %} + +{% if output_file.parent_request.options.grpc_kind == "grpcio" %} +{% for service in output_file.services %} +class {{ service.py_name }}Base(ServicerBase): + {% if service.comment %} +{{ service.comment }} + + {% endif %} + + {% for method in service.methods %} + async def {{ method.py_name }}(self + {%- if not method.client_streaming -%} + , request: "{{ method.py_input_message_type }}" + {%- else -%} + {# Client streaming: need a request iterator instead #} + , request_iterator: AsyncIterator["{{ method.py_input_message_type }}"] + {%- endif -%} + , context: grpc.aio.ServicerContext + ) -> {% if method.server_streaming %}AsyncGenerator["{{ method.py_output_message_type }}", None]{% else %}"{{ method.py_output_message_type }}"{% endif %}: + {% if method.comment %} +{{ method.comment }} + + {% endif %} + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + {% endfor %} + + __proto_path__ = "{{ service.proto_path }}" + + @property + def __rpc_methods__(self): + return { + {% for method in service.methods %} + "{{ method.proto_name }}": + {% if not method.client_streaming and not method.server_streaming %} + grpc.unary_unary_rpc_method_handler( + {% elif method.client_streaming and method.server_streaming %} + grpc.stream_stream_rpc_method_handler( + {% elif method.client_streaming %} + grpc.stream_unary_rpc_method_handler( + {% else %} + grpc.unary_stream_rpc_method_handler( + {% endif %} + self.{{ method.py_name }}, + request_deserializer={{ method.py_input_message_type }}.FromString, + response_serializer={{ method.py_input_message_type }}.SerializeToString, + ), + {% endfor %} + } + +{% endfor %} +{% endif %} {% for i in output_file.imports|sort %} {{ i }}