Skip to content

Commit

Permalink
fix: replace parameter country by countries
Browse files Browse the repository at this point in the history
  • Loading branch information
raphael0202 committed Aug 7, 2023
1 parent 9c2d1c6 commit 56b0804
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 22 deletions.
18 changes: 9 additions & 9 deletions doc/references/api.yml
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ paths:
- $ref: "#/components/parameters/count"
- $ref: "#/components/parameters/server_type"
- $ref: "#/components/parameters/insight_types"
- $ref: "#/components/parameters/country"
- $ref: "#/components/parameters/countries"
- $ref: "#/components/parameters/brands"
- $ref: "#/components/parameters/value_tag"
- $ref: "#/components/parameters/page"
Expand Down Expand Up @@ -116,7 +116,7 @@ paths:
- $ref: "#/components/parameters/count"
- $ref: "#/components/parameters/server_type"
- $ref: "#/components/parameters/insight_types"
- $ref: "#/components/parameters/country"
- $ref: "#/components/parameters/countries"
- $ref: "#/components/parameters/brands"
- $ref: "#/components/parameters/value_tag"
- $ref: "#/components/parameters/page"
Expand Down Expand Up @@ -156,7 +156,7 @@ paths:
- $ref: "#/components/parameters/count"
- $ref: "#/components/parameters/server_type"
- $ref: "#/components/parameters/insight_types"
- $ref: "#/components/parameters/country"
- $ref: "#/components/parameters/countries"
- $ref: "#/components/parameters/brands"
- $ref: "#/components/parameters/value_tag"
- $ref: "#/components/parameters/page"
Expand Down Expand Up @@ -185,7 +185,7 @@ paths:
minimum: 1
- $ref: "#/components/parameters/server_type"
- $ref: "#/components/parameters/insight_type"
- $ref: "#/components/parameters/country"
- $ref: "#/components/parameters/countries"
- $ref: "#/components/parameters/page"
- $ref: "#/components/parameters/reserved_barcode"
- $ref: "#/components/parameters/campaigns"
Expand Down Expand Up @@ -264,7 +264,7 @@ paths:
summary: Get a random insight
parameters:
- $ref: "#/components/parameters/insight_type"
- $ref: "#/components/parameters/country"
- $ref: "#/components/parameters/countries"
- $ref: "#/components/parameters/value_tag"
- $ref: "#/components/parameters/server_type"
- $ref: "#/components/parameters/count"
Expand Down Expand Up @@ -1319,13 +1319,13 @@ components:
description: Filter by insight type
schema:
type: string
country:
name: country
countries:
name: countries
in: query
description: Filter by country tag
description: Comma separated list, filter by country value (2-letter code)
schema:
type: string
example: en:france
example: uk
brands:
name: brands
in: query
Expand Down
49 changes: 41 additions & 8 deletions robotoff/app/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from falcon.media.validators import jsonschema
from falcon_cors import CORS
from falcon_multipart.middleware import MultipartMiddleware
from openfoodfacts import Country
from PIL import Image
from sentry_sdk.integrations.falcon import FalconIntegration

Expand Down Expand Up @@ -124,6 +125,38 @@ def get_server_type_from_req(
raise falcon.HTTPBadRequest(f"invalid `server_type`: {server_type_str}")


COUNTRY_NAME_TO_ENUM = {item.value: item for item in Country}


def get_countries_from_req(req: falcon.Request) -> Optional[list[Country]]:
"""Parse `countries` query string from request."""
countries_str: Optional[list[str]] = req.get_param_as_list("countries")
countries: Optional[list[Country]] = None

if countries_str is None:
# `country` parameter is deprecated
country: Optional[str] = req.get_param("country")

if country:
if country in COUNTRY_NAME_TO_ENUM:
country = COUNTRY_NAME_TO_ENUM[country]
countries = [country]
else:
countries = None
else:
try:
countries = [
Country.get_from_2_letter_code(country_str)
for country_str in countries_str
]
except KeyError:
raise falcon.HTTPBadRequest(
description=f"invalid `countries` value: {countries_str}"
)

return countries


def _get_skip_voted_on(
auth: Optional[OFFAuthentication], device_id: str
) -> SkipVotedOn:
Expand Down Expand Up @@ -184,13 +217,13 @@ def on_get(self, req: falcon.Request, resp: falcon.Response):
"insight_types", required=False
)
barcode: Optional[str] = req.get_param("barcode")
country: Optional[str] = req.get_param("country")
annotated: Optional[bool] = req.get_param_as_bool("annotated")
annotation: Optional[int] = req.get_param_as_int("annotation")
value_tag: str = req.get_param("value_tag")
brands = req.get_param_as_list("brands") or None
predictor = req.get_param("predictor")
server_type = get_server_type_from_req(req)
countries: Optional[list[str]] = get_countries_from_req(req)

