Skip to content

Commit

Permalink
Improve Optional types handling, add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
romis2012 committed Oct 26, 2021
1 parent 87caa9b commit a1238ff
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 21 deletions.
2 changes: 1 addition & 1 deletion pydantic_collections/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
__title__ = 'pydantic-collections'
__version__ = '0.1.0'
__version__ = '0.1.1'

from ._base_collection_model import BaseCollectionModel

Expand Down
27 changes: 19 additions & 8 deletions pydantic_collections/_base_collection_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,8 @@ def _validate_element(self, value, index):
if self.__config__.validate_assignment_strict:
if self.__el_field__.allow_none and value is None:
pass # pragma: no cover
elif not isinstance(value, self.__el_field__.type_):
error = ArbitraryTypeError(expected_arbitrary_type=self.__el_field__.type_)
raise ValidationError(
[ErrorWrapper(exc=error, loc='{} -> {}'.format('__root__', index))],
self.__class__,
)
else:
self._validate_element_type(self.__el_field__, value, index)

value, err = self.__el_field__.validate(
value,
Expand All @@ -90,6 +86,21 @@ def _validate_element(self, value, index):

return value

def _validate_element_type(self, field: ModelField, value: Any, index: int):
def get_field_types(fld: ModelField):
if fld.sub_fields:
for sub_field in fld.sub_fields:
yield from get_field_types(sub_field)
else:
yield fld.type_

if not isinstance(value, tuple(get_field_types(field))):
error = ArbitraryTypeError(expected_arbitrary_type=field.type_)
raise ValidationError(
[ErrorWrapper(exc=error, loc='{} -> {}'.format('__root__', index))],
self.__class__,
)

def __len__(self):
return len(self.__root__)

Expand All @@ -107,10 +118,10 @@ def __iter__(self) -> List[T]:
yield from self.__root__

def __repr__(self):
return '{}({!r})'.format(self.__class__.__name__, self.__root__)
return '{}({!r})'.format(self.__class__.__name__, self.__root__) # pragma: no cover

def __str__(self):
return repr(self)
return repr(self) # pragma: no cover

def insert(self, index, value):
self.__root__.insert(index, self._validate_element(value, index))
Expand Down
78 changes: 66 additions & 12 deletions tests/test_collections.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest

from typing import Optional, Union
from datetime import datetime

from pydantic import BaseModel, ValidationError
Expand Down Expand Up @@ -32,7 +32,15 @@ class Config:
validate_assignment_strict = False


data = [
class OptionalIntCollection(BaseCollectionModel[Optional[int]]):
pass


class IntOrOptionalDatetimeCollection(BaseCollectionModel[Union[int, Optional[datetime]]]):
pass


user_data = [
{
'id': 1,
'name': 'Bender',
Expand All @@ -47,11 +55,11 @@ class Config:


def test_collection_validation_serialization():
user0 = User(**data[0])
user1 = User(**data[1])
user0 = User(**user_data[0])
user1 = User(**user_data[1])

users = UserCollection(data)
assert len(users) == len(data)
users = UserCollection(user_data)
assert len(users) == len(user_data)
assert users[0] == user0
assert users[1] == user1

Expand All @@ -61,34 +69,80 @@ def test_collection_validation_serialization():
for (u1, u2) in zip(users, users2):
assert u1 == u2

users3 = UserCollection.parse_obj(users.dict())
for (u1, u2) in zip(users, users3):
assert u1 == u2


def test_collection_sort():
users = UserCollection(data)
users = UserCollection(user_data)
reversed_users = users.sort(key=lambda u: u.id, reverse=True)
assert reversed_users[0] == users[1]
assert reversed_users[1] == users[0]


def test_collection_assignment_validation():
users = UserCollection()
for item in data:
for item in user_data:
users.append(User(**item))

with pytest.raises(ValidationError):
users.append(data[0]) # noqa
users.append(user_data[0]) # noqa

with pytest.raises(ValidationError):
users[0] = data[0]
users[0] = user_data[0]

weak_users = WeakUserCollection()
for d in data:
for d in user_data:
weak_users.append(d) # noqa

for user in weak_users:
assert user.__class__ is User

for (u1, u2) in zip(weak_users, data):
for (u1, u2) in zip(weak_users, user_data):
assert u1 == User(**u2)

with pytest.raises(ValidationError):
weak_users.append('user') # noqa


def test_optional_collection():
data = [1, None]
c = OptionalIntCollection()
for el in data:
c.append(el)

for (item1, item2) in zip(c, data):
assert item1 == item2


def test_union_collection():
data = [1, datetime.utcnow(), None]
c = IntOrOptionalDatetimeCollection()
for el in data:
c.append(el)

for (item1, item2) in zip(c, data):
assert item1 == item2

with pytest.raises(ValidationError):
c.append('data') # noqa


def test_collection_sequence_methods():
users = UserCollection()
for item in user_data:
users.append(User(**item))

assert len(users) == len(user_data)

user0 = User(**user_data[0])
users.insert(0, user0)
assert users[0] == user0
assert len(users) == len(user_data) + 1

users[-1] = user0
assert users[-1] == user0

users.clear()
assert len(users) == 0

0 comments on commit a1238ff

Please sign in to comment.