From c77f049da0a3e28053acab64ee1f077bf9b1e367 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Bournhonesque?= Date: Fri, 1 Dec 2023 16:14:42 +0100 Subject: [PATCH] feat: add language predictor for product --- doc/references/api.yml | 79 +++++++++++++++++++++++++++++++++++ robotoff/app/api.py | 40 ++++++++++++++++++ tests/integration/test_api.py | 43 ++++++++++++++++++- 3 files changed, 161 insertions(+), 1 deletion(-) diff --git a/doc/references/api.yml b/doc/references/api.yml index 0a3e7cb3f4..2a71dd824b 100644 --- a/doc/references/api.yml +++ b/doc/references/api.yml @@ -1086,6 +1086,85 @@ paths: "400": description: "An HTTP 400 is returned if the provided parameters are invalid" + /predict/lang/product: + get: + tags: + - Predict + summary: Predict the languages of the product + description: | + Return the most common languages present on the product images, based on word-level + language detection from product images. + + Language detection is not performed on the fly, but is based on predictions of type + `image_lang` stored in the `prediction` table. + + parameters: + - $ref: "#/components/parameters/barcode" + - $ref: "#/components/parameters/server_type" + in: query + required: false + description: | + the minimum probability for a language to be returned + schema: + type: number + default: 0.01 + minimum: 0 + maximum: 1 + responses: + "200": + description: | + The predicted languages, sorted by descending probability. + content: + application/json: + schema: + type: object + properties: + counts: + type: array + description: | + the number of words detected for each language, over all images, + sorted by descending count + items: + type: object + properties: + lang: + type: string + description: the predicted language (2-letter code). `null` if the language could not be detected. + example: "en" + count: + type: number + description: the number of words for which this language was detected over all images + example: 10 + percent: + type: array + description: | + the percentage of words detected for each language, over all images, + sorted by descending percentage + items: + type: object + properties: + lang: + type: string + description: the predicted language (2-letter code). `null` if the language could not be detected. + example: "en" + percent: + type: number + description: the percentage of words for which the language was detected over all images + minimum: 0 + maximum: 100 + example: 80.5 + image_ids: + type: array + description: | + the IDs of the images that were used to generate the predictions + items: + type: number + example: 1 + description: the ID of an image + "400": + description: "An HTTP 400 is returned if the provided parameters are invalid" + + components: schemas: LogoANNSearchResponse: diff --git a/robotoff/app/api.py b/robotoff/app/api.py index 15b04ca233..22feba3aa5 100644 --- a/robotoff/app/api.py +++ b/robotoff/app/api.py @@ -7,6 +7,8 @@ import re import tempfile import uuid +from collections import defaultdict +from pathlib import Path from typing import Literal, Optional, cast import falcon @@ -685,6 +687,43 @@ def on_post(self, req: falcon.Request, resp: falcon.Response): self._on_get_post(req, resp) +class ProductLanguagePredictorResource: + def on_get(self, req: falcon.Request, resp: falcon.Response): + """Predict the languages displayed on the product images, using + `image_lang` predictions as input.""" + barcode = req.get_param("barcode", required=True) + server_type = get_server_type_from_req(req) + counts: dict[str, int] = defaultdict(int) + image_ids: list[int] = [] + + for prediction_data, source_image in ( + Prediction.select(Prediction.data, Prediction.source_image) + .where( + Prediction.barcode == barcode, + Prediction.server_type == server_type.name, + Prediction.type == PredictionType.image_lang.name, + ) + .tuples() + .iterator() + ): + image_ids.append(int(Path(source_image).stem)) + for lang, lang_count in prediction_data["count"].items(): + counts[lang] += lang_count + + words_n = counts.pop("words") + sorted_counts = sorted(counts.items(), key=lambda x: x[1], reverse=True) + counts_list = [{"count": count, "lang": lang} for lang, count in sorted_counts] + percent_list = [ + {"percent": (count * 100 / words_n), "lang": lang} + for lang, count in sorted_counts + ] + resp.media = { + "counts": counts_list, + "percent": percent_list, + "image_ids": sorted(image_ids), + } + + class UpdateDatasetResource: def on_post(self, req: falcon.Request, resp: falcon.Response): """Re-import the Product Opener product dump.""" @@ -1796,6 +1835,7 @@ def on_get(self, req: falcon.Request, resp: falcon.Response): api.add_route("/api/v1/predict/category", CategoryPredictorResource()) api.add_route("/api/v1/predict/ingredient_list", IngredientListPredictorResource()) api.add_route("/api/v1/predict/lang", LanguagePredictorResource()) +api.add_route("/api/v1/predict/lang/product", ProductLanguagePredictorResource()) api.add_route("/api/v1/products/dataset", UpdateDatasetResource()) api.add_route("/api/v1/webhook/product", WebhookProductResource()) api.add_route("/api/v1/images", ImageCollection()) diff --git a/tests/integration/test_api.py b/tests/integration/test_api.py index 18999eb5f4..66a351addc 100644 --- a/tests/integration/test_api.py +++ b/tests/integration/test_api.py @@ -10,7 +10,7 @@ from robotoff.models import AnnotationVote, LogoAnnotation, ProductInsight from robotoff.off import OFFAuthentication from robotoff.prediction.langid import LanguagePrediction -from robotoff.types import ProductIdentifier, ServerType +from robotoff.types import PredictionType, ProductIdentifier, ServerType from .models_utils import ( AnnotationVoteFactory, @@ -1250,3 +1250,44 @@ def test_predict_lang(client, mocker): ) assert result.status_code == 200 assert result.json == {"predictions": expected_predictions} + + +def test_predict_product_language(client, peewee_db): + barcode = "123456789" + prediction_data_1 = {"count": {"en": 10, "fr": 5, "es": 3, "words": 18}} + prediction_data_2 = {"count": {"en": 2, "fr": 3, "words": 5}} + + with peewee_db: + PredictionFactory( + barcode=barcode, + server_type=ServerType.off.name, + type=PredictionType.image_lang.name, + data=prediction_data_1, + source_image="/123/45678/2.jpg", + ) + PredictionFactory( + barcode=barcode, + server_type=ServerType.off.name, + type=PredictionType.image_lang.name, + data=prediction_data_2, + source_image="/123/45678/4.jpg", + ) + + # Send GET request to the API endpoint + result = client.simulate_get(f"/api/v1/predict/lang/product?barcode={barcode}") + + # Assert the response + assert result.status_code == 200 + assert result.json == { + "counts": [ + {"count": 12, "lang": "en"}, + {"count": 8, "lang": "fr"}, + {"count": 3, "lang": "es"}, + ], + "percent": [ + {"percent": 12 * 100 / 23, "lang": "en"}, + {"percent": 8 * 100 / 23, "lang": "fr"}, + {"percent": 3 * 100 / 23, "lang": "es"}, + ], + "image_ids": [2, 4], + }