From de21584058f032ca4310a75361c61ae847e51875 Mon Sep 17 00:00:00 2001 From: Felix Geilert Date: Sat, 25 May 2024 08:58:27 +0200 Subject: [PATCH] Feat: Implemeneted updates to serialize for async --- functown/serialization/base.py | 24 ++++++++++++++++++++++++ functown/serialization/hybrid.py | 3 +-- tests/serialization/test_hybrid.py | 24 ++++++++++++++++++++++++ 3 files changed, 49 insertions(+), 2 deletions(-) diff --git a/functown/serialization/base.py b/functown/serialization/base.py index 5324af9..468589b 100644 --- a/functown/serialization/base.py +++ b/functown/serialization/base.py @@ -3,6 +3,7 @@ Copyright (c) 2023, Felix Geilert """ +import asyncio from abc import abstractmethod from typing import Any, Dict, Tuple, Union, Optional @@ -50,7 +51,30 @@ def serialize( """ raise NotImplementedError + async def __async_wrapper(self, func, *args, **kwargs): + # execute inner function + res = await func(*args, **kwargs) + + # get request object + req = self._get("req", 0, *args, **kwargs) + if "req" in kwargs: + del kwargs["req"] + else: + args = args[1:] + + # serialize result + data, mtype = self.serialize(req, res, *args, **kwargs) + if isinstance(mtype, ContentTypes): + mtype = mtype.value + return HttpResponse( + data, status_code=self._code, headers=self._headers, mimetype=mtype + ) + def run(self, func, *args, **kwargs): + # check if async + if asyncio.iscoroutinefunction(func): + return self.__async_wrapper(func, *args, **kwargs) + # execute inner function res = func(*args, **kwargs) diff --git a/functown/serialization/hybrid.py b/functown/serialization/hybrid.py index d7aec43..d8097e9 100644 --- a/functown/serialization/hybrid.py +++ b/functown/serialization/hybrid.py @@ -66,7 +66,6 @@ def __init__( def serialize( self, req: HttpRequest, res: tp.Any, *args, **kwargs ) -> tuple[tp.Union[bytes, str], str]: - # FEAT: integrate async support [LIN:MED-568] # check for request header mime_raw = RequestArgHandler(req).get_header( HeaderEnum.CONTENT_TYPE, required=False @@ -114,7 +113,7 @@ def serialize( # execute correct response if use_json is True: - # DEBT: sub-objects are skipped instead of null values [LIN:MED-391] + # DEBT: sub-objects are skipped instead of null values # dict = json_format.MessageToDict( # res, # # NOTE: this is helpful for gql parsing diff --git a/tests/serialization/test_hybrid.py b/tests/serialization/test_hybrid.py index 711365d..c714509 100644 --- a/tests/serialization/test_hybrid.py +++ b/tests/serialization/test_hybrid.py @@ -84,3 +84,27 @@ def main(req: HttpRequest) -> pb2.InformationList: body_item = pb2.InformationList() body_item.ParseFromString(res.get_body()) assert body_item == item + + +@pytest.mark.asyncio +async def test_hybridproto_response_async(json_data): + """Tests protobuf response serialization.""" + # generate + item = pb2.InformationList() + info = item.infos.add() + info.msg = "Hello World" + info.id = 1 + info.score = 0.5 + for i in range(3): + d = info.data.add() + d.msg = f"Hello World {i}" + d.type = pb2.Information.Importance.HIGH + + @HybridProtoResponse + async def main(req: HttpRequest) -> pb2.InformationList: + return item + + res = await main(req=HttpRequest("GET", "http://localhost", body=None)) + assert isinstance(res, HttpResponse) + assert res.mimetype == "application/octet-stream" + assert res.get_body() == item.SerializeToString()