Skip to content

Commit

Permalink
feat: add endpoint to predict language
Browse files Browse the repository at this point in the history
  • Loading branch information
raphael0202 committed Nov 30, 2023
1 parent 79b300c commit 686e180
Show file tree
Hide file tree
Showing 5 changed files with 167 additions and 5 deletions.
58 changes: 57 additions & 1 deletion doc/references/api.yml
Original file line number Diff line number Diff line change
Expand Up @@ -1021,7 +1021,63 @@ paths:

"400":
description: "An HTTP 400 is returned if the provided parameters are invalid"


/predict/lang:
get:
tags:
- Predict
summary: Predict the language of a text
parameters:
- name: text
in: query
required: true
description: The text to predict language of
schema:
type: string
example: "hello world"
- name: k
in: query
required: false
description: |
the number of predictions to return
schema:
type: integer
default: 10
minimum: 1
- name: threshold
in: query
required: false
description: |
the minimum probability for a language to be returned
schema:
type: number
default: 0.01
minimum: 0
maximum: 1
responses:
"200":
description: the predicted languages
content:
application/json:
schema:
type: object
properties:
predictions:
type: array
description: a list of predicted languages, sorted by descending probability
items:
type: object
properties:
lang:
type: string
description: the predicted language (2-letter code)
example: "en"
confidence:
type: number
description: the probability of the predicted language
example: 0.9
"400":
description: "An HTTP 400 is returned if the provided parameters are invalid"

components:
schemas:
Expand Down
28 changes: 25 additions & 3 deletions robotoff/app/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import re
import tempfile
import uuid
from typing import Literal, Optional
from typing import Literal, Optional, cast

import falcon
import orjson
Expand Down Expand Up @@ -36,6 +36,7 @@
get_predictions,
save_annotation,
update_logo_annotations,
validate_params,
)
from robotoff.app.middleware import DBConnectionMiddleware
from robotoff.elasticsearch import get_es_client
Expand Down Expand Up @@ -67,6 +68,7 @@
)
from robotoff.prediction import ingredient_list
from robotoff.prediction.category import predict_category
from robotoff.prediction.langid import predict_lang
from robotoff.prediction.object_detection import ObjectDetectionModelRegistry
from robotoff.products import get_image_id, get_product, get_product_dataset_etag
from robotoff.taxonomy import is_prefixed_value, match_taxonomized_value
Expand Down Expand Up @@ -98,8 +100,6 @@

settings.init_sentry(integrations=[FalconIntegration()])

es_client = get_es_client()

TRANSLATION_STORE = TranslationStore()
TRANSLATION_STORE.load()

Expand Down Expand Up @@ -650,6 +650,26 @@ def on_get(self, req: falcon.Request, resp: falcon.Response):
resp.media = dataclasses.asdict(output)


class LanguagePredictorResource:
def on_get(self, req: falcon.Request, resp: falcon.Response):
"""Predict language of a text."""
params = validate_params(
{
"text": req.get_param("text"),
"k": req.get_param("k"),
"threshold": req.get_param("threshold"),
},
schema.LanguagePredictorResourceParams,
)
params = cast(schema.LanguagePredictorResourceParams, params)
language_predictions = predict_lang(params.text, params.k, params.threshold)
resp.media = {
"predictions": [
dataclasses.asdict(prediction) for prediction in language_predictions
]
}


