From a1238ff05bdfb5f065a33aede6297636125d61c8 Mon Sep 17 00:00:00 2001 From: Roman Snegirev Date: Tue, 26 Oct 2021 10:44:47 +0300 Subject: [PATCH] Improve Optional types handling, add tests --- pydantic_collections/__init__.py | 2 +- .../_base_collection_model.py | 27 +++++-- tests/test_collections.py | 78 ++++++++++++++++--- 3 files changed, 86 insertions(+), 21 deletions(-) diff --git a/pydantic_collections/__init__.py b/pydantic_collections/__init__.py index 1cdf5fa..5d157cc 100644 --- a/pydantic_collections/__init__.py +++ b/pydantic_collections/__init__.py @@ -1,5 +1,5 @@ __title__ = 'pydantic-collections' -__version__ = '0.1.0' +__version__ = '0.1.1' from ._base_collection_model import BaseCollectionModel diff --git a/pydantic_collections/_base_collection_model.py b/pydantic_collections/_base_collection_model.py index 57f75ee..1fbf51a 100644 --- a/pydantic_collections/_base_collection_model.py +++ b/pydantic_collections/_base_collection_model.py @@ -65,12 +65,8 @@ def _validate_element(self, value, index): if self.__config__.validate_assignment_strict: if self.__el_field__.allow_none and value is None: pass # pragma: no cover - elif not isinstance(value, self.__el_field__.type_): - error = ArbitraryTypeError(expected_arbitrary_type=self.__el_field__.type_) - raise ValidationError( - [ErrorWrapper(exc=error, loc='{} -> {}'.format('__root__', index))], - self.__class__, - ) + else: + self._validate_element_type(self.__el_field__, value, index) value, err = self.__el_field__.validate( value, @@ -90,6 +86,21 @@ def _validate_element(self, value, index): return value + def _validate_element_type(self, field: ModelField, value: Any, index: int): + def get_field_types(fld: ModelField): + if fld.sub_fields: + for sub_field in fld.sub_fields: + yield from get_field_types(sub_field) + else: + yield fld.type_ + + if not isinstance(value, tuple(get_field_types(field))): + error = ArbitraryTypeError(expected_arbitrary_type=field.type_) + raise ValidationError( + [ErrorWrapper(exc=error, loc='{} -> {}'.format('__root__', index))], + self.__class__, + ) + def __len__(self): return len(self.__root__) @@ -107,10 +118,10 @@ def __iter__(self) -> List[T]: yield from self.__root__ def __repr__(self): - return '{}({!r})'.format(self.__class__.__name__, self.__root__) + return '{}({!r})'.format(self.__class__.__name__, self.__root__) # pragma: no cover def __str__(self): - return repr(self) + return repr(self) # pragma: no cover def insert(self, index, value): self.__root__.insert(index, self._validate_element(value, index)) diff --git a/tests/test_collections.py b/tests/test_collections.py index 2916195..7be128c 100644 --- a/tests/test_collections.py +++ b/tests/test_collections.py @@ -1,5 +1,5 @@ import pytest - +from typing import Optional, Union from datetime import datetime from pydantic import BaseModel, ValidationError @@ -32,7 +32,15 @@ class Config: validate_assignment_strict = False -data = [ +class OptionalIntCollection(BaseCollectionModel[Optional[int]]): + pass + + +class IntOrOptionalDatetimeCollection(BaseCollectionModel[Union[int, Optional[datetime]]]): + pass + + +user_data = [ { 'id': 1, 'name': 'Bender', @@ -47,11 +55,11 @@ class Config: def test_collection_validation_serialization(): - user0 = User(**data[0]) - user1 = User(**data[1]) + user0 = User(**user_data[0]) + user1 = User(**user_data[1]) - users = UserCollection(data) - assert len(users) == len(data) + users = UserCollection(user_data) + assert len(users) == len(user_data) assert users[0] == user0 assert users[1] == user1 @@ -61,9 +69,13 @@ def test_collection_validation_serialization(): for (u1, u2) in zip(users, users2): assert u1 == u2 + users3 = UserCollection.parse_obj(users.dict()) + for (u1, u2) in zip(users, users3): + assert u1 == u2 + def test_collection_sort(): - users = UserCollection(data) + users = UserCollection(user_data) reversed_users = users.sort(key=lambda u: u.id, reverse=True) assert reversed_users[0] == users[1] assert reversed_users[1] == users[0] @@ -71,24 +83,66 @@ def test_collection_sort(): def test_collection_assignment_validation(): users = UserCollection() - for item in data: + for item in user_data: users.append(User(**item)) with pytest.raises(ValidationError): - users.append(data[0]) # noqa + users.append(user_data[0]) # noqa with pytest.raises(ValidationError): - users[0] = data[0] + users[0] = user_data[0] weak_users = WeakUserCollection() - for d in data: + for d in user_data: weak_users.append(d) # noqa for user in weak_users: assert user.__class__ is User - for (u1, u2) in zip(weak_users, data): + for (u1, u2) in zip(weak_users, user_data): assert u1 == User(**u2) with pytest.raises(ValidationError): weak_users.append('user') # noqa + + +def test_optional_collection(): + data = [1, None] + c = OptionalIntCollection() + for el in data: + c.append(el) + + for (item1, item2) in zip(c, data): + assert item1 == item2 + + +def test_union_collection(): + data = [1, datetime.utcnow(), None] + c = IntOrOptionalDatetimeCollection() + for el in data: + c.append(el) + + for (item1, item2) in zip(c, data): + assert item1 == item2 + + with pytest.raises(ValidationError): + c.append('data') # noqa + + +def test_collection_sequence_methods(): + users = UserCollection() + for item in user_data: + users.append(User(**item)) + + assert len(users) == len(user_data) + + user0 = User(**user_data[0]) + users.insert(0, user0) + assert users[0] == user0 + assert len(users) == len(user_data) + 1 + + users[-1] = user0 + assert users[-1] == user0 + + users.clear() + assert len(users) == 0