diff --git a/lib/dl_core/dl_core_tests/unit/test_json_serializer.py b/lib/dl_core/dl_core_tests/unit/test_json_serializer.py index 8c196533e..d94854693 100644 --- a/lib/dl_core/dl_core_tests/unit/test_json_serializer.py +++ b/lib/dl_core/dl_core_tests/unit/test_json_serializer.py @@ -26,7 +26,7 @@ some_timedelta=datetime.timedelta(seconds=1320.0231), some_decimal=decimal.Decimal("12345" * 9 + "." + "54321" * 9), some_uuid=uuid.UUID("12345678123456781234567812345678"), - # some_bytes=b"Another one bites", TODO: currently not serializable + some_bytes=b"Another one bites", ) @@ -47,7 +47,7 @@ "value": "123451234512345123451234512345123451234512345.543215432154321543215432154321543215432154321", }, some_uuid={"__dl_type__": "uuid", "value": "12345678-1234-5678-1234-567812345678"}, - # some_bytes=b"Another one bites", + some_bytes={"__dl_type__": "bytes", "value": "QW5vdGhlciBvbmUgYml0ZXM="}, ) @@ -61,9 +61,7 @@ def test_json_serialization(): def test_json_tricky_serialization(): - data = SAMPLE_DATA - dumped = json.dumps(data, cls=RedisDatalensDataJSONEncoder) - tricky_data = dict(normal=data, abnormal=json.loads(dumped)) + tricky_data = dict(normal=SAMPLE_DATA, abnormal=EXPECTED_DUMP) tricky_data_dumped = json.dumps(tricky_data, cls=RedisDatalensDataJSONEncoder) tricky_roundtrip = json.loads(tricky_data_dumped, cls=RedisDatalensDataJSONDecoder) assert tricky_roundtrip["normal"] == tricky_data["normal"], tricky_roundtrip diff --git a/lib/dl_model_tools/dl_model_tools/serialization.py b/lib/dl_model_tools/dl_model_tools/serialization.py index 2da25d82d..b9693f83a 100644 --- a/lib/dl_model_tools/dl_model_tools/serialization.py +++ b/lib/dl_model_tools/dl_model_tools/serialization.py @@ -7,6 +7,7 @@ from __future__ import annotations import abc +import base64 import datetime import decimal import json @@ -152,6 +153,19 @@ def from_jsonable(value: TJSONLike) -> uuid.UUID: return uuid.UUID(value) +class BytesSerializer(TypeSerializer[bytes]): + typename = "bytes" + + @staticmethod + def to_jsonable(value: bytes) -> TJSONLike: + return base64.b64encode(value).decode("ascii") + + @staticmethod + def from_jsonable(value: TJSONLike) -> bytes: + assert isinstance(value, str) + return base64.b64decode(value, validate=True) + + COMMON_SERIALIZERS: list[Type[TypeSerializer]] = [ DateSerializer, DatetimeSerializer, @@ -159,6 +173,7 @@ def from_jsonable(value: TJSONLike) -> uuid.UUID: TimedeltaSerializer, DecimalSerializer, UUIDSerializer, + BytesSerializer, ] assert len(set(cls.typename for cls in COMMON_SERIALIZERS)) == len(COMMON_SERIALIZERS), "uniqueness check"