Skip to content

Commit

Permalink
Rename container into inner in container fields
Browse files Browse the repository at this point in the history
  • Loading branch information
lafrech committed May 29, 2019
1 parent fbddbd7 commit 9f4813d
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 72 deletions.
106 changes: 53 additions & 53 deletions src/marshmallow/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
)


CONTAINER_MODIFIERS = ['only', 'exclude']
INNER_MODIFIERS = ['only', 'exclude']


class Field(FieldABC):
Expand Down Expand Up @@ -356,7 +356,7 @@ def _deserialize(self, value, attr, data, **kwargs):
return value

def get_modifiers(self):
return {attr: None for attr in CONTAINER_MODIFIERS}
return {attr: None for attr in INNER_MODIFIERS}

def set_modifiers(self, modifiers):
pass
Expand Down Expand Up @@ -516,7 +516,7 @@ def _deserialize(self, value, attr, data, partial=None, **kwargs):
return self._load(value, data, partial=partial)

def get_modifiers(self):
return {attr: getattr(self, attr) for attr in CONTAINER_MODIFIERS}
return {attr: getattr(self, attr) for attr in INNER_MODIFIERS}

def set_modifiers(self, modifiers):
for attr, value in modifiers.items():
Expand Down Expand Up @@ -567,13 +567,13 @@ def _deserialize(self, value, attr, data, partial=None, **kwargs):
class ContainerMixin:
"""Common methods for container fields"""

def init_container_modifiers(self, container):
for attr, value in container.get_modifiers().items():
def init_inner_modifiers(self, inner):
for attr, value in inner.get_modifiers().items():
setattr(self, attr, value)

