From a1d0ee10bb7570286cabd25ee89e525f5333b7d1 Mon Sep 17 00:00:00 2001 From: DerTiedemann Date: Wed, 13 Mar 2024 13:35:53 +0100 Subject: [PATCH 1/4] fix: byte encoder for kafka messages refactored common logic into new codec class --- mlserver/codecs/json.py | 61 ++++++++++++++++++++++++++++++++++++++ mlserver/kafka/message.py | 29 +++--------------- mlserver/rest/responses.py | 56 ++++------------------------------ 3 files changed, 70 insertions(+), 76 deletions(-) create mode 100644 mlserver/codecs/json.py diff --git a/mlserver/codecs/json.py b/mlserver/codecs/json.py new file mode 100644 index 000000000..7f1371a2d --- /dev/null +++ b/mlserver/codecs/json.py @@ -0,0 +1,61 @@ +# seperate file to side step circular dependecy on the decode_str function + +from typing import Any, Union +import json + +try: + import orjson +except ImportError: + orjson = None # type: ignore + +from .string import decode_str + + +# originally taken from: mlserver/rest/responses.py +class _BytesJSONEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, bytes): + # If we get a bytes payload, try to decode it back to a string on a + # "best effort" basis + return decode_str(obj) + + return super().default(self, obj) + + +def _encode_object_to_bytes(obj: Any) -> str: + """ + Add compatibility with `bytes` payloads to `orjson` + """ + if isinstance(obj, bytes): + # If we get a bytes payload, try to decode it back to a string on a + # "best effort" basis + return decode_str(obj) + + raise TypeError + + +def encode_to_json_bytes(v: Any) -> bytes: + """encodes a dict into json bytes, can deal with byte like values gracefully""" + if orjson is None: + # Original implementation of starlette's JSONResponse, using our + # custom encoder (capable of "encoding" bytes). + # Original implementation can be seen here: + # https://github.com/encode/starlette/blob/ + # f53faba229e3fa2844bc3753e233d9c1f54cca52/starlette/responses.py#L173-L180 + return json.dumps( + v, + ensure_ascii=False, + allow_nan=False, + indent=None, + separators=(",", ":"), + cls=_BytesJSONEncoder, + ).encode("utf-8") + + return orjson.dumps(v, default=_encode_object_to_bytes) + + +def decode_from_bytelike_json_to_dict(v: Union[bytes, str]) -> dict: + if orjson is None: + return json.loads(v) + + return orjson.loads(v) diff --git a/mlserver/kafka/message.py b/mlserver/kafka/message.py index 45c50e4a0..2815bc990 100644 --- a/mlserver/kafka/message.py +++ b/mlserver/kafka/message.py @@ -1,28 +1,7 @@ -import json - -from typing import Dict, Optional, List, Tuple, Union +from typing import Dict, Optional, List, Tuple from pydantic import BaseModel - -try: - import orjson -except ImportError: - orjson = None # type: ignore - - -def _encode_value(v: dict) -> bytes: - if orjson is None: - dumped = json.dumps(v) - return dumped.encode("utf-8") - - return orjson.dumps(v) - - -def _decode_value(v: Union[bytes, str]) -> dict: - if orjson is None: - return json.loads(v) - - return orjson.loads(v) +from ..codecs.json import encode_to_json_bytes, decode_from_bytelike_json_to_dict def _encode_headers(h: Dict[str, str]) -> List[Tuple[str, bytes]]: @@ -48,7 +27,7 @@ def from_types( @classmethod def from_kafka_record(cls, kafka_record) -> "KafkaMessage": key = kafka_record.key - value = _decode_value(kafka_record.value) + value = decode_from_bytelike_json_to_dict(kafka_record.value) headers = _decode_headers(kafka_record.headers) return KafkaMessage(key=key, value=value, headers=headers) @@ -61,7 +40,7 @@ def encoded_key(self) -> bytes: @property def encoded_value(self) -> bytes: - return _encode_value(self.value) + return encode_to_json_bytes(self.value) @property def encoded_headers(self) -> List[Tuple[str, bytes]]: diff --git a/mlserver/rest/responses.py b/mlserver/rest/responses.py index c55835064..c9348a22d 100644 --- a/mlserver/rest/responses.py +++ b/mlserver/rest/responses.py @@ -1,64 +1,18 @@ -import json - from typing import Any from starlette.responses import JSONResponse as _JSONResponse -from ..codecs.string import decode_str - -try: - import orjson -except ImportError: - orjson = None # type: ignore - - -class BytesJSONEncoder(json.JSONEncoder): - def default(self, obj): - if isinstance(obj, bytes): - # If we get a bytes payload, try to decode it back to a string on a - # "best effort" basis - return decode_str(obj) - - return super().default(self, obj) +from ..codecs.json import encode_to_json_bytes class Response(_JSONResponse): """ - Custom Response class to use `orjson` if present. - Otherwise, it'll fall back to the standard JSONResponse. + Custom Response that will use the encode_to_json_bytes function to + encode given content to json based on library availability. + See mlserver/codecs/utils.py for more details """ media_type = "application/json" def render(self, content: Any) -> bytes: - if orjson is None: - # Original implementation of starlette's JSONResponse, using our - # custom encoder (capable of "encoding" bytes). - # Original implementation can be seen here: - # https://github.com/encode/starlette/blob/ - # f53faba229e3fa2844bc3753e233d9c1f54cca52/starlette/responses.py#L173-L180 - return json.dumps( - content, - ensure_ascii=False, - allow_nan=False, - indent=None, - separators=(",", ":"), - cls=BytesJSONEncoder, - ).encode("utf-8") - - # This is equivalent to the ORJSONResponse implementation in FastAPI: - # https://github.com/tiangolo/fastapi/blob/ - # 864643ef7608d28ac4ed321835a7fb4abe3dfc13/fastapi/responses.py#L32-L34 - return orjson.dumps(content, default=_encode_bytes) - - -def _encode_bytes(obj: Any) -> str: - """ - Add compatibility with `bytes` payloads to `orjson` - """ - if isinstance(obj, bytes): - # If we get a bytes payload, try to decode it back to a string on a - # "best effort" basis - return decode_str(obj) - - raise TypeError + return encode_to_json_bytes(content) From a7172a98402a0cb1821297b45a236f7fce6d8d0b Mon Sep 17 00:00:00 2001 From: DerTiedemann Date: Wed, 13 Mar 2024 14:09:08 +0100 Subject: [PATCH 2/4] chore: add tests --- tests/codecs/test_json.py | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 tests/codecs/test_json.py diff --git a/tests/codecs/test_json.py b/tests/codecs/test_json.py new file mode 100644 index 000000000..7c6530494 --- /dev/null +++ b/tests/codecs/test_json.py @@ -0,0 +1,39 @@ +import pytest + +from typing import Any, Union + +from mlserver.codecs.json import decode_from_bytelike_json_to_dict, encode_to_json_bytes + + +@pytest.mark.parametrize( + "input, expected", + [ + (b"{}", dict()), + ("{}", dict()), + ('{"hello":"world"}', {"hello": "world"}), + (b'{"hello":"world"}', {"hello": "world"}), + (b'{"hello":"' + "world".encode("utf-8") + b'"}', {"hello": "world"}), + ( + b'{"hello":"' + "world".encode("utf-8") + b'", "foo": { "bar": "baz" } }', + {"hello": "world", "foo": {"bar": "baz"}}, + ), + ], +) +def test_decode_input(input: Union[str, bytes], expected): + assert expected == decode_from_bytelike_json_to_dict(input) + + +@pytest.mark.parametrize( + "expected, input", + [ + (b"{}", dict()), + (b'{"hello":"world"}', {"hello": "world"}), + (b'{"hello":"' + "world".encode("utf-8") + b'"}', {"hello": "world"}), + ( + b'{"hello":"' + "world".encode("utf-8") + b'","foo":{"bar":"baz"}}', + {"hello": b"world", "foo": {"bar": "baz"}}, + ), + ], +) +def test_encode_input(input: Any, expected: bytes): + assert expected == encode_to_json_bytes(input) From f748fdbe7324dcc6b019426b48b6b77f647e9159 Mon Sep 17 00:00:00 2001 From: Jan Max Tiedemann Date: Wed, 8 May 2024 11:14:21 +0200 Subject: [PATCH 3/4] fix: address pr comments --- mlserver/codecs/json.py | 2 +- tests/codecs/test_json.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/mlserver/codecs/json.py b/mlserver/codecs/json.py index 7f1371a2d..9afbf7e4b 100644 --- a/mlserver/codecs/json.py +++ b/mlserver/codecs/json.py @@ -1,4 +1,4 @@ -# seperate file to side step circular dependecy on the decode_str function +# seperate file to side step circular dependency on the decode_str function from typing import Any, Union import json diff --git a/tests/codecs/test_json.py b/tests/codecs/test_json.py index 7c6530494..b3d98d689 100644 --- a/tests/codecs/test_json.py +++ b/tests/codecs/test_json.py @@ -19,11 +19,12 @@ ), ], ) -def test_decode_input(input: Union[str, bytes], expected): +def test_decode_input(input: Union[str, bytes], expected: dict): assert expected == decode_from_bytelike_json_to_dict(input) @pytest.mark.parametrize( + # input and expected are flipped here for easier CTRL+C / V "expected, input", [ (b"{}", dict()), From a98ff692d34a692313850d3cf23d75dba580ef41 Mon Sep 17 00:00:00 2001 From: Sherif Akoush Date: Thu, 4 Jul 2024 12:03:12 +0100 Subject: [PATCH 4/4] fix lint --- mlserver/rest/responses.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlserver/rest/responses.py b/mlserver/rest/responses.py index 19b3fa0f5..c59d01645 100644 --- a/mlserver/rest/responses.py +++ b/mlserver/rest/responses.py @@ -29,4 +29,4 @@ def __init__(self, data: BaseModel, *args, **kwargs): def encode(self) -> bytes: as_dict = self.data.model_dump() - return self._pre + encode_to_json_bytes(as_dict) + self._sep \ No newline at end of file + return self._pre + encode_to_json_bytes(as_dict) + self._sep