From e9e0a8ce234e407980ddf27f1ae4d9d1f14ce31d Mon Sep 17 00:00:00 2001 From: florimondmanca Date: Mon, 25 Mar 2019 16:46:54 +0100 Subject: [PATCH 1/5] add unique field prevalidation --- orm/models.py | 29 ++++++++++++++++++++++++----- tests/test_models.py | 19 +++++++++++++++++++ 2 files changed, 43 insertions(+), 5 deletions(-) diff --git a/orm/models.py b/orm/models.py index e50e77f..77e0105 100644 --- a/orm/models.py +++ b/orm/models.py @@ -110,13 +110,13 @@ def filter(self, **kwargs): model_cls = self.model_cls if related_parts: - # Add any implied select_related + #  Add any implied select_related related_str = "__".join(related_parts) if related_str not in select_related: select_related.append(related_str) # Walk the relationships to the actual model class - # against which the comparison is being made. + #  against which the comparison is being made. for part in related_parts: model_cls = model_cls.fields[part].to @@ -162,7 +162,10 @@ async def all(self, **kwargs): expr = self.build_select_expression() rows = await self.database.fetch_all(expr) - return [self.model_cls.from_row(row, select_related=self._select_related) for row in rows] + return [ + self.model_cls.from_row(row, select_related=self._select_related) + for row in rows + ] async def get(self, **kwargs): if kwargs: @@ -186,6 +189,22 @@ async def create(self, **kwargs): ) kwargs = validator.validate(kwargs) + # Verify constraints on unique fields. + unique = {key: kwargs[key] for key, value in fields.items() if value.unique} + unique_error_messages = [] + for key, value in unique.items(): + expr = self.table.select() + column = getattr(self.table.c, key) + expr = expr.where(column == value) + row = await self.database.fetch_one(query=expr) + if row is not None: + text = f"{self.table.name} with {key}='{value}' already exists" + message = typesystem.Message(text=text, code="unique_exists") + unique_error_messages.append(message) + + if unique_error_messages: + raise typesystem.ValidationError(messages=unique_error_messages) + # Build the insert expression. expr = self.table.insert() expr = expr.values(**kwargs) @@ -261,8 +280,8 @@ def from_row(cls, row, select_related=[]): # Instantiate any child instances first. for related in select_related: - if '__' in related: - first_part, remainder = related.split('__', 1) + if "__" in related: + first_part, remainder = related.split("__", 1) model_cls = cls.fields[first_part].to item[first_part] = model_cls.from_row(row, select_related=[remainder]) else: diff --git a/tests/test_models.py b/tests/test_models.py index 1df30cf..42f3e50 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -6,6 +6,7 @@ import databases import orm +import typesystem DATABASE_URL = "sqlite:///test.db" database = databases.Database(DATABASE_URL, force_rollback=True) @@ -32,6 +33,15 @@ class Product(orm.Model): in_stock = orm.Boolean(default=False) +class Post(orm.Model): + __tablename__ = "post" + __metadata__ = metadata + __database__ = database + + id = orm.Integer(primary_key=True) + slug = orm.String(max_length=100, unique=True) + + @pytest.fixture(autouse=True, scope="module") def create_test_database(): engine = sqlalchemy.create_engine(DATABASE_URL) @@ -137,3 +147,12 @@ async def test_model_filter(): products = await Product.objects.all(name__icontains="T") assert len(products) == 2 + + +@async_adapter +async def test_validate_unique(): + slug = "hello-world" + async with database: + await Post.objects.create(slug=slug) + with pytest.raises(typesystem.ValidationError): + await Post.objects.create(slug=slug) From 1a0af398e59a91d9bd87f6d89e4f4f72e2890d59 Mon Sep 17 00:00:00 2001 From: florimondmanca Date: Tue, 26 Mar 2019 18:26:40 +0100 Subject: [PATCH 2/5] add unique prevalidation in update + check for pk --- orm/models.py | 84 +++++++++++++++++++++++++++++++++----------- tests/test_models.py | 8 +++++ 2 files changed, 71 insertions(+), 21 deletions(-) diff --git a/orm/models.py b/orm/models.py index 77e0105..5f1457c 100644 --- a/orm/models.py +++ b/orm/models.py @@ -48,6 +48,55 @@ def __new__( return new_model +class ModelObject(typesystem.Object): + # NOTE: to be updated if/when typesystem handles validation of unique fieldss. + + def __init__(self, *args, queryset: "QuerySet" = None, **kwargs): + super().__init__(*args, **kwargs) + self.queryset = queryset + + async def validate_unique(self, value: dict): + if self.queryset is None: + return + + unique_properties = { + key: value[key] for key, val in self.properties.items() if val.unique + } + + error_messages = [] + + for key, val in unique_properties.items(): + expr = self.queryset.table.select() + column = getattr(self.queryset.table.c, key) + expr = expr.where(column == val) + row = await self.queryset.database.fetch_one(query=expr) + if row is not None: + text = f"{self.queryset.table.name} with {key}='{val}' already exists" + message = typesystem.Message(text=text, code="unique_exists") + error_messages.append(message) + + if error_messages: + raise typesystem.ValidationError(messages=error_messages) + + async def validate(self, value: dict, strict: bool = False) -> typing.Any: + messages = [] + + try: + validated = super().validate(value, strict=strict) + except typesystem.ValidationError as exc: + messages.extend(exc.messages()) + + try: + await self.validate_unique(value) + except typesystem.ValidationError as exc: + messages.extend(exc.messages()) + + if messages: + raise typesystem.ValidationError(messages=messages) + + return validated + + class QuerySet: def __init__(self, model_cls=None, filter_clauses=None, select_related=None): self.model_cls = model_cls @@ -184,26 +233,13 @@ async def create(self, **kwargs): # Validate the keyword arguments. fields = self.model_cls.fields required = [key for key, value in fields.items() if not value.has_default()] - validator = typesystem.Object( - properties=fields, required=required, additional_properties=False + validator = ModelObject( + properties=fields, + required=required, + additional_properties=False, + queryset=self, ) - kwargs = validator.validate(kwargs) - - # Verify constraints on unique fields. - unique = {key: kwargs[key] for key, value in fields.items() if value.unique} - unique_error_messages = [] - for key, value in unique.items(): - expr = self.table.select() - column = getattr(self.table.c, key) - expr = expr.where(column == value) - row = await self.database.fetch_one(query=expr) - if row is not None: - text = f"{self.table.name} with {key}='{value}' already exists" - message = typesystem.Message(text=text, code="unique_exists") - unique_error_messages.append(message) - - if unique_error_messages: - raise typesystem.ValidationError(messages=unique_error_messages) + kwargs = await validator.validate(kwargs) # Build the insert expression. expr = self.table.insert() @@ -234,10 +270,16 @@ def pk(self, value): setattr(self, self.__pkname__, value) async def update(self, **kwargs): + # Prevent primary key from being updated. + if "pk" in kwargs or self.__pkname__ in kwargs: + raise ValueError( + f"the primary key of a model instance cannot be updated" + ) + # Validate the keyword arguments. fields = {key: field for key, field in self.fields.items() if key in kwargs} - validator = typesystem.Object(properties=fields) - kwargs = validator.validate(kwargs) + validator = ModelObject(properties=fields, queryset=self.objects) + kwargs = await validator.validate(kwargs) # Build the update expression. pk_column = getattr(self.__table__.c, self.__pkname__) diff --git a/tests/test_models.py b/tests/test_models.py index 42f3e50..5552bef 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -100,6 +100,9 @@ async def test_model_crud(): assert user.pk is not None assert users == [user] + with pytest.raises(ValueError): + await user.update(pk=42) + await user.delete() users = await User.objects.all() assert users == [] @@ -154,5 +157,10 @@ async def test_validate_unique(): slug = "hello-world" async with database: await Post.objects.create(slug=slug) + with pytest.raises(typesystem.ValidationError): await Post.objects.create(slug=slug) + + other = await Post.objects.create(slug="world-hello") + with pytest.raises(typesystem.ValidationError): + await other.update(slug=slug) From 6cdbdc81ae00546caaca50b5920babdb9e1b030f Mon Sep 17 00:00:00 2001 From: florimondmanca Date: Tue, 26 Mar 2019 18:30:35 +0100 Subject: [PATCH 3/5] fix test coverage --- orm/models.py | 5 +---- tests/test_models.py | 3 +++ 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/orm/models.py b/orm/models.py index 5f1457c..2729242 100644 --- a/orm/models.py +++ b/orm/models.py @@ -51,14 +51,11 @@ def __new__( class ModelObject(typesystem.Object): # NOTE: to be updated if/when typesystem handles validation of unique fieldss. - def __init__(self, *args, queryset: "QuerySet" = None, **kwargs): + def __init__(self, *args, queryset: "QuerySet", **kwargs): super().__init__(*args, **kwargs) self.queryset = queryset async def validate_unique(self, value: dict): - if self.queryset is None: - return - unique_properties = { key: value[key] for key, val in self.properties.items() if val.unique } diff --git a/tests/test_models.py b/tests/test_models.py index 5552bef..5f41bea 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -91,6 +91,9 @@ async def test_model_crud(): assert user.pk is not None assert users == [user] + with pytest.raises(typesystem.ValidationError): + await User.objects.create(foo="is not a User field") + lookup = await User.objects.get() assert lookup == user From d7dacf19963600f27da0bf3a50f267343461dcdd Mon Sep 17 00:00:00 2001 From: Florimond Manca Date: Tue, 26 Mar 2019 18:31:07 +0100 Subject: [PATCH 4/5] Fix typo --- orm/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/orm/models.py b/orm/models.py index 2729242..50fa8a6 100644 --- a/orm/models.py +++ b/orm/models.py @@ -49,7 +49,7 @@ def __new__( class ModelObject(typesystem.Object): - # NOTE: to be updated if/when typesystem handles validation of unique fieldss. + # NOTE: to be updated if/when typesystem handles validation of unique fields. def __init__(self, *args, queryset: "QuerySet", **kwargs): super().__init__(*args, **kwargs) From a0a16030bbaa9c2dc9aac05bd21a3528833e26b2 Mon Sep 17 00:00:00 2001 From: Florimond Manca Date: Tue, 26 Mar 2019 18:32:05 +0100 Subject: [PATCH 5/5] fix double spaces --- orm/models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/orm/models.py b/orm/models.py index 50fa8a6..60ad47c 100644 --- a/orm/models.py +++ b/orm/models.py @@ -162,7 +162,7 @@ def filter(self, **kwargs): select_related.append(related_str) # Walk the relationships to the actual model class - #  against which the comparison is being made. + # against which the comparison is being made. for part in related_parts: model_cls = model_cls.fields[part].to @@ -336,7 +336,7 @@ def from_row(cls, row, select_related=[]): def __setattr__(self, key, value): if key in self.fields: - #  Setting a relationship to a raw pk value should set a + # Setting a relationship to a raw pk value should set a # fully-fledged relationship instance, with just the pk loaded. value = self.fields[key].expand_relationship(value) super().__setattr__(key, value)