diff --git a/fast_grpc/context.py b/fast_grpc/context.py index d6be0d2..8f20851 100644 --- a/fast_grpc/context.py +++ b/fast_grpc/context.py @@ -12,7 +12,7 @@ def __init__(self, grpc_context: grpc.ServicerContext, method, method_descriptor self.input_type = method_descriptor.input_type._concrete_class self.output_type = method_descriptor.output_type._concrete_class self._start_time = time.time() - self._metadata = {} + self._metadata: dict[str, str] = {} @property def elapsed_time(self): diff --git a/fast_grpc/proto.py b/fast_grpc/proto.py index f740e37..9d24b93 100644 --- a/fast_grpc/proto.py +++ b/fast_grpc/proto.py @@ -6,7 +6,7 @@ from typing import Type, Sequence, Any import grpc -from pydantic import BaseModel +from pydantic import BaseModel, Field from typing_extensions import get_args, get_origin from jinja2 import Template from google.protobuf.descriptor import ( @@ -17,7 +17,7 @@ ) from fast_grpc.service import Service, MethodMode -from fast_grpc.types import Empty, ProtoTag +from fast_grpc.types import Empty, ProtoTag, PYTHON_TO_PROTOBUF_TYPES from fast_grpc.utils import protoc_compile, camel_to_snake _base_types = { @@ -37,6 +37,11 @@ syntax = "proto3"; package {{ proto_define.package }}; +{% for depend in proto_define.dependencies %} +{%- if depend %} +import "{{ depend }}"; +{%- endif %} +{%- endfor %} {% for enum in proto_define.enums.values() %} enum {{ enum.name }} { {% for field in enum.fields -%} @@ -139,6 +144,7 @@ class ProtoDefine(BaseModel): services: list[ProtoService] messages: dict[Any, ProtoStruct] enums: dict[Any, ProtoStruct] + dependencies: set[str] = Field(default_factory=set) def render(self, proto_template) -> str: template = Template(proto_template) @@ -159,6 +165,8 @@ def generate_type_name(type_: type) -> str: """ if not isinstance(type_, type): raise ValueError(f"'{type_}' must be a type") + if type_ in (bytes, int, float, bool, str, datetime.datetime): + return type_.__name__.capitalize() origin = get_origin(type_) args = get_args(type_) if origin is None: @@ -171,15 +179,12 @@ def generate_type_name(type_: type) -> str: return "".join(type_names + [origin.__name__]) if issubclass(type_, IntEnum): return type_.__name__ - if not issubclass(type_, tuple(_base_types)): - raise ValueError(f"Unsupported type: {type_}") - return type_.__name__.capitalize() else: if issubclass(origin, Sequence): return f"{generate_type_name(args[0])}List" if issubclass(origin, dict): return f"{generate_type_name(args[0])}{generate_type_name(args[1])}Dict" - raise ValueError(f"Unsupported type: {type_}") + raise ValueError(f"Unsupported type: {type_}") class ProtoBuilder: @@ -211,8 +216,8 @@ def convert_message(self, schema: Type[BaseModel]) -> ProtoStruct: if schema in self._proto_define.messages: return self._proto_define.messages[schema] message = ProtoStruct(name=generate_type_name(schema), fields=[]) - for i, (name, field) in enumerate(schema.model_fields.items(), 1): - type_name = self._get_type_name(field.annotation, field.metadata) + for i, name in enumerate(schema.model_fields.keys(), 1): + type_name = self._get_type_name(schema.__annotations__[name]) message.fields.append(ProtoField(name=name, type=type_name, index=i)) self._proto_define.messages[schema] = message return message @@ -236,14 +241,19 @@ def convert_enum(self, schema: Type[IntEnum]): self._proto_define.enums[schema] = enum_struct return enum_struct - def _get_type_name(self, type_: Any, metadata: Sequence[Any] = tuple()) -> str: + def _get_type_name(self, type_: Any) -> str: origin = get_origin(type_) args = get_args(type_) - for tag in metadata: - if isinstance(tag, ProtoTag): - return tag.name + if type_ in PYTHON_TO_PROTOBUF_TYPES: + tag = PYTHON_TO_PROTOBUF_TYPES[type_] + self._proto_define.dependencies.add(tag.package) + return tag.name if origin is typing.Annotated: - return self._get_type_name(args[0], args[1:]) + for tag in args[1:]: + if isinstance(tag, ProtoTag): + self._proto_define.dependencies.add(tag.package) + return tag.name + return self._get_type_name(args[0]) if origin is typing.Union: _args = [i for i in args if i is not type(None)] return self._get_type_name(_args[0]) @@ -254,15 +264,12 @@ def _get_type_name(self, type_: Any, metadata: Sequence[Any] = tuple()) -> str: if issubclass(type_, IntEnum): struct = self.convert_enum(type_) return struct.name - if not issubclass(type_, tuple(_base_types)): - raise ValueError(f"Unsupported type: {type_}") - return _base_types[type_] else: if issubclass(origin, Sequence): return f"repeated {self._get_type_name(args[0])}" if issubclass(origin, dict): return f"map <{self._get_type_name(args[0])}, {self._get_type_name(args[1])}>" - raise ValueError(f"Unsupported type: {type_}") + raise ValueError(f"Unsupported type: {type_}") class ClientBuilder: diff --git a/fast_grpc/types.py b/fast_grpc/types.py index 7490739..ac46caa 100644 --- a/fast_grpc/types.py +++ b/fast_grpc/types.py @@ -1,41 +1,35 @@ # -*- coding: utf-8 -*- -from typing import ( - Awaitable, - Callable, - TypeVar, - Sequence, - Tuple, - Union, - AsyncIterable, - Annotated, -) +from datetime import datetime +from typing import Annotated from pydantic import BaseModel, Field -from typing_extensions import TypeAlias -from fast_grpc.context import ServiceContext -Request = TypeVar("Request") -Response = TypeVar("Response") +class ProtoTag: + __slots__ = ("name", "package") -Method = Callable[ - [Request, ServiceContext], Union[AsyncIterable[Response], Awaitable[Response]] -] -MetadataType = Sequence[Tuple[str, Union[str, bytes]]] + def __init__(self, name: str, package: str = ""): + self.name = name + self.package = package -class Empty(BaseModel): - pass +# Python -> Protobuf +PYTHON_TO_PROTOBUF_TYPES = { + bytes: ProtoTag("bytes"), + int: ProtoTag("int32"), + float: ProtoTag("float"), + bool: ProtoTag("bool"), + str: ProtoTag("string"), + datetime: ProtoTag("string"), +} -class ProtoTag: - def __init__(self, name: str, package: str = ""): - self.name = name - self.package = package +Uint32 = Annotated[int, Field(ge=0, lt=2**32), ProtoTag(name="uint32")] +Uint64 = Annotated[int, Field(ge=0, lt=2**64), ProtoTag(name="uint64")] +Int32 = Annotated[int, Field(ge=-(2**31), lt=2**31), ProtoTag(name="int32")] +Int64 = Annotated[int, Field(ge=-(2**63), lt=2**63), ProtoTag(name="int64")] +Double = Annotated[float, ProtoTag(name="double")] -Uint32: TypeAlias = Annotated[int, Field(ge=0, lt=2**32), ProtoTag(name="uint32")] -Uint64: TypeAlias = Annotated[int, Field(ge=0, lt=2**64), ProtoTag(name="uint64")] -Int32: TypeAlias = Annotated[int, Field(ge=-(2**31), lt=2**31), ProtoTag(name="int32")] -Int64: TypeAlias = Annotated[int, Field(ge=-(2**63), lt=2**63), ProtoTag(name="int64")] -Double: TypeAlias = Annotated[float, ProtoTag(name="double")] +class Empty(BaseModel): + pass