Skip to content

Commit

Permalink
refactor(proofs): reorganise schema inheritance (#343)
Browse files Browse the repository at this point in the history
  • Loading branch information
raphodn authored Jun 27, 2024
1 parent 64aac06 commit 7e17a68
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 70 deletions.
16 changes: 8 additions & 8 deletions app/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@
from app.schemas import (
LocationCreate,
LocationFilter,
PriceCreate,
PriceCreateWithValidation,
PriceFilter,
PriceUpdate,
PriceUpdateWithValidation,
ProductCreate,
ProductFilter,
ProductFull,
ProofBasicUpdatableFields,
ProofFilter,
ProofUpdate,
UserCreate,
)

Expand Down Expand Up @@ -300,7 +300,7 @@ def get_price_by_id(db: Session, id: int) -> Price | None:


def create_price(
db: Session, price: PriceCreate, user: UserCreate, source: str = None
db: Session, price: PriceCreateWithValidation, user: UserCreate, source: str = None
) -> Price:
db_price = Price(**price.model_dump(), owner=user.user_id, source=source)
db.add(db_price)
Expand Down Expand Up @@ -345,7 +345,9 @@ def delete_price(db: Session, db_price: Price) -> bool:
return True


def update_price(db: Session, price: Price, new_values: PriceUpdate) -> Price:
def update_price(
db: Session, price: Price, new_values: PriceUpdateWithValidation
) -> Price:
new_values_cleaned = new_values.model_dump(exclude_unset=True)
for key in new_values_cleaned:
setattr(price, key, new_values_cleaned[key])
Expand Down Expand Up @@ -489,9 +491,7 @@ def set_proof_location(db: Session, proof: Proof, location: Location) -> Proof:
return proof


def update_proof(
db: Session, proof: Proof, new_values: ProofBasicUpdatableFields
) -> Proof:
def update_proof(db: Session, proof: Proof, new_values: ProofUpdate) -> Proof:
new_values_cleaned = new_values.model_dump(exclude_unset=True)
for key in new_values_cleaned:
setattr(proof, key, new_values_cleaned[key])
Expand Down
4 changes: 2 additions & 2 deletions app/routers/prices.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def get_prices(
status_code=status.HTTP_201_CREATED,
)
def create_price(
price: schemas.PriceCreate,
price: schemas.PriceCreateWithValidation,
background_tasks: BackgroundTasks,
current_user: schemas.UserCreate = Depends(get_current_user),
app_name: str | None = None,
Expand Down Expand Up @@ -77,7 +77,7 @@ def create_price(
)
def update_price(
price_id: int,
price_new_values: schemas.PriceUpdate,
price_new_values: schemas.PriceUpdateWithValidation,
current_user: schemas.UserCreate = Depends(get_current_user),
db: Session = Depends(get_db),
) -> Price:
Expand Down
2 changes: 1 addition & 1 deletion app/routers/proofs.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def get_user_proof_by_id(
)
def update_proof(
proof_id: int,
proof_new_values: schemas.ProofBasicUpdatableFields,
proof_new_values: schemas.ProofUpdate,
current_user: schemas.UserCreate = Depends(get_current_user),
db: Session = Depends(get_db),
) -> Proof:
Expand Down
90 changes: 52 additions & 38 deletions app/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,18 +200,28 @@ class LocationFull(LocationCreate):

# Proof
# ------------------------------------------------------------------------------
class ProofFull(BaseModel):
model_config = ConfigDict(from_attributes=True, arbitrary_types_allowed=True)
class ProofBase(BaseModel):
model_config = ConfigDict(
from_attributes=True, arbitrary_types_allowed=True, extra="forbid"
)

id: int
type: ProofTypeEnum | None = None
currency: CurrencyEnum | None = Field(
description="currency of the price, as a string. "
"The currency must be a valid currency code. "
"See https://en.wikipedia.org/wiki/ISO_4217 for a list of valid currency codes.",
examples=["EUR", "USD"],
)
date: datetime.date | None = Field(
description="date of the proof.", examples=["2024-01-01"]
)


class ProofCreate(ProofBase):
# file_path is str | null because we can mask the file path in the response
# if the proof is not public
file_path: str | None
mimetype: str
type: ProofTypeEnum | None = None
price_count: int = Field(
description="number of prices for this proof.", examples=[15], default=0
)
location_osm_id: int | None = Field(
gt=0,
description="ID of the location in OpenStreetMap: the store where the product was bought.",
Expand All @@ -223,17 +233,20 @@ class ProofFull(BaseModel):
"information about the store using the ID.",
examples=["NODE", "WAY", "RELATION"],
)
date: datetime.date | None = Field(
description="date of the proof.", examples=["2024-01-01"]
)
currency: CurrencyEnum | None = Field(
description="currency of the price, as a string. "
"The currency must be a valid currency code. "
"See https://en.wikipedia.org/wiki/ISO_4217 for a list of valid currency codes.",
examples=["EUR", "USD"],
owner: str


@partial_model
class ProofUpdate(ProofBase):
pass


class ProofFull(ProofCreate):
id: int
price_count: int = Field(
description="number of prices for this proof.", examples=[15], default=0
)
location_id: int | None
owner: str
# source: str | None = Field(
# description="Source (App name)",
# examples=["web app", "mobile app"],
Expand All @@ -249,15 +262,6 @@ class ProofFullWithRelations(ProofFull):
location: LocationFull | None


class ProofBasicUpdatableFields(BaseModel):
type: ProofTypeEnum | None = None
currency: CurrencyEnum | None = None
date: datetime.date | None = None

class Config:
extra = "forbid"


# Price
# ------------------------------------------------------------------------------
class PriceBase(BaseModel):
Expand All @@ -278,7 +282,7 @@ class PriceBase(BaseModel):
price_without_discount: float | None = Field(
default=None,
description="price of the product without discount, without its currency, taxes included. "
"If the product is not discounted, this field must be null. ",
"If the product is not discounted, this field must be null.",
examples=[2.99],
)
price_per: PricePerEnum | None = Field(
Expand All @@ -300,6 +304,14 @@ class PriceBase(BaseModel):
description="date when the product was bought.", examples=["2024-01-01"]
)


class PriceBaseWithValidation(PriceBase):
"""A version of `PriceBase` with validations.
These validations are not done in the `PriceCreate` or `PriceUpdate` model
because they are time-consuming and only necessary when creating or
updating a price from the API.
"""

@model_validator(mode="after")
def check_price_discount(self): # type: ignore
"""
Expand Down Expand Up @@ -395,6 +407,19 @@ class PriceCreate(PriceBase):
examples=[15],
)


class PriceCreateWithValidation(PriceBaseWithValidation, PriceCreate):
@field_validator("category_tag")
def category_tag_is_valid(cls, v: str | None) -> str | None:
if v is not None:
v = v.lower()
category_taxonomy = get_taxonomy("category")
if v not in category_taxonomy:
raise ValueError(
f"Invalid category tag: category '{v}' does not exist in the taxonomy"
)
return v

@field_validator("labels_tags")
def labels_tags_is_valid(cls, v: list[str] | None) -> list[str] | None:
if v is not None:
Expand Down Expand Up @@ -423,17 +448,6 @@ def origins_tags_is_valid(cls, v: list[str] | None) -> list[str] | None:
)
return v

@field_validator("category_tag")
def category_tag_is_valid(cls, v: str | None) -> str | None:
if v is not None:
v = v.lower()
category_taxonomy = get_taxonomy("category")
if v not in category_taxonomy:
raise ValueError(
f"Invalid category tag: category '{v}' does not exist in the taxonomy"
)
return v

@model_validator(mode="after")
def product_code_and_category_tag_are_exclusive(self): # type: ignore
"""Validator that checks that `product_code` and `category_tag` are
Expand Down Expand Up @@ -464,7 +478,7 @@ def set_price_per_to_null_if_barcode(self): # type: ignore


@partial_model
class PriceUpdate(PriceBase):
class PriceUpdateWithValidation(PriceBaseWithValidation):
pass


Expand Down
32 changes: 21 additions & 11 deletions tests/integration/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,13 @@
from app.api import app
from app.db import Base, engine, get_db, session
from app.models import Session as SessionModel
from app.schemas import LocationFull, PriceCreate, ProductFull, ProofFilter, UserCreate
from app.schemas import (
LocationFull,
PriceCreateWithValidation,
ProductFull,
ProofFilter,
UserCreate,
)

Base.metadata.drop_all(bind=engine)
Base.metadata.create_all(bind=engine)
Expand Down Expand Up @@ -125,7 +131,7 @@ def override_get_db():
created=datetime.datetime.now(),
updated=datetime.datetime.now(),
)
PRICE_1 = PriceCreate(
PRICE_1 = PriceCreateWithValidation(
product_code="8001505005707",
product_name="PATE NOCCIOLATA BIO 700G",
# category="en:tomatoes",
Expand All @@ -137,7 +143,7 @@ def override_get_db():
location_osm_type="NODE",
date="2023-10-31",
)
PRICE_2 = PriceCreate(
PRICE_2 = PriceCreateWithValidation(
product_code="8001505005707",
product_name="PATE NOCCIOLATA BIO 700G",
price=2.5,
Expand All @@ -148,7 +154,7 @@ def override_get_db():
location_osm_type="NODE",
date="2023-10-31",
)
PRICE_3 = PriceCreate(
PRICE_3 = PriceCreateWithValidation(
product_code="8001505005707",
product_name="PATE NOCCIOLATA BIO 700G",
price=2.5,
Expand Down Expand Up @@ -1093,13 +1099,17 @@ def test_update_proof(
assert response.json()["date"] == "2024-01-01"

# with authentication and proof owner but extra fields
PROOF_UPDATE_PARTIAL_WRONG = {**PROOF_UPDATE_PARTIAL, "owner": 1}
response = client.patch(
f"/api/v1/proofs/{proof.id}",
headers={"Authorization": f"Bearer {user_session.token}"},
json=jsonable_encoder(PROOF_UPDATE_PARTIAL_WRONG),
)
assert response.status_code == 422
PROOF_UPDATE_PARTIAL_WRONG_LIST = [
{**PROOF_UPDATE_PARTIAL, "owner": 1}, # extra field
{**PROOF_UPDATE_PARTIAL, "type": "TEST"}, # wrong type
]
for PROOF_UPDATE_PARTIAL_WRONG in PROOF_UPDATE_PARTIAL_WRONG_LIST:
response = client.patch(
f"/api/v1/proofs/{proof.id}",
headers={"Authorization": f"Bearer {user_session.token}"},
json=jsonable_encoder(PROOF_UPDATE_PARTIAL_WRONG),
)
assert response.status_code == 422


def test_update_proof_moderator(
Expand Down
20 changes: 10 additions & 10 deletions tests/unit/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
import pydantic
import pytest

from app.schemas import CurrencyEnum, LocationOSMEnum, PriceCreate
from app.schemas import CurrencyEnum, LocationOSMEnum, PriceCreateWithValidation


class TestPriceCreate:
class TestPriceCreateWithValidation:
def test_simple_price_with_barcode(self):
price = PriceCreate(
price = PriceCreateWithValidation(
product_code="5414661000456",
location_osm_id=123,
location_osm_type=LocationOSMEnum.NODE,
Expand All @@ -24,7 +24,7 @@ def test_simple_price_with_barcode(self):
assert price.date == datetime.date.fromisoformat("2021-01-01")

def test_simple_price_with_category(self):
price = PriceCreate(
price = PriceCreateWithValidation(
category_tag="en:Fresh-apricots",
labels_tags=["en:Organic", "fr:AB-agriculture-biologique"],
origins_tags=["en:California", "en:Sweden"],
Expand All @@ -40,7 +40,7 @@ def test_simple_price_with_category(self):

def test_simple_price_with_invalid_taxonomized_values(self):
with pytest.raises(pydantic.ValidationError, match="Invalid category tag"):
PriceCreate(
PriceCreateWithValidation(
category_tag="en:unknown-category",
location_osm_id=123,
location_osm_type=LocationOSMEnum.NODE,
Expand All @@ -50,7 +50,7 @@ def test_simple_price_with_invalid_taxonomized_values(self):
)

with pytest.raises(pydantic.ValidationError, match="Invalid label tag"):
PriceCreate(
PriceCreateWithValidation(
category_tag="en:carrots",
labels_tags=["en:invalid"],
location_osm_id=123,
Expand All @@ -61,7 +61,7 @@ def test_simple_price_with_invalid_taxonomized_values(self):
)

with pytest.raises(pydantic.ValidationError, match="Invalid origin tag"):
PriceCreate(
PriceCreateWithValidation(
category_tag="en:carrots",
origins_tags=["en:invalid"],
location_osm_id=123,
Expand All @@ -76,7 +76,7 @@ def test_simple_price_with_product_code_and_labels_tags_raise(self):
pydantic.ValidationError,
match="`labels_tags` can only be set for products without barcode",
):
PriceCreate(
PriceCreateWithValidation(
product_code="5414661000456",
labels_tags=["en:Organic", "fr:AB-agriculture-biologique"],
location_osm_id=123,
Expand All @@ -91,7 +91,7 @@ def test_price_discount_raise(self):
pydantic.ValidationError,
match="`price_is_discounted` must be true if `price_without_discount` is filled",
):
PriceCreate(
PriceCreateWithValidation(
product_code="5414661000456",
location_osm_id=123,
location_osm_type=LocationOSMEnum.NODE,
Expand All @@ -105,7 +105,7 @@ def test_price_discount_raise(self):
pydantic.ValidationError,
match="`price_without_discount` must be greater than `price`",
):
PriceCreate(
PriceCreateWithValidation(
product_code="5414661000456",
location_osm_id=123,
location_osm_type=LocationOSMEnum.NODE,
Expand Down

0 comments on commit 7e17a68

Please sign in to comment.