Skip to content

Commit

Permalink
add ProtoTag
Browse files Browse the repository at this point in the history
  • Loading branch information
Tonyogo committed Feb 7, 2025
1 parent 77b68ca commit 8feca90
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 47 deletions.
2 changes: 1 addition & 1 deletion fast_grpc/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def __init__(self, grpc_context: grpc.ServicerContext, method, method_descriptor
self.input_type = method_descriptor.input_type._concrete_class
self.output_type = method_descriptor.output_type._concrete_class
self._start_time = time.time()
self._metadata = {}
self._metadata: dict[str, str] = {}

@property
def elapsed_time(self):
Expand Down
41 changes: 24 additions & 17 deletions fast_grpc/proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Type, Sequence, Any

import grpc
from pydantic import BaseModel
from pydantic import BaseModel, Field
from typing_extensions import get_args, get_origin
from jinja2 import Template
from google.protobuf.descriptor import (
Expand All @@ -17,7 +17,7 @@
)

from fast_grpc.service import Service, MethodMode
from fast_grpc.types import Empty, ProtoTag
from fast_grpc.types import Empty, ProtoTag, PYTHON_TO_PROTOBUF_TYPES
from fast_grpc.utils import protoc_compile, camel_to_snake

_base_types = {
Expand All @@ -37,6 +37,11 @@
syntax = "proto3";
package {{ proto_define.package }};
{% for depend in proto_define.dependencies %}
{%- if depend %}
import "{{ depend }}";
{%- endif %}
{%- endfor %}
{% for enum in proto_define.enums.values() %}
enum {{ enum.name }} {
{% for field in enum.fields -%}
Expand Down Expand Up @@ -139,6 +144,7 @@ class ProtoDefine(BaseModel):
services: list[ProtoService]
messages: dict[Any, ProtoStruct]
enums: dict[Any, ProtoStruct]
dependencies: set[str] = Field(default_factory=set)

def render(self, proto_template) -> str:
template = Template(proto_template)
Expand All @@ -159,6 +165,8 @@ def generate_type_name(type_: type) -> str:
"""
if not isinstance(type_, type):
raise ValueError(f"'{type_}' must be a type")
if type_ in (bytes, int, float, bool, str, datetime.datetime):
return type_.__name__.capitalize()
origin = get_origin(type_)
args = get_args(type_)
if origin is None:
Expand All @@ -171,15 +179,12 @@ def generate_type_name(type_: type) -> str:
return "".join(type_names + [origin.__name__])
if issubclass(type_, IntEnum):
return type_.__name__
if not issubclass(type_, tuple(_base_types)):
raise ValueError(f"Unsupported type: {type_}")
return type_.__name__.capitalize()
else:
if issubclass(origin, Sequence):
return f"{generate_type_name(args[0])}List"
if issubclass(origin, dict):
return f"{generate_type_name(args[0])}{generate_type_name(args[1])}Dict"
raise ValueError(f"Unsupported type: {type_}")
raise ValueError(f"Unsupported type: {type_}")


class ProtoBuilder:
Expand Down Expand Up @@ -211,8 +216,8 @@ def convert_message(self, schema: Type[BaseModel]) -> ProtoStruct:
if schema in self._proto_define.messages:
return self._proto_define.messages[schema]
message = ProtoStruct(name=generate_type_name(schema), fields=[])
for i, (name, field) in enumerate(schema.model_fields.items(), 1):
type_name = self._get_type_name(field.annotation, field.metadata)
for i, name in enumerate(schema.model_fields.keys(), 1):
type_name = self._get_type_name(schema.__annotations__[name])
message.fields.append(ProtoField(name=name, type=type_name, index=i))
self._proto_define.messages[schema] = message
return message
Expand All @@ -236,14 +241,19 @@ def convert_enum(self, schema: Type[IntEnum]):
self._proto_define.enums[schema] = enum_struct
return enum_struct

def _get_type_name(self, type_: Any, metadata: Sequence[Any] = tuple()) -> str:
def _get_type_name(self, type_: Any) -> str:
origin = get_origin(type_)
args = get_args(type_)
for tag in metadata:
if isinstance(tag, ProtoTag):
return tag.name
if type_ in PYTHON_TO_PROTOBUF_TYPES:
tag = PYTHON_TO_PROTOBUF_TYPES[type_]
self._proto_define.dependencies.add(tag.package)
return tag.name
if origin is typing.Annotated:
return self._get_type_name(args[0], args[1:])
for tag in args[1:]:
if isinstance(tag, ProtoTag):
self._proto_define.dependencies.add(tag.package)
return tag.name
return self._get_type_name(args[0])
if origin is typing.Union:
_args = [i for i in args if i is not type(None)]
return self._get_type_name(_args[0])
Expand All @@ -254,15 +264,12 @@ def _get_type_name(self, type_: Any, metadata: Sequence[Any] = tuple()) -> str:
if issubclass(type_, IntEnum):
struct = self.convert_enum(type_)
return struct.name
if not issubclass(type_, tuple(_base_types)):
raise ValueError(f"Unsupported type: {type_}")
return _base_types[type_]
else:
if issubclass(origin, Sequence):
return f"repeated {self._get_type_name(args[0])}"
if issubclass(origin, dict):
return f"map <{self._get_type_name(args[0])}, {self._get_type_name(args[1])}>"
raise ValueError(f"Unsupported type: {type_}")
raise ValueError(f"Unsupported type: {type_}")


class ClientBuilder:
Expand Down
52 changes: 23 additions & 29 deletions fast_grpc/types.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,35 @@
# -*- coding: utf-8 -*-
from typing import (
Awaitable,
Callable,
TypeVar,
Sequence,
Tuple,
Union,
AsyncIterable,
Annotated,
)
from datetime import datetime
from typing import Annotated

from pydantic import BaseModel, Field
from typing_extensions import TypeAlias

from fast_grpc.context import ServiceContext

Request = TypeVar("Request")
Response = TypeVar("Response")
class ProtoTag:
__slots__ = ("name", "package")

Method = Callable[
[Request, ServiceContext], Union[AsyncIterable[Response], Awaitable[Response]]
]
MetadataType = Sequence[Tuple[str, Union[str, bytes]]]
def __init__(self, name: str, package: str = ""):
self.name = name
self.package = package


class Empty(BaseModel):
pass
# Python -> Protobuf
PYTHON_TO_PROTOBUF_TYPES = {
bytes: ProtoTag("bytes"),
int: ProtoTag("int32"),
float: ProtoTag("float"),
bool: ProtoTag("bool"),
str: ProtoTag("string"),
datetime: ProtoTag("string"),
}


class ProtoTag:
def __init__(self, name: str, package: str = ""):
self.name = name
self.package = package
Uint32 = Annotated[int, Field(ge=0, lt=2**32), ProtoTag(name="uint32")]
Uint64 = Annotated[int, Field(ge=0, lt=2**64), ProtoTag(name="uint64")]
Int32 = Annotated[int, Field(ge=-(2**31), lt=2**31), ProtoTag(name="int32")]
Int64 = Annotated[int, Field(ge=-(2**63), lt=2**63), ProtoTag(name="int64")]
Double = Annotated[float, ProtoTag(name="double")]


Uint32: TypeAlias = Annotated[int, Field(ge=0, lt=2**32), ProtoTag(name="uint32")]
Uint64: TypeAlias = Annotated[int, Field(ge=0, lt=2**64), ProtoTag(name="uint64")]
Int32: TypeAlias = Annotated[int, Field(ge=-(2**31), lt=2**31), ProtoTag(name="int32")]
Int64: TypeAlias = Annotated[int, Field(ge=-(2**63), lt=2**63), ProtoTag(name="int64")]
Double: TypeAlias = Annotated[float, ProtoTag(name="double")]
class Empty(BaseModel):
pass

0 comments on commit 8feca90

Please sign in to comment.