Skip to content

Commit

Permalink
Feat: Implemeneted updates to serialize for async
Browse files Browse the repository at this point in the history
  • Loading branch information
felixnext committed May 25, 2024
1 parent 6a39141 commit de21584
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 2 deletions.
24 changes: 24 additions & 0 deletions functown/serialization/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Copyright (c) 2023, Felix Geilert
"""

import asyncio
from abc import abstractmethod
from typing import Any, Dict, Tuple, Union, Optional

Expand Down Expand Up @@ -50,7 +51,30 @@ def serialize(
"""
raise NotImplementedError

async def __async_wrapper(self, func, *args, **kwargs):
# execute inner function
res = await func(*args, **kwargs)

# get request object
req = self._get("req", 0, *args, **kwargs)
if "req" in kwargs:
del kwargs["req"]
else:
args = args[1:]

# serialize result
data, mtype = self.serialize(req, res, *args, **kwargs)
if isinstance(mtype, ContentTypes):
mtype = mtype.value
return HttpResponse(
data, status_code=self._code, headers=self._headers, mimetype=mtype
)

def run(self, func, *args, **kwargs):
# check if async
if asyncio.iscoroutinefunction(func):
return self.__async_wrapper(func, *args, **kwargs)

# execute inner function
res = func(*args, **kwargs)

Expand Down
3 changes: 1 addition & 2 deletions functown/serialization/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ def __init__(
def serialize(
self, req: HttpRequest, res: tp.Any, *args, **kwargs
) -> tuple[tp.Union[bytes, str], str]:
# FEAT: integrate async support [LIN:MED-568]
# check for request header
mime_raw = RequestArgHandler(req).get_header(
HeaderEnum.CONTENT_TYPE, required=False
Expand Down Expand Up @@ -114,7 +113,7 @@ def serialize(

# execute correct response
if use_json is True:
# DEBT: sub-objects are skipped instead of null values [LIN:MED-391]
# DEBT: sub-objects are skipped instead of null values
# dict = json_format.MessageToDict(
# res,
# # NOTE: this is helpful for gql parsing
Expand Down
24 changes: 24 additions & 0 deletions tests/serialization/test_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,27 @@ def main(req: HttpRequest) -> pb2.InformationList:
body_item = pb2.InformationList()
body_item.ParseFromString(res.get_body())
assert body_item == item


@pytest.mark.asyncio
async def test_hybridproto_response_async(json_data):
"""Tests protobuf response serialization."""
# generate
item = pb2.InformationList()
info = item.infos.add()
info.msg = "Hello World"
info.id = 1
info.score = 0.5
for i in range(3):
d = info.data.add()
d.msg = f"Hello World {i}"
d.type = pb2.Information.Importance.HIGH

@HybridProtoResponse
async def main(req: HttpRequest) -> pb2.InformationList:
return item

res = await main(req=HttpRequest("GET", "http://localhost", body=None))
assert isinstance(res, HttpResponse)
assert res.mimetype == "application/octet-stream"
assert res.get_body() == item.SerializeToString()

0 comments on commit de21584

Please sign in to comment.