Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add extra argument to load/dump for extensions #1725

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
Changelog
---------

3.11.0 (Unreleased)
*******************

Features:

- Add ``extra`` argument to ``Schema.load`` and ``Schema.dump`` as an extension
point for libraries built on marshmallow

3.10.0 (2020-12-19)
*******************

Expand Down
58 changes: 50 additions & 8 deletions src/marshmallow/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,12 +502,17 @@ def _call_and_store(getter_func, data, *, field_name, error_store, index=None):
return value

def _serialize(
self, obj: typing.Union[_T, typing.Iterable[_T]], *, many: bool = False
self,
obj: typing.Union[_T, typing.Iterable[_T]],
*,
many: bool = False,
extra: typing.Optional[typing.Mapping[str, typing.Any]] = None
):
"""Serialize ``obj``.

:param obj: The object(s) to serialize.
:param bool many: `True` if ``data`` should be serialized as a collection.
:param extra: Extra data for serialization. Ignored by ``marshmallow``.
:return: A dictionary of the serialized data

.. versionchanged:: 1.0.0
Expand All @@ -527,13 +532,22 @@ def _serialize(
ret[key] = value
return ret

def dump(self, obj: typing.Any, *, many: typing.Optional[bool] = None):
def dump(
self,
obj: typing.Any,
*,
many: typing.Optional[bool] = None,
extra: typing.Optional[typing.Mapping[str, typing.Any]] = None
):
"""Serialize an object to native Python data types according to this
Schema's fields.

:param obj: The object to serialize.
:param many: Whether to serialize `obj` as a collection. If `None`, the value
for `self.many` is used.
:param extra: Additional data which will be passed down to
_serialize. Libraries built on ``marshmallow`` may use this to
customize serialization.
:return: A dict of serialized data
:rtype: dict

Expand All @@ -556,7 +570,7 @@ def dump(self, obj: typing.Any, *, many: typing.Optional[bool] = None):
else:
processed_obj = obj

result = self._serialize(processed_obj, many=many)
result = self._serialize(processed_obj, many=many, extra=extra)

