diff --git a/doc/references/api.yml b/doc/references/api.yml index 5672ca52dc..fb9f34afd7 100644 --- a/doc/references/api.yml +++ b/doc/references/api.yml @@ -64,7 +64,7 @@ paths: - $ref: "#/components/parameters/count" - $ref: "#/components/parameters/server_type" - $ref: "#/components/parameters/insight_types" - - $ref: "#/components/parameters/country" + - $ref: "#/components/parameters/countries" - $ref: "#/components/parameters/brands" - $ref: "#/components/parameters/value_tag" - $ref: "#/components/parameters/page" @@ -116,7 +116,7 @@ paths: - $ref: "#/components/parameters/count" - $ref: "#/components/parameters/server_type" - $ref: "#/components/parameters/insight_types" - - $ref: "#/components/parameters/country" + - $ref: "#/components/parameters/countries" - $ref: "#/components/parameters/brands" - $ref: "#/components/parameters/value_tag" - $ref: "#/components/parameters/page" @@ -156,7 +156,7 @@ paths: - $ref: "#/components/parameters/count" - $ref: "#/components/parameters/server_type" - $ref: "#/components/parameters/insight_types" - - $ref: "#/components/parameters/country" + - $ref: "#/components/parameters/countries" - $ref: "#/components/parameters/brands" - $ref: "#/components/parameters/value_tag" - $ref: "#/components/parameters/page" @@ -185,7 +185,7 @@ paths: minimum: 1 - $ref: "#/components/parameters/server_type" - $ref: "#/components/parameters/insight_type" - - $ref: "#/components/parameters/country" + - $ref: "#/components/parameters/countries" - $ref: "#/components/parameters/page" - $ref: "#/components/parameters/reserved_barcode" - $ref: "#/components/parameters/campaigns" @@ -264,7 +264,7 @@ paths: summary: Get a random insight parameters: - $ref: "#/components/parameters/insight_type" - - $ref: "#/components/parameters/country" + - $ref: "#/components/parameters/countries" - $ref: "#/components/parameters/value_tag" - $ref: "#/components/parameters/server_type" - $ref: "#/components/parameters/count" @@ -1319,13 +1319,13 @@ components: description: Filter by insight type schema: type: string - country: - name: country + countries: + name: countries in: query - description: Filter by country tag + description: Comma separated list, filter by country value (2-letter code) schema: type: string - example: en:france + example: uk brands: name: brands in: query diff --git a/robotoff/app/api.py b/robotoff/app/api.py index 347d451317..465e5c33d1 100644 --- a/robotoff/app/api.py +++ b/robotoff/app/api.py @@ -16,6 +16,7 @@ from falcon.media.validators import jsonschema from falcon_cors import CORS from falcon_multipart.middleware import MultipartMiddleware +from openfoodfacts import Country from PIL import Image from sentry_sdk.integrations.falcon import FalconIntegration @@ -124,6 +125,38 @@ def get_server_type_from_req( raise falcon.HTTPBadRequest(f"invalid `server_type`: {server_type_str}") +COUNTRY_NAME_TO_ENUM = {item.value: item for item in Country} + + +def get_countries_from_req(req: falcon.Request) -> Optional[list[Country]]: + """Parse `countries` query string from request.""" + countries_str: Optional[list[str]] = req.get_param_as_list("countries") + countries: Optional[list[Country]] = None + + if countries_str is None: + # `country` parameter is deprecated + country: Optional[str] = req.get_param("country") + + if country: + if country in COUNTRY_NAME_TO_ENUM: + country = COUNTRY_NAME_TO_ENUM[country] + countries = [country] + else: + countries = None + else: + try: + countries = [ + Country.get_from_2_letter_code(country_str) + for country_str in countries_str + ] + except KeyError: + raise falcon.HTTPBadRequest( + description=f"invalid `countries` value: {countries_str}" + ) + + return countries + + def _get_skip_voted_on( auth: Optional[OFFAuthentication], device_id: str ) -> SkipVotedOn: @@ -184,13 +217,13 @@ def on_get(self, req: falcon.Request, resp: falcon.Response): "insight_types", required=False ) barcode: Optional[str] = req.get_param("barcode") - country: Optional[str] = req.get_param("country") annotated: Optional[bool] = req.get_param_as_bool("annotated") annotation: Optional[int] = req.get_param_as_int("annotation") value_tag: str = req.get_param("value_tag") brands = req.get_param_as_list("brands") or None predictor = req.get_param("predictor") server_type = get_server_type_from_req(req) + countries: Optional[list[str]] = get_countries_from_req(req) if keep_types: # Limit the number of types to prevent slow SQL queries @@ -204,7 +237,7 @@ def on_get(self, req: falcon.Request, resp: falcon.Response): get_insights, server_type=server_type, keep_types=keep_types, - country=country, + countries=countries, value_tag=value_tag, brands=brands, annotated=annotated, @@ -232,18 +265,18 @@ def on_get(self, req: falcon.Request, resp: falcon.Response): response: JSONType = {} insight_type: Optional[str] = req.get_param("type") - country: Optional[str] = req.get_param("country") value_tag: Optional[str] = req.get_param("value_tag") count: int = req.get_param_as_int("count", min_value=1, default=25) predictor = req.get_param("predictor") server_type = get_server_type_from_req(req) + countries: Optional[list[str]] = get_countries_from_req(req) keep_types = [insight_type] if insight_type else None get_insights_ = functools.partial( get_insights, server_type=server_type, keep_types=keep_types, - country=country, + countries=countries, value_tag=value_tag, order_by="random", predictor=predictor, @@ -1295,13 +1328,13 @@ def get_questions_resource_on_get( "insight_types", required=False ) keep_types = filter_question_insight_types(keep_types) - country: Optional[str] = req.get_param("country") value_tag: str = req.get_param("value_tag") brands = req.get_param_as_list("brands") or None reserved_barcode: Optional[bool] = req.get_param_as_bool( "reserved_barcode", default=False ) server_type = get_server_type_from_req(req) + countries: Optional[list[str]] = get_countries_from_req(req) # filter by annotation campaigns campaigns: Optional[list[str]] = req.get_param_as_list("campaigns") or None @@ -1330,7 +1363,7 @@ def get_questions_resource_on_get( get_insights, server_type=server_type, keep_types=keep_types, - country=country, + countries=countries, value_tag=value_tag, brands=brands, order_by=order_by, @@ -1520,7 +1553,7 @@ def on_get(self, req: falcon.Request, resp: falcon.Response): page: int = req.get_param_as_int("page", min_value=1, default=1) count: int = req.get_param_as_int("count", min_value=1, default=25) insight_type: str = req.get_param("type") - country: Optional[str] = req.get_param("country") + countries: Optional[list[str]] = get_countries_from_req(req) reserved_barcode: Optional[bool] = req.get_param_as_bool( "reserved_barcode", default=False ) @@ -1540,7 +1573,7 @@ def on_get(self, req: falcon.Request, resp: falcon.Response): keep_types=[insight_type] if insight_type else None, group_by_value_tag=True, limit=count, - country=country, + countries=countries, automatically_processable=False, reserved_barcode=reserved_barcode, campaigns=campaigns, diff --git a/robotoff/app/core.py b/robotoff/app/core.py index fba69e3c43..b48f8c093a 100644 --- a/robotoff/app/core.py +++ b/robotoff/app/core.py @@ -4,6 +4,7 @@ from typing import Iterable, Literal, NamedTuple, Optional, Union import peewee +from openfoodfacts import Country from peewee import JOIN, SQL, fn from robotoff.app import events @@ -70,7 +71,7 @@ def get_insights( barcode: Optional[str] = None, server_type: ServerType = ServerType.off, keep_types: Optional[list[str]] = None, - country: Optional[str] = None, + countries: Optional[list[Country]] = None, brands: Optional[list[str]] = None, annotated: Optional[bool] = False, annotation: Optional[int] = None, @@ -97,7 +98,8 @@ def get_insights( ServerType.off :param keep_types: only keep insights that have any of the these types, defaults to None - :param country: only keep insights with this country, defaults to None + :param countries: only keep insights with `country` in this list of + countries, defaults to None :param brands: only keep insights that have any of these brands, defaults to None :param annotated: only keep annotated (True), not annotated (False @@ -155,8 +157,10 @@ def get_insights( if keep_types is not None: where_clauses.append(ProductInsight.type.in_(keep_types)) - if country is not None: - where_clauses.append(ProductInsight.countries.contains(country)) + if countries is not None: + where_clauses.append( + ProductInsight.countries.contains_any([c.value for c in countries]) + ) if brands: where_clauses.append(ProductInsight.brands.contains_any(brands)) diff --git a/tests/integration/test_core_integration.py b/tests/integration/test_core_integration.py index c3550480e7..5bd1ca9c5b 100644 --- a/tests/integration/test_core_integration.py +++ b/tests/integration/test_core_integration.py @@ -1,4 +1,5 @@ import pytest +from openfoodfacts.types import Country from robotoff.app.core import ( get_image_predictions, @@ -214,7 +215,7 @@ def test_get_unanswered_questions_list(): assert len(insight_data_items5) == 2 product6 = ProductInsightFactory(value_tag="en:raisins", countries="en:india") - insight_data6 = get_insights(country="en:india") + insight_data6 = get_insights(countries=[Country.in_]) insight_data_items6 = [item.to_dict() for item in insight_data6] assert insight_data_items6[0]["id"] == product6.id assert insight_data_items6[0]["value_tag"] == "en:raisins"