def set_container_modifiers(self, container):
container.set_modifiers({
attr: getattr(self, attr) for attr in CONTAINER_MODIFIERS
def set_inner_modifiers(self, inner):
inner.set_modifiers({
attr: getattr(self, attr) for attr in INNER_MODIFIERS
})


Expand All @@ -599,29 +599,29 @@ class List(ContainerMixin, Field):
def __init__(self, cls_or_instance, **kwargs):
super().__init__(**kwargs)
try:
self.container = resolve_field_instance(cls_or_instance)
self.inner = resolve_field_instance(cls_or_instance)
except FieldInstanceResolutionError:
raise ValueError(
'The list elements must be a subclass or instance of '
'marshmallow.base.FieldABC.',
)
self.init_container_modifiers(self.container)
self.init_inner_modifiers(self.inner)

def _bind_to_schema(self, field_name, schema):
super()._bind_to_schema(field_name, schema)
self.container = copy.deepcopy(self.container)
self.container.parent = self
self.container.name = field_name
self.set_container_modifiers(self.container)
self.inner = copy.deepcopy(self.inner)
self.inner.parent = self
self.inner.name = field_name
self.set_inner_modifiers(self.inner)

def _serialize(self, value, attr, obj, **kwargs):
if value is None:
return None
if utils.is_collection(value):
return [
self.container._serialize(each, attr, obj, **kwargs) for each in value
self.inner._serialize(each, attr, obj, **kwargs) for each in value
]
return [self.container._serialize(value, attr, obj, **kwargs)]
return [self.inner._serialize(value, attr, obj, **kwargs)]

def _deserialize(self, value, attr, data, **kwargs):
if not utils.is_collection(value):
Expand All @@ -631,7 +631,7 @@ def _deserialize(self, value, attr, data, **kwargs):
errors = {}
for idx, each in enumerate(value):
try:
result.append(self.container.deserialize(each))
result.append(self.inner.deserialize(each))
except ValidationError as error:
if error.valid_data is not None:
result.append(error.valid_data)
Expand Down Expand Up @@ -682,27 +682,27 @@ def __init__(self, tuple_fields, *args, **kwargs):
)

self.validate_length = Length(equal=len(self.tuple_fields))
for container in self.tuple_fields:
self.init_container_modifiers(container)
for inner in self.tuple_fields:
self.init_inner_modifiers(inner)

def _bind_to_schema(self, field_name, schema):
super()._bind_to_schema(field_name, schema)
new_tuple_fields = []
for container in self.tuple_fields:
container = copy.deepcopy(container)
container.parent = self
container.name = field_name
new_tuple_fields.append(container)
self.set_container_modifiers(container)
for inner in self.tuple_fields:
inner = copy.deepcopy(inner)
inner.parent = self
inner.name = field_name
new_tuple_fields.append(inner)
self.set_inner_modifiers(inner)
self.tuple_fields = new_tuple_fields

def _serialize(self, value, attr, obj, **kwargs):
if value is None:
return None

return tuple(
container._serialize(each, attr, obj, **kwargs)
for container, each in zip(self.tuple_fields, value)
inner._serialize(each, attr, obj, **kwargs)
for inner, each in zip(self.tuple_fields, value)
)

def _deserialize(self, value, attr, data, **kwargs):
Expand All @@ -714,9 +714,9 @@ def _deserialize(self, value, attr, data, **kwargs):
result = []
errors = {}

for idx, (container, each) in enumerate(zip(self.tuple_fields, value)):
for idx, (inner, each) in enumerate(zip(self.tuple_fields, value)):
try:
result.append(container.deserialize(each))
result.append(inner.deserialize(each))
except ValidationError as error:
if error.valid_data is not None:
result.append(error.valid_data)
Expand Down Expand Up @@ -1302,66 +1302,66 @@ class Mapping(ContainerMixin, Field):
def __init__(self, keys=None, values=None, **kwargs):
super().__init__(**kwargs)
if keys is None:
self.key_container = None
self.key_inner = None
else:
try:
self.key_container = resolve_field_instance(keys)
self.key_inner = resolve_field_instance(keys)
except FieldInstanceResolutionError:
raise ValueError(
'"keys" must be a subclass or instance of '
'marshmallow.base.FieldABC.',
)

if values is None:
self.value_container = None
self.value_inner = None
else:
try:
self.value_container = resolve_field_instance(values)
self.value_inner = resolve_field_instance(values)
except FieldInstanceResolutionError:
raise ValueError(
'"values" must be a subclass or instance of '
'marshmallow.base.FieldABC.',
)
self.init_container_modifiers(self.value_container)
self.init_inner_modifiers(self.value_inner)

def _bind_to_schema(self, field_name, schema):
super()._bind_to_schema(field_name, schema)
if self.value_container:
self.value_container = copy.deepcopy(self.value_container)
self.value_container.parent = self
self.value_container.name = field_name
self.set_container_modifiers(self.value_container)
if self.key_container:
self.key_container = copy.deepcopy(self.key_container)
self.key_container.parent = self
self.key_container.name = field_name
if self.value_inner:
self.value_inner = copy.deepcopy(self.value_inner)
self.value_inner.parent = self
self.value_inner.name = field_name
self.set_inner_modifiers(self.value_inner)
if self.key_inner:
self.key_inner = copy.deepcopy(self.key_inner)
self.key_inner.parent = self
self.key_inner.name = field_name

def _serialize(self, value, attr, obj, **kwargs):
if value is None:
return None
if not self.value_container and not self.key_container:
if not self.value_inner and not self.key_inner:
return value
if not isinstance(value, _Mapping):
self.fail('invalid')

#  Serialize keys
if self.key_container is None:
if self.key_inner is None:
keys = {k: k for k in value.keys()}
else:
keys = {
k: self.key_container._serialize(k, None, None, **kwargs)
k: self.key_inner._serialize(k, None, None, **kwargs)
for k in value.keys()
}

#  Serialize values
result = self.mapping_type()
if self.value_container is None:
if self.value_inner is None:
for k, v in value.items():
if k in keys:
result[keys[k]] = v
else:
for k, v in value.items():
result[keys[k]] = self.value_container._serialize(
result[keys[k]] = self.value_inner._serialize(
v, None, None, **kwargs
)

Expand All @@ -1370,32 +1370,32 @@ def _serialize(self, value, attr, obj, **kwargs):
def _deserialize(self, value, attr, data, **kwargs):
if not isinstance(value, _Mapping):
self.fail('invalid')
if not self.value_container and not self.key_container:
if not self.value_inner and not self.key_inner:
return value

errors = collections.defaultdict(dict)

#  Deserialize keys
if self.key_container is None:
if self.key_inner is None:
keys = {k: k for k in value.keys()}
else:
keys = {}
for key in value.keys():
try:
keys[key] = self.key_container.deserialize(key)
keys[key] = self.key_inner.deserialize(key)
except ValidationError as error:
errors[key]['key'] = error.messages

#  Deserialize values
result = self.mapping_type()
if self.value_container is None:
if self.value_inner is None:
for k, v in value.items():
if k in keys:
result[keys[k]] = v
else:
for key, val in value.items():
try:
deser_val = self.value_container.deserialize(val)
deser_val = self.value_inner.deserialize(val)
except ValidationError as error:
errors[key]['value'] = error.messages
if error.valid_data is not None and key in keys:
Expand Down
38 changes: 19 additions & 19 deletions tests/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,32 +122,32 @@ def test_unbound_field_root_returns_none(self):
assert inner_field.root is None

def test_list_field_inner_parent_and_name(self, schema):
assert schema.fields['bar'].container.parent == schema.fields['bar']
assert schema.fields['bar'].container.name == 'bar'
assert schema.fields['bar'].inner.parent == schema.fields['bar']
assert schema.fields['bar'].inner.name == 'bar'

def test_tuple_field_inner_parent_and_name(self, schema):
for container in schema.fields['baz'].tuple_fields:
assert container.parent == schema.fields['baz']
assert container.name == 'baz'
for inner in schema.fields['baz'].tuple_fields:
assert inner.parent == schema.fields['baz']
assert inner.name == 'baz'

def test_simple_field_root(self, schema):
assert schema.fields['foo'].root == schema
assert schema.fields['bar'].root == schema

def test_list_field_inner_root(self, schema):
assert schema.fields['bar'].container.root == schema
assert schema.fields['bar'].inner.root == schema

def test_tuple_field_inner_root(self, schema):
for container in schema.fields['baz'].tuple_fields:
assert container.root == schema
for inner in schema.fields['baz'].tuple_fields:
assert inner.root == schema

def test_list_root_inheritance(self, schema):
class OtherSchema(TestParentAndName.MySchema):
pass

schema2 = OtherSchema()
assert schema.fields['bar'].container.root == schema
assert schema2.fields['bar'].container.root == schema2
assert schema.fields['bar'].inner.root == schema
assert schema2.fields['bar'].inner.root == schema2

def test_dict_root_inheritance(self):
class MySchema(Schema):
Expand All @@ -158,10 +158,10 @@ class OtherSchema(MySchema):

schema = MySchema()
schema2 = OtherSchema()
assert schema.fields['foo'].key_container.root == schema
assert schema.fields['foo'].value_container.root == schema
assert schema2.fields['foo'].key_container.root == schema2
assert schema2.fields['foo'].value_container.root == schema2
assert schema.fields['foo'].key_inner.root == schema
assert schema.fields['foo'].value_inner.root == schema
assert schema2.fields['foo'].key_inner.root == schema2
assert schema2.fields['foo'].value_inner.root == schema2


class TestMetadata:
Expand Down Expand Up @@ -257,7 +257,7 @@ class Family(Schema):
children = fields.List(fields.Nested(Child))

schema = Family(**{param: ['children.name']})
assert getattr(schema.fields['children'].container, param) == {'name'}
assert getattr(schema.fields['children'].inner, param) == {'name'}

@pytest.mark.parametrize('param', ('only', 'exclude'))
def test_list_nested_only_and_exclude_merged_with_nested(self, param):
Expand All @@ -275,7 +275,7 @@ class Family(Schema):
'only': {'name'},
'exclude': {'name', 'surname', 'age'},
}[param]
assert getattr(schema.fields['children'].container, param) == expected
assert getattr(schema.fields['children'].inner, param) == expected

class TestTupleNested:

Expand Down Expand Up @@ -306,7 +306,7 @@ class Family(Schema):
(
fields.Nested(Child, **{param: ('name', 'surname')}),
fields.Nested(Child, **{param: ('name', 'surname')}),
)
),
)

schema = Family(**{param: ['children.name', 'children.age']})
Expand All @@ -330,7 +330,7 @@ class Family(Schema):
children = fields.Dict(values=fields.Nested(Child))

schema = Family(**{param: ['children.name']})
assert getattr(schema.fields['children'].value_container, param) == {'name'}
assert getattr(schema.fields['children'].value_inner, param) == {'name'}

@pytest.mark.parametrize('param', ('only', 'exclude'))
def test_dict_nested_only_and_exclude_merged_with_nested(self, param):
Expand All @@ -348,4 +348,4 @@ class Family(Schema):
'only': {'name'},
'exclude': {'name', 'surname', 'age'},
}[param]
assert getattr(schema.fields['children'].value_container, param) == expected
assert getattr(schema.fields['children'].value_inner, param) == expected

0 comments on commit 9f4813d

Please sign in to comment.