Skip to content

Commit

Permalink
Give user registered types priority when encoding / decoding JSON
Browse files Browse the repository at this point in the history
  • Loading branch information
soceanainn committed Nov 18, 2024
1 parent 3aed166 commit 21430db
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 17 deletions.
58 changes: 41 additions & 17 deletions kombu/utils/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,20 @@ class JSONEncoder(json.JSONEncoder):
"""Kombu custom json encoder."""

def default(self, o):
for t, (marker, encoder) in _encoders.items():
if isinstance(o, t):
return (
encoder(o) if marker is None else _as(marker, encoder(o))
)

reducer = getattr(o, "__json__", None)
if reducer is not None:
return reducer()

if isinstance(o, textual_types):
return str(o)

for t, (marker, encoder) in _encoders.items():
for t, (marker, encoder) in _default_encoders.items():
if isinstance(o, t):
return (
encoder(o) if marker is None else _as(marker, encoder(o))
Expand Down Expand Up @@ -66,7 +72,7 @@ def dumps(
def object_hook(o: dict):
"""Hook function to perform custom deserialization."""
if o.keys() == {"__type__", "__value__"}:
decoder = _decoders.get(o["__type__"])
decoder = _decoders.get(o["__type__"]) or _default_decoders.get(o["__type__"])
if decoder:
return decoder(o["__value__"])
else:
Expand Down Expand Up @@ -97,6 +103,16 @@ def loads(s, _loads=json.loads, decode_bytes=True, object_hook=object_hook):
T = TypeVar("T")
EncodedT = TypeVar("EncodedT")

# Separate user registered types from Kombu registered types to allow us to give preference to user types
_encoders: dict[type, tuple[str | None, EncoderT]] = {}
_decoders: dict[str, DecoderT] = {}

_default_encoders: dict[type, tuple[str | None, EncoderT]] = {}
_default_decoders: dict[str, DecoderT] = {
"bytes": lambda o: o.encode("utf-8"),
"base64": lambda o: base64.b64decode(o.encode("utf-8")),
}


def register_type(
t: type[T],
Expand All @@ -110,32 +126,40 @@ def register_type(
is not placed in an envelope, so `decoder` is unnecessary. Decoding must
instead be handled outside this library.
"""
_encoders[t] = (marker, encoder)
if marker is not None:
_decoders[marker] = decoder
_register_type(t, marker, encoder, decoder, is_default_encoder=False)


_encoders: dict[type, tuple[str | None, EncoderT]] = {}
_decoders: dict[str, DecoderT] = {
"bytes": lambda o: o.encode("utf-8"),
"base64": lambda o: base64.b64decode(o.encode("utf-8")),
}
def _register_type(
t: type[T],
marker: str | None,
encoder: Callable[[T], EncodedT],
decoder: Callable[[EncodedT], T] = lambda d: d,
is_default_encoder: bool = True,
):
if is_default_encoder:
_default_encoders[t] = (marker, encoder)
if marker is not None:
_default_decoders[marker] = decoder
else:
_encoders[t] = (marker, encoder)
if marker is not None:
_decoders[marker] = decoder


def _register_default_types():
# NOTE: datetime should be registered before date,
# because datetime is also instance of date.
register_type(datetime, "datetime", datetime.isoformat,
datetime.fromisoformat)
register_type(
_register_type(datetime, "datetime", datetime.isoformat,
datetime.fromisoformat)
_register_type(
date,
"date",
lambda o: o.isoformat(),
lambda o: datetime.fromisoformat(o).date(),
lambda o: datetime.fromisoformat(o).date()
)
register_type(time, "time", lambda o: o.isoformat(), time.fromisoformat)
register_type(Decimal, "decimal", str, Decimal)
register_type(
_register_type(time, "time", lambda o: o.isoformat(), time.fromisoformat)
_register_type(Decimal, "decimal", str, Decimal)
_register_type(
uuid.UUID,
"uuid",
lambda o: {"hex": o.hex},
Expand Down
9 changes: 9 additions & 0 deletions t/unit/utils/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,15 @@ def test_register_type_overrides_defaults(self):
loaded_value = loads(dumps({'u': value}))
assert loaded_value == {'u': "custom"}

def test_register_type_takes_priority(self):
class MyDecimal(Decimal):
pass

register_type(MyDecimal, "mydecimal", str, MyDecimal)
original = {'md': MyDecimal('3314132.13363235235324234123213213214134')}
loaded_value = loads(dumps(original))
assert original == loaded_value

def test_register_type_with_new_type(self):
# Guaranteed never before seen type
@dataclass()
Expand Down

0 comments on commit 21430db

Please sign in to comment.