if self._has_processors(POST_DUMP):
result = self._invoke_dump_processors(
Expand All @@ -566,13 +580,21 @@ def dump(self, obj: typing.Any, *, many: typing.Optional[bool] = None):
return result

def dumps(
self, obj: typing.Any, *args, many: typing.Optional[bool] = None, **kwargs
self,
obj: typing.Any,
*args,
many: typing.Optional[bool] = None,
extra: typing.Optional[typing.Mapping[str, typing.Any]] = None,
**kwargs
):
"""Same as :meth:`dump`, except return a JSON-encoded string.

:param obj: The object to serialize.
:param many: Whether to serialize `obj` as a collection. If `None`, the value
for `self.many` is used.
:param extra: Additional data which will be passed down to
_serialize. Libraries built on ``marshmallow`` may use this to
customize serialization.
:return: A ``json`` string
:rtype: str

Expand All @@ -582,7 +604,7 @@ def dumps(
A :exc:`ValidationError <marshmallow.exceptions.ValidationError>` is raised
if ``obj`` is invalid.
"""
serialized = self.dump(obj, many=many)
serialized = self.dump(obj, many=many, extra=extra)
return self.opts.render_module.dumps(serialized, *args, **kwargs)

def _deserialize(
Expand All @@ -596,6 +618,7 @@ def _deserialize(
many: bool = False,
partial=False,
unknown=RAISE,
extra=None,
index=None
) -> typing.Union[_T, typing.List[_T]]:
"""Deserialize ``data``.
Expand All @@ -611,6 +634,7 @@ def _deserialize(
fields in the data. Use `EXCLUDE`, `INCLUDE` or `RAISE`.
:param int index: Index of the item being serialized (for storing errors) if
serializing a collection, otherwise `None`.
:param dict extra: Extra data for deserialization, ignored by ``marshmallow``.
:return: A dictionary of the deserialized data.
"""
index_errors = self.opts.index_errors
Expand Down Expand Up @@ -702,7 +726,8 @@ def load(
*,
many: typing.Optional[bool] = None,
partial: typing.Optional[typing.Union[bool, types.StrSequenceOrSet]] = None,
unknown: typing.Optional[str] = None
unknown: typing.Optional[str] = None,
extra: typing.Optional[typing.Mapping[str, typing.Any]] = None
):
"""Deserialize a data structure to an object defined by this Schema's fields.

Expand All @@ -716,6 +741,9 @@ def load(
:param unknown: Whether to exclude, include, or raise an error for unknown
fields in the data. Use `EXCLUDE`, `INCLUDE` or `RAISE`.
If `None`, the value for `self.unknown` is used.
:param extra: Additional data which will be passed down to
_deserialize. Libraries built on ``marshmallow`` may use this to
customize deserialization.
:return: Deserialized data

.. versionadded:: 1.0.0
Expand All @@ -725,7 +753,12 @@ def load(
if invalid data are passed.
"""
return self._do_load(
data, many=many, partial=partial, unknown=unknown, postprocess=True
data,
many=many,
partial=partial,
unknown=unknown,
extra=extra,
postprocess=True,
)

def loads(
Expand All @@ -735,6 +768,7 @@ def loads(
many: typing.Optional[bool] = None,
partial: typing.Optional[typing.Union[bool, types.StrSequenceOrSet]] = None,
unknown: typing.Optional[str] = None,
extra: typing.Optional[typing.Mapping[str, typing.Any]] = None,
**kwargs
):
"""Same as :meth:`load`, except it takes a JSON string as input.
Expand All @@ -749,6 +783,9 @@ def loads(
:param unknown: Whether to exclude, include, or raise an error for unknown
fields in the data. Use `EXCLUDE`, `INCLUDE` or `RAISE`.
If `None`, the value for `self.unknown` is used.
:param extra: Additional data which will be passed down to
_deserialize. Libraries built on ``marshmallow`` may use this to
customize deserialization.
:return: Deserialized data

.. versionadded:: 1.0.0
Expand All @@ -758,7 +795,7 @@ def loads(
if invalid data are passed.
"""
data = self.opts.render_module.loads(json_data, **kwargs)
return self.load(data, many=many, partial=partial, unknown=unknown)
return self.load(data, many=many, partial=partial, unknown=unknown, extra=extra)

def _run_validator(
self,
Expand Down Expand Up @@ -819,6 +856,7 @@ def _do_load(
many: typing.Optional[bool] = None,
partial: typing.Optional[typing.Union[bool, types.StrSequenceOrSet]] = None,
unknown: typing.Optional[str] = None,
extra: typing.Optional[typing.Mapping[str, typing.Any]] = None,
postprocess: bool = True
):
"""Deserialize `data`, returning the deserialized result.
Expand All @@ -834,6 +872,9 @@ def _do_load(
:param unknown: Whether to exclude, include, or raise an error for unknown
fields in the data. Use `EXCLUDE`, `INCLUDE` or `RAISE`.
If `None`, the value for `self.unknown` is used.
:param extra: Additional data which will be passed down to
_deserialize. Libraries built on ``marshmallow`` may use this to
customize deserialization.
:param postprocess: Whether to run post_load methods..
:return: Deserialized data
"""
Expand Down Expand Up @@ -864,6 +905,7 @@ def _do_load(
many=many,
partial=partial,
unknown=unknown,
extra=extra,
)
# Run field-level validation
self._invoke_field_validators(
Expand Down
29 changes: 29 additions & 0 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2912,3 +2912,32 @@ class DefinitelyUniqueSchema(Schema):

SchemaClass = class_registry.get_class(DefinitelyUniqueSchema.__name__)
assert SchemaClass is DefinitelyUniqueSchema


def test_extra_args_passthrough():
class MySchema(Schema):
last_deserialize_side_effect = None
last_serialize_side_effect = None

def _deserialize(self, *args, extra, **kwargs):
if "side_effect" in extra:
MySchema.last_deserialize_side_effect = extra["side_effect"]
super()._deserialize(*args, extra=extra, **kwargs)

def _serialize(self, *args, extra, **kwargs):
if "side_effect" in extra:
MySchema.last_serialize_side_effect = extra["side_effect"]
super()._serialize(*args, extra=extra, **kwargs)

name = fields.Str()
email = fields.Email()

data = {"name": "Mick", "email": "[email protected]"}

loaded = MySchema().load(data, extra={"side_effect": "foo"})
assert MySchema.last_deserialize_side_effect == "foo"
assert MySchema.last_serialize_side_effect is None

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd clear last_deserialize_side_effect here to make next test more selective.

MySchema().dump(loaded, extra={"side_effect": "bar"})
assert MySchema.last_deserialize_side_effect == "foo"
assert MySchema.last_serialize_side_effect == "bar"