diff --git a/src/marshmallow/fields.py b/src/marshmallow/fields.py index 91f8e8071..bca3cdb9a 100644 --- a/src/marshmallow/fields.py +++ b/src/marshmallow/fields.py @@ -550,6 +550,7 @@ def schema(self): context=context, load_only=self._nested_normalized_option("load_only"), dump_only=self._nested_normalized_option("dump_only"), + required=self._nested_normalized_option("required"), ) return self._schema diff --git a/src/marshmallow/schema.py b/src/marshmallow/schema.py index 78b386717..56eada560 100644 --- a/src/marshmallow/schema.py +++ b/src/marshmallow/schema.py @@ -227,6 +227,7 @@ def __init__(self, meta, ordered: bool = False): self.include = getattr(meta, "include", {}) self.load_only = getattr(meta, "load_only", ()) self.dump_only = getattr(meta, "dump_only", ()) + self.required = getattr(meta, "required", ()) self.unknown = getattr(meta, "unknown", RAISE) self.register = getattr(meta, "register", True) @@ -273,6 +274,7 @@ class AlbumSchema(Schema): :class:`fields.Function` fields. :param load_only: Fields to skip during serialization (write-only fields) :param dump_only: Fields to skip during deserialization (read-only fields) + :param required: Fields to be considered required. :param partial: Whether to ignore missing fields and not require any fields declared. Propagates down to ``Nested`` fields as well. If its value is an iterable, only missing fields listed in that iterable @@ -354,6 +356,7 @@ class Meta: of invalid items in a collection. - ``load_only``: Tuple or list of fields to exclude from serialized results. - ``dump_only``: Tuple or list of fields to exclude from deserialization + - ``required``: Tuple or list of fields to be considered required. - ``unknown``: Whether to exclude, include, or raise an error for unknown fields in the data. Use `EXCLUDE`, `INCLUDE` or `RAISE`. - ``register``: Whether to register the `Schema` with marshmallow's internal @@ -371,6 +374,7 @@ def __init__( context: typing.Dict = None, load_only: types.StrSequenceOrSet = (), dump_only: types.StrSequenceOrSet = (), + required: types.StrSequenceOrSet = (), partial: typing.Union[bool, types.StrSequenceOrSet] = False, unknown: str = None ): @@ -387,6 +391,7 @@ def __init__( self.ordered = self.opts.ordered self.load_only = set(load_only) or set(self.opts.load_only) self.dump_only = set(dump_only) or set(self.opts.dump_only) + self.required = set(required) or set(self.opts.required) self.partial = partial self.unknown = unknown or self.opts.unknown self.context = context or {} @@ -1028,13 +1033,15 @@ def _bind_field(self, field_name: str, field_obj: ma_fields.Field) -> None: """Bind field to the schema, setting any necessary attributes on the field (e.g. parent and name). - Also set field load_only and dump_only values if field_name was + Also set field load_only, dump_only and required values if field_name was specified in ``class Meta``. """ if field_name in self.load_only: field_obj.load_only = True if field_name in self.dump_only: field_obj.dump_only = True + if field_name in self.required: + field_obj.required = True try: field_obj._bind_to_schema(field_name, self) except TypeError as error: diff --git a/tests/test_schema.py b/tests/test_schema.py index 85646754b..e33decbd3 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -1394,6 +1394,61 @@ def test_dump_only(self, schema, data): assert "str_regular" in grand_child +def test_deeply_nested_required(): + class GrandChildSchema(Schema): + str_required = fields.String() + str_regular = fields.String() + + class ChildSchema(Schema): + str_required = fields.String() + str_regular = fields.String() + grand_child = fields.Nested(GrandChildSchema) + + class ParentSchema(Schema): + str_required = fields.String() + str_regular = fields.String() + child = fields.Nested(ChildSchema) + + schema = ParentSchema( + required=( + "str_required", + "child.str_required", + "child.grand_child.str_required", + ), + ) + + valid_data = { + "str_required": "Required", + "child": { + "str_required": "Required", + "grand_child": { + "str_required": "Required", + }, + }, + } + + # Assert no exception + schema.load(valid_data) + + data = valid_data.copy() + del data["str_required"] + with pytest.raises(ValidationError) as excinfo: + schema.load(data) + assert "str_required" in excinfo.value.messages + + data = valid_data.copy() + del data["child"]["str_required"] + with pytest.raises(ValidationError) as excinfo: + schema.load(data) + assert "str_required" in excinfo.value.messages["child"] + + data = valid_data.copy() + del data["child"]["grand_child"]["str_required"] + with pytest.raises(ValidationError) as excinfo: + schema.load(data) + assert "str_required" in excinfo.value.messages["child"]["grand_child"] + + class TestDeeplyNestedListLoadOnly: @pytest.fixture() def schema(self): @@ -1449,6 +1504,38 @@ def test_dump_only(self, schema, data): assert "str_regular" in child +def test_deeply_nested_list_required(): + class ChildSchema(Schema): + str_required = fields.String() + str_regular = fields.String() + + class ParentSchema(Schema): + str_required = fields.String() + str_regular = fields.String() + child = fields.List(fields.Nested(ChildSchema)) + + schema = ParentSchema( + required=("str_required", "child.str_required"), + ) + + valid_data = {"str_required": "Required", "child": [{"str_required": "Required"}]} + + # Assert no exception + schema.load(valid_data) + + data = valid_data.copy() + del data["str_required"] + with pytest.raises(ValidationError) as excinfo: + schema.load(data) + assert "str_required" in excinfo.value.messages + + data = valid_data.copy() + del data["child"][0]["str_required"] + with pytest.raises(ValidationError) as excinfo: + schema.load(data) + assert "str_required" in excinfo.value.messages["child"][0] + + def test_nested_constructor_only_and_exclude(): class GrandChildSchema(Schema): goo = fields.Field() @@ -2858,6 +2945,27 @@ class NoTldTestSchema(Schema): assert result == data_with_no_top_level_domain +def test_required_in_meta(): + class MySchema(Schema): + class Meta: + required = "str_required" + + str_required = fields.String() + str_regular = fields.String() + + data = { + "str_required": None, + "str_regular": "Regular String", + } + + schema = MySchema() + + with pytest.raises(ValidationError) as excinfo: + schema.load(data) + + assert "str_required" in excinfo.value.messages + + class TestFromDict: def test_generates_schema(self): MySchema = Schema.from_dict({"foo": fields.Str()})