class UpdateDatasetResource:
def on_post(self, req: falcon.Request, resp: falcon.Response):
"""Re-import the Product Opener product dump."""
Expand Down Expand Up @@ -1150,6 +1170,7 @@ def on_get(
"""
count = req.get_param_as_int("count", min_value=1, max_value=500, default=100)
server_type = get_server_type_from_req(req)
es_client = get_es_client()

if logo_id is None:
logo_embeddings = list(
Expand Down Expand Up @@ -1759,6 +1780,7 @@ def on_get(self, req: falcon.Request, resp: falcon.Response):
api.add_route("/api/v1/predict/ocr_prediction", OCRPredictionPredictorResource())
api.add_route("/api/v1/predict/category", CategoryPredictorResource())
api.add_route("/api/v1/predict/ingredient_list", IngredientListPredictorResource())
api.add_route("/api/v1/predict/lang", LanguagePredictorResource())
api.add_route("/api/v1/products/dataset", UpdateDatasetResource())
api.add_route("/api/v1/webhook/product", WebhookProductResource())
api.add_route("/api/v1/images", ImageCollection())
Expand Down
24 changes: 23 additions & 1 deletion robotoff/app/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
from enum import Enum
from typing import Iterable, Literal, NamedTuple, Optional, Union

import falcon
import peewee
from openfoodfacts.types import COUNTRY_CODE_TO_NAME, Country
from peewee import JOIN, SQL, fn
from pydantic import BaseModel, ValidationError

from robotoff.app import events
from robotoff.insights.annotate import (
Expand All @@ -27,7 +29,7 @@
)
from robotoff.off import OFFAuthentication
from robotoff.taxonomy import match_taxonomized_value
from robotoff.types import InsightAnnotation, ServerType
from robotoff.types import InsightAnnotation, JSONType, ServerType
from robotoff.utils import get_logger
from robotoff.utils.text import get_tag

Expand Down Expand Up @@ -580,3 +582,23 @@ def filter_question_insight_types(keep_types: Optional[list[str]]):
set(keep_types) & set(QuestionFormatterFactory.get_available_types())
)
return keep_types


def validate_params(params: JSONType, schema: type) -> BaseModel:
"""Validate the parameters passed to a Falcon resource.
Either returns a validated params object or raises a falcon.HTTPBadRequest.
:param params: the input parameters to validate, as a dict
:param schema: the pydantic schema to use for validation
:raises falcon.HTTPBadRequest: if the parameters are invalid
"""
# Remove None values from the params dict
params = {k: v for k, v in params.items() if v is not None}
try:
return schema.model_validate(params) # type: ignore
except ValidationError as e:
errors = e.errors(include_url=False)
plural = "s" if len(errors) > 1 else ""
description = f"{len(errors)} validation error{plural}: {errors}"
raise falcon.HTTPBadRequest(description=description)
17 changes: 17 additions & 0 deletions robotoff/app/schema.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from typing import Annotated

from pydantic import BaseModel, Field

from robotoff.types import JSONType, NeuralCategoryClassifierModel, ServerType

IMAGE_PREDICTION_IMPORTER_SCHEMA: JSONType = {
Expand Down Expand Up @@ -172,3 +176,16 @@
},
"required": ["annotations"],
}


class LanguagePredictorResourceParams(BaseModel):
text: Annotated[
str, Field(..., description="the text to predict language of", min_length=1)
]
k: Annotated[
int, Field(default=10, description="the number of predictions to return", ge=1)
]
threshold: Annotated[
float,
Field(default=0.01, description="the minimum confidence threshold", ge=0, le=1),
]
45 changes: 45 additions & 0 deletions tests/integration/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from robotoff.app.api import api
from robotoff.models import AnnotationVote, LogoAnnotation, ProductInsight
from robotoff.off import OFFAuthentication
from robotoff.prediction.langid import LanguagePrediction
from robotoff.types import ProductIdentifier, ServerType

from .models_utils import (
Expand Down Expand Up @@ -1205,3 +1206,47 @@ def test_logo_annotation_collection_pagination(client, peewee_db):
"truffle cake-00",
"truffle cake-01",
]


def test_predict_lang_invalid_params(client, mocker):
mocker.patch(
"robotoff.app.api.predict_lang",
return_value=[],
)
# no text
result = client.simulate_get("/api/v1/predict/lang", params={"k": 2})
assert result.status_code == 400
assert result.json == {
"description": "1 validation error: [{'type': 'missing', 'loc': ('text',), 'msg': 'Field required', 'input': {'k': '2'}}]",
"title": "400 Bad Request",
}

# invalid k and threshold parameters
result = client.simulate_get(
"/api/v1/predict/lang",
params={"text": "test", "k": "invalid", "threshold": 1.05},
)
assert result.status_code == 400
assert result.json == {
"description": "2 validation errors: [{'type': 'int_parsing', 'loc': ('k',), 'msg': 'Input should be a valid integer, unable to parse string as an integer', 'input': 'invalid'}, {'type': 'less_than_equal', 'loc': ('threshold',), 'msg': 'Input should be less than or equal to 1', 'input': '1.05', 'ctx': {'le': 1.0}}]",
"title": "400 Bad Request",
}


def test_predict_lang(client, mocker):
mocker.patch(
"robotoff.app.api.predict_lang",
return_value=[
LanguagePrediction("en", 0.9),
LanguagePrediction("fr", 0.1),
],
)
expected_predictions = [
{"lang": "en", "confidence": 0.9},
{"lang": "fr", "confidence": 0.1},
]
result = client.simulate_get(
"/api/v1/predict/lang", params={"text": "hello", "k": 2}
)
assert result.status_code == 200
assert result.json == {"predictions": expected_predictions}

0 comments on commit 686e180

Please sign in to comment.