Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Let data_key and attribute default to field name #1897

Open
wants to merge 7 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/marshmallow/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,7 @@ def _bind_to_schema(self, field_name, schema):
"""
self.parent = self.parent or schema
self.name = self.name or field_name
self.data_key = self.data_key if self.data_key is not None else field_name
self.root = self.root or (
self.parent.root if isinstance(self.parent, FieldABC) else self.parent
)
Expand Down Expand Up @@ -688,7 +689,7 @@ def __init__(
@property
def _field_data_key(self):
only_field = self.schema.fields[self.field_name]
return only_field.data_key or self.field_name
return only_field.data_key

def _serialize(self, nested_obj, attr, obj, **kwargs):
ret = super()._serialize(nested_obj, attr, obj, **kwargs)
Expand Down
21 changes: 5 additions & 16 deletions src/marshmallow/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,8 +520,7 @@ def _serialize(
value = field_obj.serialize(attr_name, obj, accessor=self.get_attribute)
if value is missing:
continue
key = field_obj.data_key if field_obj.data_key is not None else attr_name
ret[key] = value
ret[field_obj.data_key] = value
return ret

def dump(self, obj: typing.Any, *, many: typing.Optional[bool] = None):
Expand Down Expand Up @@ -634,9 +633,7 @@ def _deserialize(
else:
partial_is_collection = is_collection(partial)
for attr_name, field_obj in self.load_fields.items():
field_name = (
field_obj.data_key if field_obj.data_key is not None else attr_name
)
field_name = field_obj.data_key
raw_value = data.get(field_name, missing)
Copy link
Member

@deckar01 deckar01 Oct 27, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

error: Argument 1 to "get" of "Mapping" has incompatible type "Optional[str]"; expected "str"

The problem is that Fields.__init__ assigned a Optional[str] type to data_key. _bind_to_schema guarantees that field_obj.data_key is a str, but the type checker doesn't know that _deserialize is guaranteed to run after _bind_to_schema. I guess the get type annotation enforces this argument, but AFAICT, None doesn't actually error. 🤷

It wants you to ensure data_key is never None. The only thing I can think of is to make data_key a @property that raises an exception if the field is unbound.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we agree on the reason behind the error.

It would be a shame to "complicate" the code, even slightly, to please the linter.

I suppose we could add comments to skip test on each failing line, but I'd rather find a way to centralize that by somehow forcing the type of self.data_key once and for all.

We could use a cast:

self.data_key = typing.cast(str, data_key)

but doing this in __init__ before field binding would be wrong.

Or we do it just before the type check error:

field_name = typing.cast(str, field_obj.data_key)

Still not ideal because:

  • It is not done everywhere, only where the error was triggered. Other errors may occur later here or in other libs.
  • It is executed on each deserialization, although IIUC it is neglectable.

I just pushed a commit doing that.

if raw_value is missing:
# Ignore missing field if we're allowed to.
Expand Down Expand Up @@ -669,10 +666,7 @@ def _deserialize(
key = field_obj.attribute or attr_name
set_value(ret_d, key, value)
if unknown != EXCLUDE:
fields = {
field_obj.data_key if field_obj.data_key is not None else field_name
for field_name, field_obj in self.load_fields.items()
}
fields = {field_obj.data_key for field_obj in self.load_fields.values()}
for key in set(data) - fields:
value = data[key]
if unknown == INCLUDE:
Expand Down Expand Up @@ -986,10 +980,7 @@ def _init_fields(self) -> None:
if not field_obj.load_only:
dump_fields[field_name] = field_obj

dump_data_keys = [
field_obj.data_key if field_obj.data_key is not None else name
for name, field_obj in dump_fields.items()
]
dump_data_keys = [field_obj.data_key for field_obj in dump_fields.values()]
if len(dump_data_keys) != len(set(dump_data_keys)):
data_keys_duplicates = {
x for x in dump_data_keys if dump_data_keys.count(x) > 1
Expand Down Expand Up @@ -1110,9 +1101,7 @@ def _invoke_field_validators(self, *, error_store: ErrorStore, data, many: bool)
continue
raise ValueError(f'"{field_name}" field does not exist.') from error

data_key = (
field_obj.data_key if field_obj.data_key is not None else field_name
)
data_key = field_obj.data_key
if many:
for idx, item in enumerate(data):
try:
Expand Down
9 changes: 9 additions & 0 deletions tests/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,15 @@ class MySchema(Schema):
result = MySchema().dump({"name": "Monty", "foo": 42})
assert result == {"_NaMe": "Monty"}

def test_data_key_defaults_to_field_name(self):
class MySchema(Schema):
field_1 = fields.String(data_key="field_one")
field_2 = fields.String()

schema_fields = MySchema().fields
assert schema_fields["field_1"].data_key == "field_one"
assert schema_fields["field_2"].data_key == "field_2"


class TestParentAndName:
class MySchema(Schema):
Expand Down