Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Let data_key and attribute default to field name #1897

Open
wants to merge 7 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 25 additions & 5 deletions src/marshmallow/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def __init__(
def __repr__(self) -> str:
return (
"<fields.{ClassName}(dump_default={self.dump_default!r}, "
"attribute={self.attribute!r}, "
"attribute={self.attribute!r}, data_key={self.data_key!r}, "
"validate={self.validate}, required={self.required}, "
"load_only={self.load_only}, dump_only={self.dump_only}, "
"load_default={self.load_default}, allow_none={self.allow_none}, "
Expand Down Expand Up @@ -368,6 +368,11 @@ def deserialize(

# Methods for concrete classes to override.

def _set_name(self, field_name):
self.name = self.name or field_name
self.data_key = self.data_key if self.data_key is not None else field_name
self.attribute = self.attribute if self.attribute is not None else field_name

def _bind_to_schema(self, field_name, schema):
"""Update field with values from its parent schema. Called by
:meth:`Schema._bind_field <marshmallow.Schema._bind_field>`.
Expand All @@ -376,7 +381,6 @@ def _bind_to_schema(self, field_name, schema):
:param Schema|Field schema: Parent object.
"""
self.parent = self.parent or schema
self.name = self.name or field_name
self.root = self.root or (
self.parent.root if isinstance(self.parent, FieldABC) else self.parent
)
Expand Down Expand Up @@ -688,7 +692,7 @@ def __init__(
@property
def _field_data_key(self):
only_field = self.schema.fields[self.field_name]
return only_field.data_key or self.field_name
return only_field.data_key

def _serialize(self, nested_obj, attr, obj, **kwargs):
ret = super()._serialize(nested_obj, attr, obj, **kwargs)
Expand Down Expand Up @@ -742,6 +746,10 @@ def __init__(self, cls_or_instance: typing.Union[Field, type], **kwargs):
self.only = self.inner.only
self.exclude = self.inner.exclude

def _set_name(self, field_name):
super()._set_name(field_name)
self.inner._set_name(field_name)

def _bind_to_schema(self, field_name, schema):
super()._bind_to_schema(field_name, schema)
self.inner = copy.deepcopy(self.inner)
Expand Down Expand Up @@ -818,6 +826,11 @@ def __init__(self, tuple_fields, *args, **kwargs):

self.validate_length = Length(equal=len(self.tuple_fields))

def _set_name(self, field_name):
super()._set_name(field_name)
for field in self.tuple_fields:
field._set_name(field_name)

def _bind_to_schema(self, field_name, schema):
super()._bind_to_schema(field_name, schema)
new_tuple_fields = []
Expand Down Expand Up @@ -1541,6 +1554,13 @@ def __init__(
self.only = self.value_field.only
self.exclude = self.value_field.exclude

def _set_name(self, field_name):
super()._set_name(field_name)
if self.value_field:
self.value_field._set_name(field_name)
if self.key_field:
self.key_field._set_name(field_name)

def _bind_to_schema(self, field_name, schema):
super()._bind_to_schema(field_name, schema)
if self.value_field:
Expand Down Expand Up @@ -1973,8 +1993,8 @@ class Inferred(Field):
Users should not need to use this class directly.
"""

def __init__(self):
super().__init__()
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# We memoize the fields to avoid creating and binding new fields
# every time on serialization.
self._field_cache = {}
Expand Down
81 changes: 41 additions & 40 deletions src/marshmallow/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,23 @@
_T = typing.TypeVar("_T")


def _set_field_name(field_obj, field_name):
try:
field_obj._set_name(field_name)
except TypeError as error:
# Field declared as a class, not an instance. Ignore type checking because
# we handle unsupported arg types, i.e. this is dead code from
# the type checker's perspective.
if isinstance(field_obj, type) and issubclass(field_obj, base.FieldABC):
msg = (
'Field for "{}" must be declared as a '
"Field instance, not a class. "
'Did you mean "fields.{}()"?'.format(field_name, field_obj.__name__)
)
raise TypeError(msg) from error
raise error


def _get_fields(attrs, ordered=False):
"""Get fields from a class. If ordered=True, fields will sorted by creation index.

Expand All @@ -51,6 +68,9 @@ def _get_fields(attrs, ordered=False):
]
if ordered:
fields.sort(key=lambda pair: pair[1]._creation_index)
# Set field name on each field
for field_name, field_value in fields:
_set_field_name(field_value, field_name)
return fields


Expand Down Expand Up @@ -111,6 +131,8 @@ def __new__(mcs, name, bases, attrs):
# get_declared_fields
klass.opts = klass.OPTIONS_CLASS(meta, ordered=ordered)
# Add fields specified in the `include` class Meta option
for field_name, field_obj in klass.opts.include.items():
_set_field_name(field_obj, field_name)
cls_fields += list(klass.opts.include.items())

dict_cls = OrderedDict if ordered else dict
Expand Down Expand Up @@ -520,8 +542,7 @@ def _serialize(
value = field_obj.serialize(attr_name, obj, accessor=self.get_attribute)
if value is missing:
continue
key = field_obj.data_key if field_obj.data_key is not None else attr_name
ret[key] = value
ret[field_obj.data_key] = value
return ret

def dump(self, obj: typing.Any, *, many: typing.Optional[bool] = None):
Expand Down Expand Up @@ -634,10 +655,8 @@ def _deserialize(
else:
partial_is_collection = is_collection(partial)
for attr_name, field_obj in self.load_fields.items():
field_name = (
field_obj.data_key if field_obj.data_key is not None else attr_name
)
raw_value = data.get(field_name, missing)
data_key = typing.cast(str, field_obj.data_key)
raw_value = data.get(data_key, missing)
if raw_value is missing:
# Ignore missing field if we're allowed to.
if partial is True or (
Expand All @@ -647,7 +666,7 @@ def _deserialize(
d_kwargs = {}
# Allow partial loading of nested schemas.
if partial_is_collection:
prefix = field_name + "."
prefix = data_key + "."
len_prefix = len(prefix)
sub_partial = [
f[len_prefix:] for f in partial if f.startswith(prefix)
Expand All @@ -656,23 +675,20 @@ def _deserialize(
else:
d_kwargs["partial"] = partial
getter = lambda val: field_obj.deserialize(
val, field_name, data, **d_kwargs
val, data_key, data, **d_kwargs
)
value = self._call_and_store(
getter_func=getter,
data=raw_value,
field_name=field_name,
field_name=data_key,
error_store=error_store,
index=index,
)
if value is not missing:
key = field_obj.attribute or attr_name
set_value(ret_d, key, value)
attribute = typing.cast(str, field_obj.attribute)
set_value(ret_d, attribute, value)
if unknown != EXCLUDE:
fields = {
field_obj.data_key if field_obj.data_key is not None else field_name
for field_name, field_obj in self.load_fields.items()
}
fields = {field_obj.data_key for field_obj in self.load_fields.values()}
for key in set(data) - fields:
value = data[key]
if unknown == INCLUDE:
Expand Down Expand Up @@ -975,7 +991,10 @@ def _init_fields(self) -> None:

fields_dict = self.dict_class()
for field_name in field_names:
field_obj = self.declared_fields.get(field_name, ma_fields.Inferred())
field_obj = self.declared_fields.get(
field_name,
ma_fields.Inferred(attribute=field_name, data_key=field_name),
)
self._bind_field(field_name, field_obj)
fields_dict[field_name] = field_obj

Expand All @@ -986,10 +1005,7 @@ def _init_fields(self) -> None:
if not field_obj.load_only:
dump_fields[field_name] = field_obj

dump_data_keys = [
field_obj.data_key if field_obj.data_key is not None else name
for name, field_obj in dump_fields.items()
]
dump_data_keys = [field_obj.data_key for field_obj in dump_fields.values()]
if len(dump_data_keys) != len(set(dump_data_keys)):
data_keys_duplicates = {
x for x in dump_data_keys if dump_data_keys.count(x) > 1
Expand All @@ -1000,7 +1016,7 @@ def _init_fields(self) -> None:
"Check the following field names and "
"data_key arguments: {}".format(list(data_keys_duplicates))
)
load_attributes = [obj.attribute or name for name, obj in load_fields.items()]
load_attributes = [obj.attribute for obj in load_fields.values()]
if len(load_attributes) != len(set(load_attributes)):
attributes_duplicates = {
x for x in load_attributes if load_attributes.count(x) > 1
Expand Down Expand Up @@ -1034,20 +1050,7 @@ def _bind_field(self, field_name: str, field_obj: ma_fields.Field) -> None:
field_obj.load_only = True
if field_name in self.dump_only:
field_obj.dump_only = True
try:
field_obj._bind_to_schema(field_name, self)
except TypeError as error:
# Field declared as a class, not an instance. Ignore type checking because
# we handle unsupported arg types, i.e. this is dead code from
# the type checker's perspective.
if isinstance(field_obj, type) and issubclass(field_obj, base.FieldABC):
msg = (
'Field for "{}" must be declared as a '
"Field instance, not a class. "
'Did you mean "fields.{}()"?'.format(field_name, field_obj.__name__)
)
raise TypeError(msg) from error
raise error
field_obj._bind_to_schema(field_name, self)
self.on_bind_field(field_name, field_obj)

@lru_cache(maxsize=8)
Expand Down Expand Up @@ -1110,13 +1113,11 @@ def _invoke_field_validators(self, *, error_store: ErrorStore, data, many: bool)
continue
raise ValueError(f'"{field_name}" field does not exist.') from error

data_key = (
field_obj.data_key if field_obj.data_key is not None else field_name
)
data_key = field_obj.data_key
if many:
for idx, item in enumerate(data):
try:
value = item[field_obj.attribute or field_name]
value = item[field_obj.attribute]
except KeyError:
pass
else:
Expand All @@ -1131,7 +1132,7 @@ def _invoke_field_validators(self, *, error_store: ErrorStore, data, many: bool)
data[idx].pop(field_name, None)
else:
try:
value = data[field_obj.attribute or field_name]
value = data[field_obj.attribute]
except KeyError:
pass
else:
Expand Down
20 changes: 19 additions & 1 deletion tests/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_repr(self):
default = "œ∑´"
field = fields.Field(dump_default=default, attribute=None)
assert repr(field) == (
"<fields.Field(dump_default={0!r}, attribute=None, "
"<fields.Field(dump_default={0!r}, attribute=None, data_key=None, "
"validate=None, required=False, "
"load_only=False, dump_only=False, "
"load_default={missing}, allow_none=False, "
Expand Down Expand Up @@ -92,6 +92,24 @@ class MySchema(Schema):
result = MySchema().dump({"name": "Monty", "foo": 42})
assert result == {"_NaMe": "Monty"}

def test_data_key_defaults_to_field_name(self):
class MySchema(Schema):
field_1 = fields.String(data_key="field_one")
field_2 = fields.String()

schema_fields = MySchema().fields
assert schema_fields["field_1"].data_key == "field_one"
assert schema_fields["field_2"].data_key == "field_2"

def test_attribute_defaults_to_field_name(self):
class MySchema(Schema):
field_1 = fields.String(attribute="field_one")
field_2 = fields.String()

schema_fields = MySchema().fields
assert schema_fields["field_1"].attribute == "field_one"
assert schema_fields["field_2"].attribute == "field_2"


class TestParentAndName:
class MySchema(Schema):
Expand Down
7 changes: 3 additions & 4 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,11 +553,10 @@ def test_function_field(serialized_user, user):


def test_fields_must_be_declared_as_instances(user):
class BadUserSchema(Schema):
name = fields.String

with pytest.raises(TypeError, match="must be declared as a Field instance"):
BadUserSchema().dump(user)

class BadUserSchema(Schema):
name = fields.String


# regression test
Expand Down