if keep_types:
# Limit the number of types to prevent slow SQL queries
Expand All @@ -204,7 +237,7 @@ def on_get(self, req: falcon.Request, resp: falcon.Response):
get_insights,
server_type=server_type,
keep_types=keep_types,
country=country,
countries=countries,
value_tag=value_tag,
brands=brands,
annotated=annotated,
Expand Down Expand Up @@ -232,18 +265,18 @@ def on_get(self, req: falcon.Request, resp: falcon.Response):
response: JSONType = {}

insight_type: Optional[str] = req.get_param("type")
country: Optional[str] = req.get_param("country")
value_tag: Optional[str] = req.get_param("value_tag")
count: int = req.get_param_as_int("count", min_value=1, default=25)
predictor = req.get_param("predictor")
server_type = get_server_type_from_req(req)
countries: Optional[list[str]] = get_countries_from_req(req)

keep_types = [insight_type] if insight_type else None
get_insights_ = functools.partial(
get_insights,
server_type=server_type,
keep_types=keep_types,
country=country,
countries=countries,
value_tag=value_tag,
order_by="random",
predictor=predictor,
Expand Down Expand Up @@ -1295,13 +1328,13 @@ def get_questions_resource_on_get(
"insight_types", required=False
)
keep_types = filter_question_insight_types(keep_types)
country: Optional[str] = req.get_param("country")
value_tag: str = req.get_param("value_tag")
brands = req.get_param_as_list("brands") or None
reserved_barcode: Optional[bool] = req.get_param_as_bool(
"reserved_barcode", default=False
)
server_type = get_server_type_from_req(req)
countries: Optional[list[str]] = get_countries_from_req(req)
# filter by annotation campaigns
campaigns: Optional[list[str]] = req.get_param_as_list("campaigns") or None

Expand Down Expand Up @@ -1330,7 +1363,7 @@ def get_questions_resource_on_get(
get_insights,
server_type=server_type,
keep_types=keep_types,
country=country,
countries=countries,
value_tag=value_tag,
brands=brands,
order_by=order_by,
Expand Down Expand Up @@ -1520,7 +1553,7 @@ def on_get(self, req: falcon.Request, resp: falcon.Response):
page: int = req.get_param_as_int("page", min_value=1, default=1)
count: int = req.get_param_as_int("count", min_value=1, default=25)
insight_type: str = req.get_param("type")
country: Optional[str] = req.get_param("country")
countries: Optional[list[str]] = get_countries_from_req(req)
reserved_barcode: Optional[bool] = req.get_param_as_bool(
"reserved_barcode", default=False
)
Expand All @@ -1540,7 +1573,7 @@ def on_get(self, req: falcon.Request, resp: falcon.Response):
keep_types=[insight_type] if insight_type else None,
group_by_value_tag=True,
limit=count,
country=country,
countries=countries,
automatically_processable=False,
reserved_barcode=reserved_barcode,
campaigns=campaigns,
Expand Down
12 changes: 8 additions & 4 deletions robotoff/app/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Iterable, Literal, NamedTuple, Optional, Union

import peewee
from openfoodfacts import Country
from peewee import JOIN, SQL, fn

from robotoff.app import events
Expand Down Expand Up @@ -70,7 +71,7 @@ def get_insights(
barcode: Optional[str] = None,
server_type: ServerType = ServerType.off,
keep_types: Optional[list[str]] = None,
country: Optional[str] = None,
countries: Optional[list[Country]] = None,
brands: Optional[list[str]] = None,
annotated: Optional[bool] = False,
annotation: Optional[int] = None,
Expand All @@ -97,7 +98,8 @@ def get_insights(
ServerType.off
:param keep_types: only keep insights that have any of the these types,
defaults to None
:param country: only keep insights with this country, defaults to None
:param countries: only keep insights with `country` in this list of
countries, defaults to None
:param brands: only keep insights that have any of these brands, defaults
to None
:param annotated: only keep annotated (True), not annotated (False
Expand Down Expand Up @@ -155,8 +157,10 @@ def get_insights(
if keep_types is not None:
where_clauses.append(ProductInsight.type.in_(keep_types))

if country is not None:
where_clauses.append(ProductInsight.countries.contains(country))
if countries is not None:
where_clauses.append(
ProductInsight.countries.contains_any([c.value for c in countries])
)

if brands:
where_clauses.append(ProductInsight.brands.contains_any(brands))
Expand Down
3 changes: 2 additions & 1 deletion tests/integration/test_core_integration.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
from openfoodfacts.types import Country

from robotoff.app.core import (
get_image_predictions,
Expand Down Expand Up @@ -214,7 +215,7 @@ def test_get_unanswered_questions_list():
assert len(insight_data_items5) == 2

product6 = ProductInsightFactory(value_tag="en:raisins", countries="en:india")
insight_data6 = get_insights(country="en:india")
insight_data6 = get_insights(countries=[Country.in_])
insight_data_items6 = [item.to_dict() for item in insight_data6]
assert insight_data_items6[0]["id"] == product6.id
assert insight_data_items6[0]["value_tag"] == "en:raisins"
Expand Down

0 comments on commit 56b0804

Please sign in to comment.