From 8baa565c740ea94290a64e38579026e930568d39 Mon Sep 17 00:00:00 2001 From: Pierre Slamich Date: Mon, 6 May 2019 14:15:57 +0200 Subject: [PATCH 1/4] Logo annotations for packaging --- data/ocr/logo_annotation_packagings.txt | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 data/ocr/logo_annotation_packagings.txt diff --git a/data/ocr/logo_annotation_packagings.txt b/data/ocr/logo_annotation_packagings.txt new file mode 100644 index 0000000000..65d4ab2daa --- /dev/null +++ b/data/ocr/logo_annotation_packagings.txt @@ -0,0 +1,2 @@ +alu 41||en:Aluminium +Acier Recyclable||en:Recyclable steel From ebba5c92a900d72339b84c08073ccfea5dd9ce71 Mon Sep 17 00:00:00 2001 From: Pierre Slamich Date: Fri, 24 Jul 2020 11:59:33 +0200 Subject: [PATCH 2/4] Update logo_annotation_packagings.txt --- data/ocr/logo_annotation_packagings.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/data/ocr/logo_annotation_packagings.txt b/data/ocr/logo_annotation_packagings.txt index 65d4ab2daa..574c6e0c85 100644 --- a/data/ocr/logo_annotation_packagings.txt +++ b/data/ocr/logo_annotation_packagings.txt @@ -1,2 +1,2 @@ -alu 41||en:Aluminium -Acier Recyclable||en:Recyclable steel +en:Aluminium||alu 41 +en:Recyclable steel||Acier Recyclable From 9dfecaaf89996a8a490be640a452417a44bb30c9 Mon Sep 17 00:00:00 2001 From: Pierre Slamich Date: Thu, 3 Feb 2022 11:02:06 +0100 Subject: [PATCH 3/4] Update logo_annotation_packagings.txt --- data/ocr/logo_annotation_packagings.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/data/ocr/logo_annotation_packagings.txt b/data/ocr/logo_annotation_packagings.txt index 574c6e0c85..ef85007a40 100644 --- a/data/ocr/logo_annotation_packagings.txt +++ b/data/ocr/logo_annotation_packagings.txt @@ -1,2 +1,3 @@ en:Aluminium||alu 41 en:Recyclable steel||Acier Recyclable +en:Plastic bottle||Plastic bottle From 0ed35673acdf9693523e2543e9f649ebe1b218ef Mon Sep 17 00:00:00 2001 From: "deepsource-autofix[bot]" <62050782+deepsource-autofix[bot]@users.noreply.github.com> Date: Thu, 3 Feb 2022 10:02:57 +0000 Subject: [PATCH 4/4] Format code with black and isort --- robotoff.py | 31 +- robotoff/app/api.py | 276 ++++---- robotoff/app/core.py | 69 +- robotoff/cli/annotate.py | 76 +- robotoff/cli/batch.py | 35 +- robotoff/cli/insights.py | 26 +- robotoff/cli/run.py | 19 +- robotoff/elasticsearch/category/dump.py | 39 +- robotoff/elasticsearch/category/match.py | 34 +- robotoff/elasticsearch/category/predict.py | 46 +- robotoff/elasticsearch/product/dump.py | 46 +- robotoff/ingredients.py | 174 ++--- robotoff/insights/annotate.py | 178 ++--- robotoff/insights/data.py | 6 +- robotoff/insights/extraction.py | 103 ++- robotoff/insights/importer.py | 666 +++++++++--------- robotoff/insights/normalize.py | 12 +- robotoff/insights/ocr/__init__.py | 4 +- robotoff/insights/ocr/brand.py | 63 +- robotoff/insights/ocr/core.py | 48 +- robotoff/insights/ocr/dataclass.py | 245 ++++--- robotoff/insights/ocr/expiration_date.py | 48 +- robotoff/insights/ocr/image_flag.py | 33 +- robotoff/insights/ocr/image_orientation.py | 10 +- robotoff/insights/ocr/label.py | 386 +++++----- robotoff/insights/ocr/nutrient.py | 53 +- robotoff/insights/ocr/packager_code.py | 60 +- robotoff/insights/ocr/product_weight.py | 53 +- robotoff/insights/ocr/store.py | 41 +- robotoff/insights/ocr/trace.py | 19 +- robotoff/insights/question.py | 119 ++-- robotoff/insights/validator.py | 30 +- robotoff/ml/category_classifier.py | 238 ++++--- robotoff/ml/object_detection/core.py | 198 +++--- robotoff/ml/object_detection/download.py | 47 +- .../ml/object_detection/utils/dataset_util.py | 4 +- .../object_detection/utils/label_map_util.py | 45 +- robotoff/ml/object_detection/utils/ops.py | 21 +- .../object_detection/utils/standard_fields.py | 175 ++--- .../utils/string_int_label_map_pb2.py | 226 +++--- .../utils/visualization_utils.py | 552 +++++++++------ robotoff/models.py | 24 +- robotoff/off.py | 99 ++- robotoff/products.py | 165 +++-- robotoff/scheduler.py | 164 +++-- robotoff/settings.py | 60 +- robotoff/slack.py | 142 ++-- robotoff/taxonomy.py | 97 +-- robotoff/utils/__init__.py | 44 +- robotoff/utils/cache.py | 21 +- robotoff/utils/es.py | 32 +- robotoff/utils/i18n.py | 16 +- robotoff/utils/text.py | 4 +- robotoff/workers/client.py | 19 +- robotoff/workers/listener.py | 19 +- robotoff/workers/tasks.py | 68 +- tests/insights/ocr/test_brand.py | 35 +- tests/insights/ocr/test_image_orientation.py | 50 +- tests/insights/ocr/test_product_weight.py | 23 +- tests/insights/test_annotate.py | 30 +- tests/insights/test_importer.py | 18 +- tests/insights/test_question.py | 16 +- tests/test_ingredients.py | 35 +- tests/test_taxonomy.py | 37 +- 64 files changed, 3151 insertions(+), 2621 deletions(-) diff --git a/robotoff.py b/robotoff.py index 5823883cf4..2ec9f73681 100644 --- a/robotoff.py +++ b/robotoff.py @@ -9,45 +9,50 @@ def cli(): @click.command() -@click.argument('service') +@click.argument("service") def run(service: str): from robotoff.cli.run import run as run_ + run_(service) @click.command() -@click.argument('input_') -@click.option('--insight-type', '-t', required=True) -@click.option('--output', '-o') +@click.argument("input_") +@click.option("--insight-type", "-t", required=True) +@click.option("--output", "-o") def generate_ocr_insights(input_: str, insight_type: str, output: str): from robotoff.cli import insights + insights.run_from_ocr_archive(input_, insight_type, output) @click.command() -@click.option('--insight-type') -@click.option('--country') +@click.option("--insight-type") +@click.option("--country") def annotate(insight_type: Optional[str], country: Optional[str]): from robotoff.cli import annotate as annotate_ + annotate_.run(insight_type, country) @click.command() -@click.option('--insight-type', required=True) -@click.option('--dry/--no-dry', default=True) -@click.option('-f', '--filter', 'filter_clause') +@click.option("--insight-type", required=True) +@click.option("--dry/--no-dry", default=True) +@click.option("-f", "--filter", "filter_clause") def batch_annotate(insight_type: str, dry: bool, filter_clause: str): from robotoff.cli import batch + batch.run(insight_type, dry, filter_clause) @click.command() -@click.argument('output') +@click.argument("output") def predict_category(output: str): + from robotoff import settings from robotoff.elasticsearch.category.predict import predict_from_dataset - from robotoff.utils import dump_jsonl from robotoff.products import ProductDataset - from robotoff import settings + from robotoff.utils import dump_jsonl + dataset = ProductDataset(settings.JSONL_DATASET_PATH) dump_jsonl(output, predict_from_dataset(dataset)) @@ -59,5 +64,5 @@ def predict_category(output: str): cli.add_command(predict_category) -if __name__ == '__main__': +if __name__ == "__main__": cli() diff --git a/robotoff/app/api.py b/robotoff/app/api.py index 37e7e22b44..d7a0338bfa 100644 --- a/robotoff/app/api.py +++ b/robotoff/app/api.py @@ -1,36 +1,30 @@ +import dataclasses import io import itertools from typing import List, Optional -import dataclasses - import falcon +import sentry_sdk from falcon_cors import CORS from falcon_multipart.middleware import MultipartMiddleware - from PIL import Image +from sentry_sdk.integrations.wsgi import SentryWsgiMiddleware from robotoff import settings -from robotoff.app.core import (get_insights, - get_random_insight, - save_insight) +from robotoff.app.core import get_insights, get_random_insight, save_insight from robotoff.app.middleware import DBConnectionMiddleware -from robotoff.ingredients import generate_corrections, generate_corrected_text +from robotoff.ingredients import generate_corrected_text, generate_corrections from robotoff.insights._enum import InsightType -from robotoff.insights.question import QuestionFormatterFactory, \ - QuestionFormatter +from robotoff.insights.question import QuestionFormatter, QuestionFormatterFactory from robotoff.ml.object_detection import ObjectDetectionModelRegistry from robotoff.products import get_product_dataset_etag -from robotoff.taxonomy import TAXONOMY_STORES, TaxonomyType, Taxonomy -from robotoff.utils import get_logger, get_image_from_url +from robotoff.taxonomy import TAXONOMY_STORES, Taxonomy, TaxonomyType +from robotoff.utils import get_image_from_url, get_logger from robotoff.utils.es import get_es_client from robotoff.utils.i18n import TranslationStore from robotoff.utils.types import JSONType from robotoff.workers.client import send_ipc_event -import sentry_sdk -from sentry_sdk.integrations.wsgi import SentryWsgiMiddleware - logger = get_logger() es_client = get_es_client() @@ -43,8 +37,7 @@ def init_sentry(app): if settings.SENTRY_DSN: - sentry_sdk.init( - dsn=settings.SENTRY_DSN) + sentry_sdk.init(dsn=settings.SENTRY_DSN) return SentryWsgiMiddleware(app) return app @@ -56,38 +49,37 @@ def on_get(self, req, resp, barcode): insights = [i.serialize() for i in get_insights(barcode=barcode)] if not insights: - response['status'] = "no_insights" + response["status"] = "no_insights" else: - response['insights'] = insights - response['status'] = "found" + response["insights"] = insights + response["status"] = "found" resp.media = response class RandomInsightResource: def on_get(self, req, resp): - insight_type = req.get_param('type') or None - country = req.get_param('country') or None + insight_type = req.get_param("type") or None + country = req.get_param("country") or None response = {} insight = get_random_insight(insight_type, country) if not insight: - response['status'] = "no_insights" + response["status"] = "no_insights" else: - response['insight'] = insight.serialize() - response['status'] = "found" + response["insight"] = insight.serialize() + response["status"] = "found" resp.media = response class AnnotateInsightResource: def on_post(self, req, resp): - insight_id = req.get_param('insight_id', required=True) - annotation = req.get_param_as_int('annotation', required=True, - min=-1, max=1) + insight_id = req.get_param("insight_id", required=True) + annotation = req.get_param_as_int("annotation", required=True, min=-1, max=1) - update = req.get_param_as_bool('update') + update = req.get_param_as_bool("update") if update is None: update = True @@ -95,111 +87,119 @@ def on_post(self, req, resp): annotation_result = save_insight(insight_id, annotation, update=update) resp.media = { - 'status': annotation_result.status, - 'description': annotation_result.description, + "status": annotation_result.status, + "description": annotation_result.description, } class IngredientSpellcheckResource: def on_post(self, req, resp): - text = req.get_param('text', required=True) + text = req.get_param("text", required=True) corrections = generate_corrections(es_client, text, confidence=1) - term_corrections = list(itertools.chain - .from_iterable((c.term_corrections - for c in corrections))) + term_corrections = list( + itertools.chain.from_iterable((c.term_corrections for c in corrections)) + ) resp.media = { - 'corrections': [dataclasses.asdict(c) for c in corrections], - 'text': text, - 'corrected': generate_corrected_text(term_corrections, text), + "corrections": [dataclasses.asdict(c) for c in corrections], + "text": text, + "corrected": generate_corrected_text(term_corrections, text), } class UpdateDatasetResource: def on_post(self, req, resp): - send_ipc_event('download_dataset') + send_ipc_event("download_dataset") resp.media = { - 'status': 'scheduled', + "status": "scheduled", } def on_get(self, req, resp): resp.media = { - 'etag': get_product_dataset_etag(), + "etag": get_product_dataset_etag(), } class InsightImporterResource: def on_post(self, req, resp): logger.info("New insight import request") - insight_type = req.get_param('type', required=True) + insight_type = req.get_param("type", required=True) if insight_type not in (t.name for t in InsightType): - raise falcon.HTTPBadRequest(description="unknown insight type: " - "'{}'".format(insight_type)) + raise falcon.HTTPBadRequest( + description="unknown insight type: " "'{}'".format(insight_type) + ) - content = req.get_param('file', required=True) + content = req.get_param("file", required=True) logger.info("Insight type: '{}'".format(insight_type)) lines = [l for l in io.TextIOWrapper(content.file)] - send_ipc_event('import_insights', { - 'insight_type': insight_type, - 'items': lines, - }) + send_ipc_event( + "import_insights", + { + "insight_type": insight_type, + "items": lines, + }, + ) logger.info("Import scheduled") resp.media = { - 'status': 'scheduled', + "status": "scheduled", } class ImageImporterResource: def on_post(self, req, resp): - barcode = req.get_param('barcode', required=True) - image_url = req.get_param('image_url', required=True) - ocr_url = req.get_param('ocr_url', required=True) - server_domain = req.get_param('server_domain', required=True) + barcode = req.get_param("barcode", required=True) + image_url = req.get_param("image_url", required=True) + ocr_url = req.get_param("ocr_url", required=True) + server_domain = req.get_param("server_domain", required=True) if server_domain != settings.OFF_SERVER_DOMAIN: logger.info("Rejecting image import from {}".format(server_domain)) resp.media = { - 'status': 'rejected', + "status": "rejected", } return - send_ipc_event('import_image', { - 'barcode': barcode, - 'image_url': image_url, - 'ocr_url': ocr_url, - }) + send_ipc_event( + "import_image", + { + "barcode": barcode, + "image_url": image_url, + "ocr_url": ocr_url, + }, + ) resp.media = { - 'status': 'scheduled', + "status": "scheduled", } class ImagePredictorResource: def on_get(self, req, resp): - image_url = req.get_param('image_url', required=True) - models: List[str] = req.get_param_as_list('models') + image_url = req.get_param("image_url", required=True) + models: List[str] = req.get_param_as_list("models") available_models = ObjectDetectionModelRegistry.get_available_models() if models is None: - models = ['nutrition-table'] + models = ["nutrition-table"] else: for model_name in models: if model_name not in available_models: raise falcon.HTTPBadRequest( "invalid_model", "unknown model {}, available models: {}" - "".format(model_name, ', '.join(available_models))) + "".format(model_name, ", ".join(available_models)), + ) - output_image = req.get_param_as_bool('output_image') + output_image = req.get_param_as_bool("output_image") if output_image is None: output_image = False @@ -208,7 +208,8 @@ def on_get(self, req, resp): raise falcon.HTTPBadRequest( "invalid_request", "a single model must be specified with the `models` parameter " - "when `output_image` is True") + "when `output_image` is True", + ) image = get_image_from_url(image_url) @@ -223,21 +224,18 @@ def on_get(self, req, resp): result = model.detect_from_image(image, output_image=output_image) if output_image: - self.image_response(result.boxed_image, - resp) + self.image_response(result.boxed_image, resp) return else: predictions[model_name] = result.to_json() - resp.media = { - 'predictions': predictions - } + resp.media = {"predictions": predictions} @staticmethod def image_response(image: Image.Image, resp: falcon.Response) -> None: - resp.content_type = 'image/jpeg' + resp.content_type = "image/jpeg" fp = io.BytesIO() - image.save(fp, 'JPEG') + image.save(fp, "JPEG") resp.stream_len = fp.tell() fp.seek(0) resp.stream = fp @@ -245,54 +243,63 @@ def image_response(image: Image.Image, resp: falcon.Response) -> None: class WebhookProductResource: def on_post(self, req, resp): - barcode = req.get_param('barcode', required=True) - action = req.get_param('action', required=True) - server_domain = req.get_param('server_domain', required=True) + barcode = req.get_param("barcode", required=True) + action = req.get_param("action", required=True) + server_domain = req.get_param("server_domain", required=True) if server_domain != settings.OFF_SERVER_DOMAIN: logger.info("Rejecting webhook event from {}".format(server_domain)) resp.media = { - 'status': 'rejected', + "status": "rejected", } return - logger.info("New webhook event received for product {} (action: {}, " - "domain: {})".format(barcode, action, server_domain)) + logger.info( + "New webhook event received for product {} (action: {}, " + "domain: {})".format(barcode, action, server_domain) + ) - if action not in ('updated', 'deleted'): - raise falcon.HTTPBadRequest(title="invalid_action", - description="action must be one of " - "`deleted`, `updated`") - - if action == 'updated': - send_ipc_event('product_updated', { - 'barcode': barcode, - }) - - elif action == 'deleted': - send_ipc_event('product_deleted', { - 'barcode': barcode, - }) + if action not in ("updated", "deleted"): + raise falcon.HTTPBadRequest( + title="invalid_action", + description="action must be one of " "`deleted`, `updated`", + ) + + if action == "updated": + send_ipc_event( + "product_updated", + { + "barcode": barcode, + }, + ) + + elif action == "deleted": + send_ipc_event( + "product_deleted", + { + "barcode": barcode, + }, + ) resp.media = { - 'status': 'scheduled', + "status": "scheduled", } class ProductQuestionsResource: def on_get(self, req, resp, barcode): response = {} - count: int = req.get_param_as_int('count', min=1) or 1 - lang: str = req.get_param('lang', default='en') + count: int = req.get_param_as_int("count", min=1) or 1 + lang: str = req.get_param("lang", default="en") keep_types = QuestionFormatterFactory.get_available_types() - insights = list(get_insights(barcode=barcode, - keep_types=keep_types, - count=count)) + insights = list( + get_insights(barcode=barcode, keep_types=keep_types, count=count) + ) if not insights: - response['questions'] = [] - response['status'] = "no_questions" + response["questions"] = [] + response["status"] = "no_questions" else: questions: List[JSONType] = [] @@ -302,8 +309,8 @@ def on_get(self, req, resp, barcode): question = formatter.format_question(insight, lang) questions.append(question.serialize()) - response['questions'] = questions - response['status'] = "found" + response["questions"] = questions + response["status"] = "found" resp.media = response @@ -311,12 +318,13 @@ def on_get(self, req, resp, barcode): class RandomQuestionsResource: def on_get(self, req, resp): response = {} - count: int = req.get_param_as_int('count', min=1) or 1 - lang: str = req.get_param('lang', default='en') + count: int = req.get_param_as_int("count", min=1) or 1 + lang: str = req.get_param("lang", default="en") keep_types: Optional[List[str]] = req.get_param_as_list( - 'insight_types', required=False) - country: Optional[str] = req.get_param('country') or None - brands = req.get_param_as_list('brands') or None + "insight_types", required=False + ) + country: Optional[str] = req.get_param("country") or None + brands = req.get_param_as_list("brands") or None if keep_types is None: keep_types = QuestionFormatterFactory.get_available_types() @@ -328,14 +336,15 @@ def on_get(self, req, resp): # Limit the number of brands to prevent slow SQL queries brands = brands[:10] - insights = list(get_insights(keep_types=keep_types, - count=count, - country=country, - brands=brands)) + insights = list( + get_insights( + keep_types=keep_types, count=count, country=country, brands=brands + ) + ) if not insights: - response['questions'] = [] - response['status'] = "no_questions" + response["questions"] = [] + response["status"] = "no_questions" else: questions: List[JSONType] = [] @@ -349,34 +358,29 @@ def on_get(self, req, resp): question = formatter.format_question(insight, lang) questions.append(question.serialize()) - response['questions'] = questions - response['status'] = "found" + response["questions"] = questions + response["status"] = "found" resp.media = response -cors = CORS(allow_all_origins=True, - allow_all_headers=True, - allow_all_methods=True) +cors = CORS(allow_all_origins=True, allow_all_headers=True, allow_all_methods=True) -api = falcon.API(middleware=[cors.middleware, - MultipartMiddleware(), - DBConnectionMiddleware()]) +api = falcon.API( + middleware=[cors.middleware, MultipartMiddleware(), DBConnectionMiddleware()] +) # Parse form parameters api.req_options.auto_parse_form_urlencoded = True -api.add_route('/api/v1/insights/{barcode}', ProductInsightResource()) -api.add_route('/api/v1/insights/random', RandomInsightResource()) -api.add_route('/api/v1/insights/annotate', AnnotateInsightResource()) -api.add_route('/api/v1/insights/import', InsightImporterResource()) -api.add_route('/api/v1/predict/ingredients/spellcheck', - IngredientSpellcheckResource()) -api.add_route('/api/v1/products/dataset', - UpdateDatasetResource()) -api.add_route('/api/v1/webhook/product', - WebhookProductResource()) -api.add_route('/api/v1/images/import', ImageImporterResource()) -api.add_route('/api/v1/images/predict', ImagePredictorResource()) -api.add_route('/api/v1/questions/{barcode}', ProductQuestionsResource()) -api.add_route('/api/v1/questions/random', RandomQuestionsResource()) +api.add_route("/api/v1/insights/{barcode}", ProductInsightResource()) +api.add_route("/api/v1/insights/random", RandomInsightResource()) +api.add_route("/api/v1/insights/annotate", AnnotateInsightResource()) +api.add_route("/api/v1/insights/import", InsightImporterResource()) +api.add_route("/api/v1/predict/ingredients/spellcheck", IngredientSpellcheckResource()) +api.add_route("/api/v1/products/dataset", UpdateDatasetResource()) +api.add_route("/api/v1/webhook/product", WebhookProductResource()) +api.add_route("/api/v1/images/import", ImageImporterResource()) +api.add_route("/api/v1/images/predict", ImagePredictorResource()) +api.add_route("/api/v1/questions/{barcode}", ProductQuestionsResource()) +api.add_route("/api/v1/questions/random", RandomQuestionsResource()) api = init_sentry(api) diff --git a/robotoff/app/core.py b/robotoff/app/core.py index 2f953444b5..bea9e8dd99 100644 --- a/robotoff/app/core.py +++ b/robotoff/app/core.py @@ -1,28 +1,30 @@ import tempfile +from typing import Iterable, List, Optional, Union +import peewee import requests -from typing import Union, Optional, List, Iterable +from PIL import Image -from robotoff.insights.annotate import (InsightAnnotatorFactory, - AnnotationResult, - ALREADY_ANNOTATED_RESULT, - UNKNOWN_INSIGHT_RESULT) +from robotoff.insights.annotate import ( + ALREADY_ANNOTATED_RESULT, + UNKNOWN_INSIGHT_RESULT, + AnnotationResult, + InsightAnnotatorFactory, +) from robotoff.models import ProductInsight from robotoff.off import get_product from robotoff.utils import get_logger -from PIL import Image - -import peewee - logger = get_logger(__name__) -def get_insights(barcode: Optional[str] = None, - keep_types: List[str] = None, - country: str = None, - brands: List[str] = None, - count=25) -> Iterable[ProductInsight]: +def get_insights( + barcode: Optional[str] = None, + keep_types: List[str] = None, + country: str = None, + brands: List[str] = None, + count=25, +) -> Iterable[ProductInsight]: where_clauses = [ ProductInsight.annotation.is_null(), ] @@ -34,22 +36,23 @@ def get_insights(barcode: Optional[str] = None, where_clauses.append(ProductInsight.type.in_(keep_types)) if country is not None: - where_clauses.append(ProductInsight.countries.contains( - country)) + where_clauses.append(ProductInsight.countries.contains(country)) if brands: - where_clauses.append(ProductInsight.brands.contains_any( - brands)) - - query = (ProductInsight.select() - .where(*where_clauses) - .limit(count) - .order_by(peewee.fn.Random())) + where_clauses.append(ProductInsight.brands.contains_any(brands)) + + query = ( + ProductInsight.select() + .where(*where_clauses) + .limit(count) + .order_by(peewee.fn.Random()) + ) return query.iterator() -def get_random_insight(insight_type: str = None, - country: str = None) -> Optional[ProductInsight]: +def get_random_insight( + insight_type: str = None, country: str = None +) -> Optional[ProductInsight]: attempts = 0 while True: attempts += 1 @@ -61,12 +64,10 @@ def get_random_insight(insight_type: str = None, where_clauses = [ProductInsight.annotation.is_null()] if country is not None: - where_clauses.append(ProductInsight.countries.contains( - country)) + where_clauses.append(ProductInsight.countries.contains(country)) if insight_type is not None: - where_clauses.append(ProductInsight.type == - insight_type) + where_clauses.append(ProductInsight.type == insight_type) query = query.where(*where_clauses).order_by(peewee.fn.Random()) @@ -78,7 +79,7 @@ def get_random_insight(insight_type: str = None, insight = insight_list[0] # We only need to know if the product exists, so fetching barcode # is enough - product = get_product(insight.barcode, ['code']) + product = get_product(insight.barcode, ["code"]) # Product may be None if not found if product: @@ -88,11 +89,11 @@ def get_random_insight(insight_type: str = None, logger.info("Product not found, insight deleted") -def save_insight(insight_id: str, annotation: int, update: bool = True) \ - -> AnnotationResult: +def save_insight( + insight_id: str, annotation: int, update: bool = True +) -> AnnotationResult: try: - insight: Union[ProductInsight, None] \ - = ProductInsight.get_by_id(insight_id) + insight: Union[ProductInsight, None] = ProductInsight.get_by_id(insight_id) except ProductInsight.DoesNotExist: insight = None diff --git a/robotoff/cli/annotate.py b/robotoff/cli/annotate.py index 92f8153891..76794eb96b 100644 --- a/robotoff/cli/annotate.py +++ b/robotoff/cli/annotate.py @@ -1,6 +1,7 @@ import json from difflib import SequenceMatcher -from typing import Optional, Dict +from typing import Dict, Optional + import click import requests @@ -32,47 +33,47 @@ def run(insight_type: Optional[str], country: Optional[str]): click.echo("No insight left") -def run_loop(insight_type: Optional[str], - country: Optional[str]) -> None: +def run_loop(insight_type: Optional[str], country: Optional[str]) -> None: insight = get_random_insight(insight_type, country) print_insight(insight) annotation = None while annotation is None: - annotation = click.prompt('Annotation [-1, 0, 1]: ', type=int) + annotation = click.prompt("Annotation [-1, 0, 1]: ", type=int) if annotation not in (0, 1, -1): click.echo("Invalid value: 0, 1 or -1 expected", err=True) annotation = None - response = save_insight(insight['id'], annotation=annotation) + response = save_insight(insight["id"], annotation=annotation) click.echo(json.dumps(response, indent=4) + "\n") -def get_random_insight(insight_type: Optional[str] = None, - country: Optional[str] = None) -> JSONType: +def get_random_insight( + insight_type: Optional[str] = None, country: Optional[str] = None +) -> JSONType: params = {} if insight_type: - params['type'] = insight_type + params["type"] = insight_type if country: - params['country'] = country + params["country"] = country r = http_session.get(RANDOM_INSIGHT_URL, params=params) data = r.json() - if data['status'] == 'no_insights': + if data["status"] == "no_insights": raise NoInsightException() - return data['insight'] + return data["insight"] def save_insight(insight_id: str, annotation: int): params = { - 'insight_id': insight_id, - 'annotation': str(annotation), + "insight_id": insight_id, + "annotation": str(annotation), } r = http_session.post(ANNOTATE_INSIGHT_URL, data=params) @@ -82,9 +83,9 @@ def save_insight(insight_id: str, annotation: int): def print_insight(insight: Dict) -> None: - insight_type = insight.get('type') + insight_type = insight.get("type") - if insight_type == 'ingredient_spellcheck': + if insight_type == "ingredient_spellcheck": print_ingredient_spellcheck_insight(insight) else: @@ -93,27 +94,32 @@ def print_insight(insight: Dict) -> None: def print_generic_insight(insight: JSONType) -> None: for key, value in insight.items(): - click.echo('{}: {}'.format(key, str(value))) + click.echo("{}: {}".format(key, str(value))) - click.echo("url: {}".format("https://fr.openfoodfacts.org/produit/" - "{}".format(insight['barcode']))) + click.echo( + "url: {}".format( + "https://fr.openfoodfacts.org/produit/" "{}".format(insight["barcode"]) + ) + ) - if 'source' in insight: - click.echo("image: {}{}".format(STATIC_IMAGE_DIR_URL, - insight['source'])) + if "source" in insight: + click.echo("image: {}{}".format(STATIC_IMAGE_DIR_URL, insight["source"])) click.echo("") def print_ingredient_spellcheck_insight(insight: JSONType) -> None: - for key in ('id', 'type', 'barcode', 'countries'): + for key in ("id", "type", "barcode", "countries"): value = insight.get(key) - click.echo('{}: {}'.format(key, str(value))) + click.echo("{}: {}".format(key, str(value))) - click.echo("url: {}".format("https://fr.openfoodfacts.org/produit/" - "{}".format(insight['barcode']))) + click.echo( + "url: {}".format( + "https://fr.openfoodfacts.org/produit/" "{}".format(insight["barcode"]) + ) + ) - original_snippet = insight['original_snippet'] - corrected_snippet = insight['corrected_snippet'] + original_snippet = insight["original_snippet"] + corrected_snippet = insight["corrected_snippet"] click.echo(generate_colored_diff(original_snippet, corrected_snippet)) click.echo("") @@ -123,14 +129,14 @@ def generate_colored_diff(original: str, correction: str) -> str: diff = "" for opcode, i1, i2, j1, j2 in matcher.get_opcodes(): - if opcode == 'equal': + if opcode == "equal": diff += original[i1:i2] - elif opcode == 'insert': - diff += click.style(correction[j1:j2], fg='black', bg='green') - elif opcode == 'delete': - diff += click.style(original[i1:i2], fg='black', bg='red') - elif opcode == 'replace': - diff += click.style(original[i1:i2], fg='black', bg='red') - diff += click.style(correction[j1:j2], fg='black', bg='green') + elif opcode == "insert": + diff += click.style(correction[j1:j2], fg="black", bg="green") + elif opcode == "delete": + diff += click.style(original[i1:i2], fg="black", bg="red") + elif opcode == "replace": + diff += click.style(original[i1:i2], fg="black", bg="red") + diff += click.style(correction[j1:j2], fg="black", bg="green") return diff diff --git a/robotoff/cli/batch.py b/robotoff/cli/batch.py index 6a70b9c5ec..4d303a1261 100644 --- a/robotoff/cli/batch.py +++ b/robotoff/cli/batch.py @@ -5,22 +5,18 @@ from robotoff.models import ProductInsight -def run(insight_type: str, - dry: bool = True, - json_contains_str: Optional[str] = None): +def run(insight_type: str, dry: bool = True, json_contains_str: Optional[str] = None): if json_contains_str is not None: json_contains = ast.literal_eval(json_contains_str) else: json_contains = None - batch_annotate(insight_type, - dry, - json_contains) + batch_annotate(insight_type, dry, json_contains) -def batch_annotate(insight_type: str, - dry: bool = True, - json_contains: Optional[Dict] = None): +def batch_annotate( + insight_type: str, dry: bool = True, json_contains: Optional[Dict] = None +): annotator = InsightAnnotatorFactory.get(insight_type) i = 0 @@ -28,7 +24,7 @@ def batch_annotate(insight_type: str, query = ProductInsight.select() where_clauses = [ ProductInsight.type == insight_type, - ProductInsight.annotation.is_null() + ProductInsight.annotation.is_null(), ] if json_contains is not None: @@ -38,18 +34,21 @@ def batch_annotate(insight_type: str, if dry: count = query.count() - print("-- dry run --\n" - "{} items matching filter:\n" - " insight type: {}\n" - " filter: {}" - "".format(count, insight_type, json_contains)) + print( + "-- dry run --\n" + "{} items matching filter:\n" + " insight type: {}\n" + " filter: {}" + "".format(count, insight_type, json_contains) + ) else: for insight in query: i += 1 print("Insight %d" % i) - print("Add label {} to https://fr.openfoodfacts.org/produit/{}" - "".format(insight.data, - insight.barcode)) + print( + "Add label {} to https://fr.openfoodfacts.org/produit/{}" + "".format(insight.data, insight.barcode) + ) print(insight.data) annotator.annotate(insight, 1, update=True) diff --git a/robotoff/cli/insights.py b/robotoff/cli/insights.py index 9f73c7975c..9f8016d4f0 100644 --- a/robotoff/cli/insights.py +++ b/robotoff/cli/insights.py @@ -5,14 +5,17 @@ import click -from robotoff.insights.ocr import (ocr_iter, OCRResult, - extract_insights, - get_barcode_from_path) +from robotoff.insights.ocr import ( + OCRResult, + extract_insights, + get_barcode_from_path, + ocr_iter, +) def run_from_ocr_archive(input_: str, insight_type: str, output: Optional[str]): if output is not None: - output_f = open(output, 'w') + output_f = open(output, "w") else: output_f = sys.stdout @@ -24,8 +27,9 @@ def run_from_ocr_archive(input_: str, insight_type: str, output: Optional[str]): barcode: Optional[str] = get_barcode_from_path(source) if barcode is None: - click.echo("cannot extract barcode from source " - "{}".format(source), err=True) + click.echo( + "cannot extract barcode from source " "{}".format(source), err=True + ) continue ocr_result: Optional[OCRResult] = OCRResult.from_json(ocr_json) @@ -37,12 +41,12 @@ def run_from_ocr_archive(input_: str, insight_type: str, output: Optional[str]): if insights: item = { - 'insights': insights, - 'barcode': barcode, - 'type': insight_type, + "insights": insights, + "barcode": barcode, + "type": insight_type, } if source: - item['source'] = source + item["source"] = source - output_f.write(json.dumps(item) + '\n') + output_f.write(json.dumps(item) + "\n") diff --git a/robotoff/cli/run.py b/robotoff/cli/run.py index f50d98957e..a722295f7a 100644 --- a/robotoff/cli/run.py +++ b/robotoff/cli/run.py @@ -1,23 +1,30 @@ import subprocess import click + from robotoff import settings def run(service: str): - if service == 'api': - subprocess.run(["gunicorn", "--config", - str(settings.PROJECT_DIR / "gunicorn.conf"), - "robotoff.app.api:api"]) + if service == "api": + subprocess.run( + [ + "gunicorn", + "--config", + str(settings.PROJECT_DIR / "gunicorn.conf"), + "robotoff.app.api:api", + ] + ) - elif service == 'workers': + elif service == "workers": from robotoff.workers import listener listener.run() - elif service == 'scheduler': + elif service == "scheduler": from robotoff import scheduler from robotoff.utils import get_logger + # Defining a root logger get_logger() scheduler.run() diff --git a/robotoff/elasticsearch/category/dump.py b/robotoff/elasticsearch/category/dump.py index 4b220ce845..8b7ab768e9 100644 --- a/robotoff/elasticsearch/category/dump.py +++ b/robotoff/elasticsearch/category/dump.py @@ -1,20 +1,20 @@ import hashlib -from typing import Iterable, Tuple, Dict +from typing import Dict, Iterable, Tuple +from robotoff import settings from robotoff.insights._enum import InsightType from robotoff.taxonomy import TAXONOMY_STORES, Taxonomy from robotoff.utils import get_logger from robotoff.utils.es import get_es_client, perform_export -from robotoff import settings logger = get_logger() SUPPORTED_LANG = { - 'fr', - 'en', - 'es', - 'de', + "fr", + "en", + "es", + "de", } @@ -26,32 +26,33 @@ def category_export(): delete_categories(client) logger.info("Starting export...") category_data = generate_category_data(category_taxonomy) - rows_inserted = perform_export(client, category_data, - settings.ELASTICSEARCH_CATEGORY_INDEX) + rows_inserted = perform_export( + client, category_data, settings.ELASTICSEARCH_CATEGORY_INDEX + ) logger.info("%d rows inserted" % rows_inserted) def generate_category_data(category_taxonomy: Taxonomy) -> Iterable[Tuple[str, Dict]]: for category_node in category_taxonomy.iter_nodes(): - supported_langs = [lang for lang in category_node.names - if lang in SUPPORTED_LANG] + supported_langs = [ + lang for lang in category_node.names if lang in SUPPORTED_LANG + ] - data = { - f"{lang}:name": category_node.names[lang] - for lang in supported_langs - } - data['id'] = category_node.id + data = {f"{lang}:name": category_node.names[lang] for lang in supported_langs} + data["id"] = category_node.id - id_ = hashlib.sha256(category_node.id.encode('utf-8')).hexdigest() + id_ = hashlib.sha256(category_node.id.encode("utf-8")).hexdigest() yield id_, data def delete_categories(client): body = {"query": {"match_all": {}}} - client.delete_by_query(body=body, - index=settings.ELASTICSEARCH_CATEGORY_INDEX, - doc_type=settings.ELASTICSEARCH_TYPE) + client.delete_by_query( + body=body, + index=settings.ELASTICSEARCH_CATEGORY_INDEX, + doc_type=settings.ELASTICSEARCH_TYPE, + ) if __name__ == "__main__": diff --git a/robotoff/elasticsearch/category/match.py b/robotoff/elasticsearch/category/match.py index 8e5b9339d0..ad8ba957c5 100644 --- a/robotoff/elasticsearch/category/match.py +++ b/robotoff/elasticsearch/category/match.py @@ -1,15 +1,15 @@ -import json import argparse -from typing import Tuple, Optional +import json +from typing import Optional, Tuple -from robotoff.utils.es import get_es_client from robotoff import settings +from robotoff.utils.es import get_es_client SUPPORTED_LANG = { - 'fr', - 'en', - 'es', - 'de', + "fr", + "en", + "es", + "de", } @@ -19,21 +19,23 @@ def predict_category(client, name: str, lang: str) -> Optional[Tuple[str, float] results = match(client, name, lang) - hits = results['hits']['hits'] + hits = results["hits"]["hits"] if hits: hit = hits[0] - return hit['_source']['id'], hit['_score'] - + return hit["_source"]["id"], hit["_score"] + return None def match(client, query: str, lang: str): body = generate_request(query, lang) - return client.search(index=settings.ELASTICSEARCH_CATEGORY_INDEX, - doc_type=settings.ELASTICSEARCH_TYPE, - body=body, - _source=True) + return client.search( + index=settings.ELASTICSEARCH_CATEGORY_INDEX, + doc_type=settings.ELASTICSEARCH_TYPE, + body=body, + _source=True, + ) def generate_request(query: str, lang: str): @@ -51,7 +53,7 @@ def generate_request(query: str, lang: str): def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("query", help="query to search") - parser.add_argument("--lang", help="language of the query", default='fr') + parser.add_argument("--lang", help="language of the query", default="fr") return parser.parse_args() @@ -59,4 +61,4 @@ def parse_args(): args = parse_args() es_client = get_es_client() results = match(es_client, args.query, args.lang) - print(json.dumps(results['hits'], indent=4)) + print(json.dumps(results["hits"], indent=4)) diff --git a/robotoff/elasticsearch/category/predict.py b/robotoff/elasticsearch/category/predict.py index eb4a97e861..b786eac7ba 100644 --- a/robotoff/elasticsearch/category/predict.py +++ b/robotoff/elasticsearch/category/predict.py @@ -1,12 +1,11 @@ import datetime import operator -from typing import Iterable, Dict, Optional +from typing import Dict, Iterable, Optional +from robotoff.elasticsearch.category.match import predict_category from robotoff.products import ProductDataset from robotoff.utils import get_logger - from robotoff.utils.es import get_es_client -from robotoff.elasticsearch.category.match import predict_category from robotoff.utils.types import JSONType logger = get_logger(__name__) @@ -15,7 +14,7 @@ def predict(client, product: Dict) -> Optional[Dict]: predictions = [] - for lang in product.get('languages_codes', []): + for lang in product.get("languages_codes", []): product_name = product.get(f"product_name_{lang}") if not product_name: @@ -32,19 +31,19 @@ def predict(client, product: Dict) -> Optional[Dict]: if predictions: # Sort by descending score - sorted_predictions = sorted(predictions, - key=operator.itemgetter(2), - reverse=True) + sorted_predictions = sorted( + predictions, key=operator.itemgetter(2), reverse=True + ) prediction = sorted_predictions[0] lang, category, product_name, score = prediction return { - 'barcode': product['code'], - 'category': category, - 'matcher_lang': lang, - 'product_name': product_name, - 'model': 'matcher', + "barcode": product["code"], + "category": category, + "matcher_lang": lang, + "product_name": product_name, + "model": "matcher", } return None @@ -63,9 +62,9 @@ def predict_from_iterable(client, products: Iterable[Dict]) -> Iterable[Dict]: yield prediction -def predict_from_dataset(dataset: ProductDataset, - from_datetime: Optional[datetime.datetime] = None) -> \ - Iterable[JSONType]: +def predict_from_dataset( + dataset: ProductDataset, from_datetime: Optional[datetime.datetime] = None +) -> Iterable[JSONType]: """Return an iterable of category insights, using the provided dataset. Args: @@ -73,16 +72,19 @@ def predict_from_dataset(dataset: ProductDataset, from_datetime: datetime threshold: only keep products modified after `from_datetime` """ - product_stream = (dataset.stream() - .filter_nonempty_text_field('code') - .filter_nonempty_text_field('product_name') - .filter_empty_tag_field('categories_tags') - .filter_nonempty_tag_field('countries_tags') - .filter_nonempty_tag_field('languages_codes')) + product_stream = ( + dataset.stream() + .filter_nonempty_text_field("code") + .filter_nonempty_text_field("product_name") + .filter_empty_tag_field("categories_tags") + .filter_nonempty_tag_field("countries_tags") + .filter_nonempty_tag_field("languages_codes") + ) if from_datetime: product_stream = product_stream.filter_by_modified_datetime( - from_t=from_datetime) + from_t=from_datetime + ) product_iter = product_stream.iter() logger.info("Performing prediction on products without categories") diff --git a/robotoff/elasticsearch/product/dump.py b/robotoff/elasticsearch/product/dump.py index 3eae5c0ae8..8a54d31dfa 100644 --- a/robotoff/elasticsearch/product/dump.py +++ b/robotoff/elasticsearch/product/dump.py @@ -1,13 +1,11 @@ import re +from robotoff import settings from robotoff.ingredients import process_ingredients from robotoff.products import ProductDataset -from robotoff import settings from robotoff.utils import get_logger - from robotoff.utils.es import get_es_client, perform_export - logger = get_logger(__name__) @@ -17,18 +15,30 @@ def product_export(): dataset = ProductDataset(settings.JSONL_DATASET_PATH) - product_iter = (dataset.stream() - .filter_by_country_tag('en:france') - .filter_nonempty_text_field('ingredients_text_fr') - .filter_by_state_tag('en:complete') - .iter()) - product_iter = (p for p in product_iter - if 'ingredients-unknown-score-above-0' - not in p.get('quality_tags', [])) - - data = ((product['code'], - {'ingredients_text_fr': normalize_ingredient_list(product['ingredients_text_fr'])}) - for product in product_iter) + product_iter = ( + dataset.stream() + .filter_by_country_tag("en:france") + .filter_nonempty_text_field("ingredients_text_fr") + .filter_by_state_tag("en:complete") + .iter() + ) + product_iter = ( + p + for p in product_iter + if "ingredients-unknown-score-above-0" not in p.get("quality_tags", []) + ) + + data = ( + ( + product["code"], + { + "ingredients_text_fr": normalize_ingredient_list( + product["ingredients_text_fr"] + ) + }, + ) + for product in product_iter + ) logger.info("Importing products") @@ -37,7 +47,7 @@ def product_export(): def empty_ingredient(ingredient: str) -> bool: - return not bool(ingredient.strip(' /-.%0123456789')) + return not bool(ingredient.strip(" /-.%0123456789")) def normalize_ingredient_list(ingredient_text: str): @@ -49,8 +59,8 @@ def normalize_ingredient_list(ingredient_text: str): if empty_ingredient(ingredient): continue - ingredient = MULTIPLE_SPACES_RE.sub(' ', ingredient) - ingredient = ingredient.strip(' .') + ingredient = MULTIPLE_SPACES_RE.sub(" ", ingredient) + ingredient = ingredient.strip(" .") normalized.append(ingredient) return normalized diff --git a/robotoff/ingredients.py b/robotoff/ingredients.py index f6cae7288b..2f90b07986 100644 --- a/robotoff/ingredients.py +++ b/robotoff/ingredients.py @@ -1,17 +1,16 @@ +import dataclasses import itertools import json import operator import re - -import dataclasses from dataclasses import dataclass, field -from typing import List, Tuple, Iterable, Dict +from typing import Dict, Iterable, List, Tuple from robotoff import settings from robotoff.products import ProductDataset -from robotoff.utils.es import get_es_client, generate_msearch_body +from robotoff.utils.es import generate_msearch_body, get_es_client -SPLITTER_CHAR = {'(', ')', ',', ';', '[', ']', '-', '{', '}'} +SPLITTER_CHAR = {"(", ")", ",", ";", "[", "]", "-", "{", "}"} # Food additives (EXXX) may be mistaken from one another, because of their edit distance proximity BLACKLIST_RE = re.compile(r"(?:\d+(?:,\d+)?\s*%)|(?:E\d{3})|(?:[_•])") @@ -67,7 +66,7 @@ def normalize_ingredients(ingredient_text: str): if match: start = match.start() end = match.end() - normalized = normalized[:start] + ' ' * (end - start) + normalized[end:] + normalized = normalized[:start] + " " * (end - start) + normalized[end:] else: break @@ -85,14 +84,14 @@ def process_ingredients(ingredient_text: str) -> Ingredients: if char in SPLITTER_CHAR: offsets.append((start_idx, idx)) start_idx = idx + 1 - chars.append(' ') + chars.append(" ") else: chars.append(char) if start_idx != len(normalized): offsets.append((start_idx, len(normalized))) - normalized = ''.join(chars) + normalized = "".join(chars) return Ingredients(ingredient_text, normalized, offsets) @@ -101,22 +100,24 @@ def generate_corrections(client, ingredients_text: str, **kwargs) -> List[Correc ingredients: Ingredients = process_ingredients(ingredients_text) normalized_ingredients: Iterable[str] = ingredients.iter_normalized_ingredients() - for idx, suggestions in enumerate(_suggest_batch(client, normalized_ingredients, **kwargs)): + for idx, suggestions in enumerate( + _suggest_batch(client, normalized_ingredients, **kwargs) + ): offsets = ingredients.offsets[idx] normalized_ingredient = ingredients.get_normalized_ingredient(idx) - options = suggestions['options'] + options = suggestions["options"] if not options: continue option = options[0] original_tokens = analyze(client, normalized_ingredient) - suggestion_tokens = analyze(client, option['text']) + suggestion_tokens = analyze(client, option["text"]) try: - term_corrections = format_corrections(original_tokens, - suggestion_tokens, - offsets[0]) - corrections.append(Correction(term_corrections, option['score'])) + term_corrections = format_corrections( + original_tokens, suggestion_tokens, offsets[0] + ) + corrections.append(Correction(term_corrections, option["score"])) except ValueError: print("Mismatch") # Length mismatch exception @@ -126,38 +127,40 @@ def generate_corrections(client, ingredients_text: str, **kwargs) -> List[Correc def generate_corrected_text(corrections: List[TermCorrection], text: str): - sorted_corrections = sorted(corrections, - key=operator.attrgetter('start_offset')) + sorted_corrections = sorted(corrections, key=operator.attrgetter("start_offset")) corrected_fragments = [] last_correction = None for correction in sorted_corrections: if last_correction is None: - corrected_fragments.append(text[:correction.start_offset]) + corrected_fragments.append(text[: correction.start_offset]) else: corrected_fragments.append( - text[last_correction.end_offset:correction.start_offset]) + text[last_correction.end_offset : correction.start_offset] + ) corrected_fragments.append(correction.correction) last_correction = correction if last_correction is not None: - corrected_fragments.append(text[last_correction.end_offset:]) + corrected_fragments.append(text[last_correction.end_offset :]) - return ''.join(corrected_fragments) + return "".join(corrected_fragments) -def format_corrections(original_tokens: List[Dict], - suggestion_tokens: List[Dict], - offset: int=0): +def format_corrections( + original_tokens: List[Dict], suggestion_tokens: List[Dict], offset: int = 0 +): corrections = [] if len(original_tokens) != len(suggestion_tokens): - raise ValueError("The original text and the suggestions must have the same number of tokens") + raise ValueError( + "The original text and the suggestions must have the same number of tokens" + ) for original_token, suggestion_token in zip(original_tokens, suggestion_tokens): - original_token_str = original_token['token'] - suggestion_token_str = suggestion_token['token'] + original_token_str = original_token["token"] + suggestion_token_str = suggestion_token["token"] if original_token_str.lower() != suggestion_token_str: if original_token_str.isupper(): @@ -167,12 +170,16 @@ def format_corrections(original_tokens: List[Dict], else: token_str = suggestion_token_str - token_start = original_token['start_offset'] - token_end = original_token['end_offset'] - corrections.append(TermCorrection(original=original_token_str, - correction=token_str, - start_offset=offset+token_start, - end_offset=offset+token_end)) + token_start = original_token["start_offset"] + token_end = original_token["end_offset"] + corrections.append( + TermCorrection( + original=original_token_str, + correction=token_str, + start_offset=offset + token_start, + end_offset=offset + token_end, + ) + ) return corrections @@ -180,52 +187,51 @@ def format_corrections(original_tokens: List[Dict], def _suggest(client, text): suggester_name = "autocorrect" body = generate_suggest_query(text, name=suggester_name) - response = client.search(index='product', - doc_type='document', - body=body, - _source=False) - return response['suggest'][suggester_name] + response = client.search( + index="product", doc_type="document", body=body, _source=False + ) + return response["suggest"][suggester_name] def analyze(client, ingredient_text: str): - r = client.indices.analyze(index=settings.ELASTICSEARCH_PRODUCT_INDEX, - body={ - 'tokenizer': "standard", - 'text': ingredient_text - }) - return r['tokens'] + r = client.indices.analyze( + index=settings.ELASTICSEARCH_PRODUCT_INDEX, + body={"tokenizer": "standard", "text": ingredient_text}, + ) + return r["tokens"] def _suggest_batch(client, texts: Iterable[str], **kwargs) -> List[Dict]: suggester_name = "autocorrect" - queries = (generate_suggest_query(text, name=suggester_name, **kwargs) - for text in texts) + queries = ( + generate_suggest_query(text, name=suggester_name, **kwargs) for text in texts + ) body = generate_msearch_body(settings.ELASTICSEARCH_PRODUCT_INDEX, queries) - response = client.msearch(body=body, - doc_type=settings.ELASTICSEARCH_TYPE) + response = client.msearch(body=body, doc_type=settings.ELASTICSEARCH_TYPE) suggestions = [] - for r in response['responses']: - if r['status'] != 200: - root_cause = response['error']['root_cause'][0] - error_type = root_cause['type'] - error_reason = root_cause['reason'] - print("Elasticsearch error: {} [{}]" - "".format(error_reason, error_type)) + for r in response["responses"]: + if r["status"] != 200: + root_cause = response["error"]["root_cause"][0] + error_type = root_cause["type"] + error_reason = root_cause["reason"] + print("Elasticsearch error: {} [{}]" "".format(error_reason, error_type)) continue - suggestions.append(r['suggest'][suggester_name][0]) + suggestions.append(r["suggest"][suggester_name][0]) return suggestions -def generate_suggest_query(text, - confidence=1, - size=1, - min_word_length=4, - suggest_mode="popular", - name="autocorrect"): +def generate_suggest_query( + text, + confidence=1, + size=1, + min_word_length=4, + suggest_mode="popular", + name="autocorrect", +): return { "suggest": { "text": text, @@ -239,16 +245,12 @@ def generate_suggest_query(text, { "field": "ingredients_text_fr.trigram", "suggest_mode": suggest_mode, - "min_word_length": min_word_length + "min_word_length": min_word_length, } ], - "smoothing": { - "laplace": { - "alpha": 0.5 - } - } + "smoothing": {"laplace": {"alpha": 0.5}}, } - } + }, } } @@ -256,29 +258,29 @@ def generate_suggest_query(text, def generate_insights(client, confidence=1): dataset = ProductDataset(settings.JSONL_DATASET_PATH) - product_iter = (dataset.stream() - .filter_by_country_tag('en:france') - .filter_nonempty_text_field('ingredients_text_fr') - .iter()) + product_iter = ( + dataset.stream() + .filter_by_country_tag("en:france") + .filter_nonempty_text_field("ingredients_text_fr") + .iter() + ) for product in product_iter: - text = product['ingredients_text_fr'] - corrections = generate_corrections(client, - text, - confidence=confidence) + text = product["ingredients_text_fr"] + corrections = generate_corrections(client, text, confidence=confidence) if not corrections: continue - term_corrections = list(itertools.chain - .from_iterable((c.term_corrections - for c in corrections))) + term_corrections = list( + itertools.chain.from_iterable((c.term_corrections for c in corrections)) + ) yield { - 'corrections': [dataclasses.asdict(c) for c in term_corrections], - 'text': text, - 'corrected': generate_corrected_text(term_corrections, text), - 'barcode': product['code'], + "corrections": [dataclasses.asdict(c) for c in term_corrections], + "text": text, + "corrected": generate_corrected_text(term_corrections, text), + "barcode": product["code"], } @@ -295,6 +297,6 @@ def generate_insights(client, confidence=1): # corrections = generate_corrections(client, text, confidence=1) # print(corrections) - with open('insights_10.jsonl', 'w') as f: + with open("insights_10.jsonl", "w") as f: for insight in generate_insights(client, confidence=10): - f.write(json.dumps(insight) + '\n') + f.write(json.dumps(insight) + "\n") diff --git a/robotoff/insights/annotate.py b/robotoff/insights/annotate.py index d7fa4cb08a..afc1c2303a 100644 --- a/robotoff/insights/annotate.py +++ b/robotoff/insights/annotate.py @@ -1,16 +1,24 @@ import abc import datetime -from typing import Optional, List - from dataclasses import dataclass from enum import Enum +from typing import List, Optional from robotoff.insights._enum import InsightType from robotoff.insights.normalize import normalize_emb_code -from robotoff.models import ProductInsight, db, ProductIngredient -from robotoff.off import get_product, save_ingredients, update_emb_codes, \ - add_label_tag, add_category, update_quantity, update_expiration_date, \ - add_brand, product_exists, add_store +from robotoff.models import ProductIngredient, ProductInsight, db +from robotoff.off import ( + add_brand, + add_category, + add_label_tag, + add_store, + get_product, + product_exists, + save_ingredients, + update_emb_codes, + update_expiration_date, + update_quantity, +) from robotoff.utils import get_logger logger = get_logger(__name__) @@ -32,32 +40,36 @@ class AnnotationStatus(Enum): SAVED_ANNOTATION_RESULT = AnnotationResult( - status=AnnotationStatus.saved.name, - description="the annotation was saved") + status=AnnotationStatus.saved.name, description="the annotation was saved" +) UPDATED_ANNOTATION_RESULT = AnnotationResult( status=AnnotationStatus.updated.name, - description="the annotation was saved and sent to OFF") + description="the annotation was saved and sent to OFF", +) MISSING_PRODUCT_RESULT = AnnotationResult( status=AnnotationStatus.error_missing_product.name, - description="the product could not be found on OFF") + description="the product could not be found on OFF", +) ALREADY_ANNOTATED_RESULT = AnnotationResult( status=AnnotationStatus.error_already_annotated.name, - description="the insight has already been annotated") + description="the insight has already been annotated", +) UNKNOWN_INSIGHT_RESULT = AnnotationResult( - status=AnnotationStatus.error_unknown_insight.name, - description="unknown insight ID") + status=AnnotationStatus.error_unknown_insight.name, description="unknown insight ID" +) class InsightAnnotator(metaclass=abc.ABCMeta): - def annotate(self, insight: ProductInsight, annotation: int, update=True) \ - -> AnnotationResult: + def annotate( + self, insight: ProductInsight, annotation: int, update=True + ) -> AnnotationResult: insight.annotation = annotation insight.completed_at = datetime.datetime.utcnow() insight.save() if annotation == 1 and update: return self.update_product(insight) - + return SAVED_ANNOTATION_RESULT @abc.abstractmethod @@ -67,18 +79,18 @@ def update_product(self, insight: ProductInsight) -> AnnotationResult: class PackagerCodeAnnotator(InsightAnnotator): def update_product(self, insight: ProductInsight) -> AnnotationResult: - emb_code: str = insight.data['text'] + emb_code: str = insight.data["text"] - product = get_product(insight.barcode, ['emb_codes']) + product = get_product(insight.barcode, ["emb_codes"]) if product is None: return MISSING_PRODUCT_RESULT - emb_codes_str: str = product.get('emb_codes', '') + emb_codes_str: str = product.get("emb_codes", "") emb_codes: List[str] = [] if emb_codes_str: - emb_codes = emb_codes_str.split(',') + emb_codes = emb_codes_str.split(",") if self.already_exists(emb_code, emb_codes): return ALREADY_ANNOTATED_RESULT @@ -88,10 +100,8 @@ def update_product(self, insight: ProductInsight) -> AnnotationResult: return UPDATED_ANNOTATION_RESULT @staticmethod - def already_exists(new_emb_code: str, - emb_codes: List[str]) -> bool: - emb_codes = [normalize_emb_code(emb_code) - for emb_code in emb_codes] + def already_exists(new_emb_code: str, emb_codes: List[str]) -> bool: + emb_codes = [normalize_emb_code(emb_code) for emb_code in emb_codes] normalized_emb_code = normalize_emb_code(new_emb_code) @@ -103,12 +113,12 @@ def already_exists(new_emb_code: str, class LabelAnnotator(InsightAnnotator): def update_product(self, insight: ProductInsight) -> AnnotationResult: - product = get_product(insight.barcode, ['labels_tags']) + product = get_product(insight.barcode, ["labels_tags"]) if product is None: return MISSING_PRODUCT_RESULT - labels_tags: List[str] = product.get('labels_tags') or [] + labels_tags: List[str] = product.get("labels_tags") or [] if insight.value_tag in labels_tags: return ALREADY_ANNOTATED_RESULT @@ -128,14 +138,18 @@ def update_product(self, insight: ProductInsight) -> AnnotationResult: try: product_ingredient: ProductIngredient = ( ProductIngredient.select() - .where(ProductIngredient.barcode == barcode) - .get()) + .where(ProductIngredient.barcode == barcode) + .get() + ) except ProductIngredient.DoesNotExist: - logger.warning("Missing product ingredient for product " - "{}".format(barcode)) - return AnnotationResult(status="error_no_matching_ingredient", - description="no ingredient is associated " - "with insight (internal error)") + logger.warning( + "Missing product ingredient for product " "{}".format(barcode) + ) + return AnnotationResult( + status="error_no_matching_ingredient", + description="no ingredient is associated " + "with insight (internal error)", + ) ingredient_str = product_ingredient.ingredients product = get_product(barcode, fields=["ingredients_text"]) @@ -147,18 +161,21 @@ def update_product(self, insight: ProductInsight) -> AnnotationResult: expected_ingredients = product.get("ingredients_text") if expected_ingredients != ingredient_str: - logger.warning("ingredients have changed since spellcheck insight " - "creation (product {})".format(barcode)) - return AnnotationResult(status=AnnotationStatus - .error_updated_product.name, - description="the ingredient list has been " - "updated since spellcheck") + logger.warning( + "ingredients have changed since spellcheck insight " + "creation (product {})".format(barcode) + ) + return AnnotationResult( + status=AnnotationStatus.error_updated_product.name, + description="the ingredient list has been " "updated since spellcheck", + ) full_correction = self.generate_full_correction( ingredient_str, - insight.data['start_offset'], - insight.data['end_offset'], - insight.data['correction']) + insight.data["start_offset"], + insight.data["end_offset"], + insight.data["correction"], + ) save_ingredients(barcode, full_correction) self.update_related_insights(insight) @@ -167,54 +184,51 @@ def update_product(self, insight: ProductInsight) -> AnnotationResult: return UPDATED_ANNOTATION_RESULT @staticmethod - def generate_full_correction(ingredient_str: str, - start_offset: int, - end_offset: int, - correction: str): - return "{}{}{}".format(ingredient_str[:start_offset], - correction, - ingredient_str[end_offset:]) + def generate_full_correction( + ingredient_str: str, start_offset: int, end_offset: int, correction: str + ): + return "{}{}{}".format( + ingredient_str[:start_offset], correction, ingredient_str[end_offset:] + ) @staticmethod - def generate_snippet(ingredient_str: str, - start_offset: int, - end_offset: int, - correction: str) -> str: + def generate_snippet( + ingredient_str: str, start_offset: int, end_offset: int, correction: str + ) -> str: context_len = 15 - return "{}{}{}".format(ingredient_str[start_offset-context_len: - start_offset], - correction, - ingredient_str[end_offset: - end_offset+context_len]) + return "{}{}{}".format( + ingredient_str[start_offset - context_len : start_offset], + correction, + ingredient_str[end_offset : end_offset + context_len], + ) @staticmethod def update_related_insights(insight: ProductInsight): - diff_len = (len(insight.data['correction']) - - len(insight.data['original'])) + diff_len = len(insight.data["correction"]) - len(insight.data["original"]) if diff_len == 0: return with db.atomic(): - for other in (ProductInsight.select() - .where(ProductInsight.barcode == insight.barcode, - ProductInsight.id != insight.id, - ProductInsight.type == - InsightType.ingredient_spellcheck.name)): - if insight.data['start_offset'] <= other.data['start_offset']: - other.data['start_offset'] += diff_len - other.data['end_offset'] += diff_len + for other in ProductInsight.select().where( + ProductInsight.barcode == insight.barcode, + ProductInsight.id != insight.id, + ProductInsight.type == InsightType.ingredient_spellcheck.name, + ): + if insight.data["start_offset"] <= other.data["start_offset"]: + other.data["start_offset"] += diff_len + other.data["end_offset"] += diff_len other.save() class CategoryAnnotator(InsightAnnotator): def update_product(self, insight: ProductInsight) -> AnnotationResult: - product = get_product(insight.barcode, ['categories_tags']) + product = get_product(insight.barcode, ["categories_tags"]) if product is None: return MISSING_PRODUCT_RESULT - categories_tags: List[str] = product.get('categories_tags') or [] + categories_tags: List[str] = product.get("categories_tags") or [] if insight.value_tag in categories_tags: return ALREADY_ANNOTATED_RESULT @@ -227,17 +241,17 @@ def update_product(self, insight: ProductInsight) -> AnnotationResult: class ProductWeightAnnotator(InsightAnnotator): def update_product(self, insight: ProductInsight) -> AnnotationResult: - product = get_product(insight.barcode, ['quantity']) + product = get_product(insight.barcode, ["quantity"]) if product is None: return MISSING_PRODUCT_RESULT - quantity: Optional[str] = product.get('quantity') or None + quantity: Optional[str] = product.get("quantity") or None if quantity is not None: return ALREADY_ANNOTATED_RESULT - weight = insight.data['text'] + weight = insight.data["text"] update_quantity(insight.barcode, weight) return UPDATED_ANNOTATION_RESULT @@ -245,14 +259,14 @@ def update_product(self, insight: ProductInsight) -> AnnotationResult: class ExpirationDateAnnotator(InsightAnnotator): def update_product(self, insight: ProductInsight) -> AnnotationResult: - expiration_date: str = insight.data['text'] + expiration_date: str = insight.data["text"] - product = get_product(insight.barcode, ['expiration_date']) + product = get_product(insight.barcode, ["expiration_date"]) if product is None: return MISSING_PRODUCT_RESULT - current_expiration_date = product.get('expiration_date') or None + current_expiration_date = product.get("expiration_date") or None if current_expiration_date: return ALREADY_ANNOTATED_RESULT @@ -263,14 +277,14 @@ def update_product(self, insight: ProductInsight) -> AnnotationResult: class BrandAnnotator(InsightAnnotator): def update_product(self, insight: ProductInsight) -> AnnotationResult: - brand: str = insight.data['brand'] + brand: str = insight.data["brand"] - product = get_product(insight.barcode, ['brands_tags']) + product = get_product(insight.barcode, ["brands_tags"]) if product is None: return MISSING_PRODUCT_RESULT - brand_tags: List[str] = product.get('brands_tags') or [] + brand_tags: List[str] = product.get("brands_tags") or [] if brand_tags: # For now, don't annotate if a brand has already been provided @@ -282,15 +296,15 @@ def update_product(self, insight: ProductInsight) -> AnnotationResult: class StoreAnnotator(InsightAnnotator): def update_product(self, insight: ProductInsight) -> AnnotationResult: - store: str = insight.data['store'] + store: str = insight.data["store"] store_tag: str = insight.value_tag - product = get_product(insight.barcode, ['stores_tags']) + product = get_product(insight.barcode, ["stores_tags"]) if product is None: return MISSING_PRODUCT_RESULT - stores_tags: List[str] = product.get('stores_tags') or [] + stores_tags: List[str] = product.get("stores_tags") or [] if store_tag in stores_tags: return ALREADY_ANNOTATED_RESULT diff --git a/robotoff/insights/data.py b/robotoff/insights/data.py index 12a5729002..6a300e4065 100644 --- a/robotoff/insights/data.py +++ b/robotoff/insights/data.py @@ -86,7 +86,7 @@ } BRANDS_BARCODE_RANGE: Dict[str, str] = { - 'boni': "5400141xxxxxx", - 'everyday': "5400141xxxxxx", - 'netto': "325039xxxxxxx", + "boni": "5400141xxxxxx", + "everyday": "5400141xxxxxx", + "netto": "325039xxxxxxx", } diff --git a/robotoff/insights/extraction.py b/robotoff/insights/extraction.py index 0c3378606a..a71686a9c0 100644 --- a/robotoff/insights/extraction.py +++ b/robotoff/insights/extraction.py @@ -1,12 +1,11 @@ +from typing import Dict, List, Optional from urllib.parse import urlparse import requests -from typing import Optional, Dict, List - from PIL import Image -from robotoff.insights._enum import InsightType from robotoff.insights import ocr +from robotoff.insights._enum import InsightType from robotoff.ml.object_detection import ObjectDetectionModelRegistry from robotoff.utils import get_image_from_url, get_logger from robotoff.utils.types import JSONType @@ -14,26 +13,28 @@ logger = get_logger(__name__) -def get_insights_from_image(barcode: str, image_url: str, ocr_url: str) \ - -> Optional[Dict]: +def get_insights_from_image( + barcode: str, image_url: str, ocr_url: str +) -> Optional[Dict]: ocr_insights = extract_ocr_insights(ocr_url) extract_nutriscore = has_nutriscore_insight(ocr_insights) image_ml_insights = extract_image_ml_insights( - image_url, extract_nutriscore=extract_nutriscore) + image_url, extract_nutriscore=extract_nutriscore + ) insight_types = set(ocr_insights.keys()).union(image_ml_insights.keys()) results = {} for insight_type in insight_types: - insights = (ocr_insights.get(insight_type, []) + - image_ml_insights.get(insight_type, [])) + insights = ocr_insights.get(insight_type, []) + image_ml_insights.get( + insight_type, [] + ) - results[insight_type] = generate_insights_dict(insights, - barcode, - insight_type, - image_url) + results[insight_type] = generate_insights_dict( + insights, barcode, insight_type, image_url + ) if not results: return None @@ -42,48 +43,44 @@ def get_insights_from_image(barcode: str, image_url: str, ocr_url: str) \ def has_nutriscore_insight(insights: JSONType) -> bool: - for insight in insights.get('label', []): - if insight['label_tag'] == 'en:nutriscore': + for insight in insights.get("label", []): + if insight["label_tag"] == "en:nutriscore": return True return False -def generate_insights_dict(insights: List[JSONType], - barcode: str, - insight_type: str, - image_url: str): +def generate_insights_dict( + insights: List[JSONType], barcode: str, insight_type: str, image_url: str +): image_url_path = urlparse(image_url).path - if image_url_path.startswith('/images/products'): - image_url_path = image_url_path[len("/images/products"):] + if image_url_path.startswith("/images/products"): + image_url_path = image_url_path[len("/images/products") :] return { - 'insights': insights, - 'barcode': barcode, - 'type': insight_type, - 'source': image_url_path, + "insights": insights, + "barcode": barcode, + "type": insight_type, + "source": image_url_path, } -def extract_image_ml_insights(image_url: str, - extract_nutriscore: bool = True) -> JSONType: +def extract_image_ml_insights( + image_url: str, extract_nutriscore: bool = True +) -> JSONType: results: JSONType = {} if extract_nutriscore: image = get_image_from_url(image_url, error_raise=True) - nutriscore_insight = extract_nutriscore_label(image, - manual_threshold=0.5, - automatic_threshold=0.9) + nutriscore_insight = extract_nutriscore_label( + image, manual_threshold=0.5, automatic_threshold=0.9 + ) if not nutriscore_insight: return results - results = { - 'label': [ - nutriscore_insight - ] - } + results = {"label": [nutriscore_insight]} return results @@ -106,13 +103,15 @@ def extract_ocr_insights(ocr_url: str) -> JSONType: results = {} - for insight_type in (InsightType.label.name, - InsightType.packager_code.name, - InsightType.product_weight.name, - InsightType.image_flag.name, - InsightType.expiration_date.name, - InsightType.brand.name, - InsightType.store.name): + for insight_type in ( + InsightType.label.name, + InsightType.packager_code.name, + InsightType.product_weight.name, + InsightType.image_flag.name, + InsightType.expiration_date.name, + InsightType.brand.name, + InsightType.store.name, + ): insights = ocr.extract_insights(ocr_result, insight_type) if insights: @@ -121,10 +120,10 @@ def extract_ocr_insights(ocr_url: str) -> JSONType: return results -def extract_nutriscore_label(image: Image.Image, - manual_threshold: float, - automatic_threshold: float) -> Optional[JSONType]: - model = ObjectDetectionModelRegistry.get('nutriscore') +def extract_nutriscore_label( + image: Image.Image, manual_threshold: float, automatic_threshold: float +) -> Optional[JSONType]: + model = ObjectDetectionModelRegistry.get("nutriscore") raw_result = model.detect_from_image(image, output_image=False) results = raw_result.select(threshold=manual_threshold) @@ -139,13 +138,13 @@ def extract_nutriscore_label(image: Image.Image, score = result.score automatic_processing = score >= automatic_threshold - label_tag = 'en:{}'.format(result.label) + label_tag = "en:{}".format(result.label) return { - 'label_tag': label_tag, - 'notify': True, - 'automatic_processing': automatic_processing, - 'confidence': score, - 'bounding_box': result.bounding_box, - 'model': 'nutriscore', + "label_tag": label_tag, + "notify": True, + "automatic_processing": automatic_processing, + "confidence": score, + "bounding_box": result.bounding_box, + "model": "nutriscore", } diff --git a/robotoff/insights/importer.py b/robotoff/insights/importer.py index 85d02f1838..1c3176b68c 100644 --- a/robotoff/insights/importer.py +++ b/robotoff/insights/importer.py @@ -1,13 +1,13 @@ import abc import datetime import uuid -from typing import Dict, Iterable, List, Set, Optional, Callable +from typing import Callable, Dict, Iterable, List, Optional, Set from robotoff.insights._enum import InsightType from robotoff.insights.data import AUTHORIZED_LABELS, BRANDS_BARCODE_RANGE from robotoff.insights.normalize import normalize_emb_code -from robotoff.models import batch_insert, ProductInsight, ProductIngredient -from robotoff.products import ProductStore, Product +from robotoff.models import ProductIngredient, ProductInsight, batch_insert +from robotoff.products import Product, ProductStore from robotoff.taxonomy import TAXONOMY_STORES, Taxonomy, TaxonomyNode from robotoff.utils import get_logger, jsonl_iter, jsonl_iter_fp from robotoff.utils.types import JSONType @@ -41,8 +41,9 @@ def need_validation(insight: JSONType) -> bool: return True @staticmethod - def _deduplicate_insights(data: Iterable[Dict], - key_func: Callable) -> Iterable[Dict]: + def _deduplicate_insights( + data: Iterable[Dict], key_func: Callable + ) -> Iterable[Dict]: seen: Set = set() for item in data: value = key_func(item) @@ -67,41 +68,45 @@ def import_insights(self, data: Iterable[Dict], automatic: bool = False) -> int: inserted = 0 for item in data: - barcode = item['barcode'] - corrections = item['corrections'] - text = item['text'] + barcode = item["barcode"] + corrections = item["corrections"] + text = item["text"] if barcode not in barcode_seen: - product_ingredients.append({ - 'barcode': barcode, - 'ingredients': item['text'], - }) + product_ingredients.append( + { + "barcode": barcode, + "ingredients": item["text"], + } + ) barcode_seen.add(barcode) for correction in corrections: - start_offset = correction['start_offset'] - end_offset = correction['end_offset'] + start_offset = correction["start_offset"] + end_offset = correction["end_offset"] key = (barcode, start_offset, end_offset) if key not in insight_seen: - original_snippet = self.generate_snippet(text, - start_offset, end_offset, - correction['original']) - corrected_snippet = self.generate_snippet(text, - start_offset, end_offset, - correction['correction']) - insights.append({ - 'id': str(uuid.uuid4()), - 'type': InsightType.ingredient_spellcheck.name, - 'barcode': barcode, - 'timestamp': timestamp, - 'automatic_processing': False, - 'data': { - **correction, - 'original_snippet': original_snippet, - 'corrected_snippet': corrected_snippet, - }, - }) + original_snippet = self.generate_snippet( + text, start_offset, end_offset, correction["original"] + ) + corrected_snippet = self.generate_snippet( + text, start_offset, end_offset, correction["correction"] + ) + insights.append( + { + "id": str(uuid.uuid4()), + "type": InsightType.ingredient_spellcheck.name, + "barcode": barcode, + "timestamp": timestamp, + "automatic_processing": False, + "data": { + **correction, + "original_snippet": original_snippet, + "corrected_snippet": corrected_snippet, + }, + } + ) insight_seen.add(key) if len(product_ingredients) >= 50: @@ -117,14 +122,15 @@ def import_insights(self, data: Iterable[Dict], automatic: bool = False) -> int: return inserted @staticmethod - def generate_snippet(ingredient_str: str, - start_offset: int, - end_offset: int, - correction: str) -> str: + def generate_snippet( + ingredient_str: str, start_offset: int, end_offset: int, correction: str + ) -> str: context_len = 15 - return "{}{}{}".format(ingredient_str[start_offset-context_len:start_offset], - correction, - ingredient_str[end_offset:end_offset+context_len]) + return "{}{}{}".format( + ingredient_str[start_offset - context_len : start_offset], + correction, + ingredient_str[end_offset : end_offset + context_len], + ) GroupedByOCRInsights = Dict[str, List] @@ -139,32 +145,34 @@ def import_insights(self, data: Iterable[Dict], automatic: bool = False) -> int: for barcode, insights in grouped_by.items(): insights = list(self.deduplicate_insights(insights)) insights = self.sort_by_priority(insights) - inserts += list(self._process_product_insights(barcode, insights, - timestamp, - automatic)) + inserts += list( + self._process_product_insights(barcode, insights, timestamp, automatic) + ) return batch_insert(ProductInsight, inserts, 50) - def _process_product_insights(self, barcode: str, - insights: List[JSONType], - timestamp: datetime.datetime, - automatic: bool) -> \ - Iterable[JSONType]: - countries_tags = getattr(self.product_store[barcode], - 'countries_tags', []) - brands_tags = getattr(self.product_store[barcode], - 'brands_tags', []) + def _process_product_insights( + self, + barcode: str, + insights: List[JSONType], + timestamp: datetime.datetime, + automatic: bool, + ) -> Iterable[JSONType]: + countries_tags = getattr(self.product_store[barcode], "countries_tags", []) + brands_tags = getattr(self.product_store[barcode], "brands_tags", []) for insight in self.process_product_insights(barcode, insights): - insight['id'] = str(uuid.uuid4()) - insight['barcode'] = barcode - insight['timestamp'] = timestamp - insight['type'] = self.get_type() - insight['countries'] = countries_tags - insight['brands'] = brands_tags - - if 'automatic_processing' not in insight: - insight['automatic_processing'] = automatic and not self.need_validation(insight) + insight["id"] = str(uuid.uuid4()) + insight["barcode"] = barcode + insight["timestamp"] = timestamp + insight["type"] = self.get_type() + insight["countries"] = countries_tags + insight["brands"] = brands_tags + + if "automatic_processing" not in insight: + insight[ + "automatic_processing" + ] = automatic and not self.need_validation(insight) yield insight @@ -173,134 +181,130 @@ def group_by_barcode(self, data: Iterable[Dict]) -> GroupedByOCRInsights: insight_type = self.get_type() for item in data: - barcode = item['barcode'] - source = item['source'] + barcode = item["barcode"] + source = item["source"] - if item['type'] != insight_type: - raise ValueError("unexpected insight type: " - "'{}'".format(insight_type)) + if item["type"] != insight_type: + raise ValueError( + "unexpected insight type: " "'{}'".format(insight_type) + ) - insights = item['insights'] + insights = item["insights"] if not insights: continue grouped_by.setdefault(barcode, []) - grouped_by[barcode] += [{ - 'source': source, - 'barcode': barcode, - 'type': insight_type, - 'content': i, - } for i in insights] + grouped_by[barcode] += [ + { + "source": source, + "barcode": barcode, + "type": insight_type, + "content": i, + } + for i in insights + ] return grouped_by @staticmethod def sort_by_priority(insights: List[JSONType]) -> List[JSONType]: - return sorted(insights, - key=lambda insight: insight.get('priority', 1)) + return sorted(insights, key=lambda insight: insight.get("priority", 1)) @abc.abstractmethod - def process_product_insights(self, barcode: str, - insights: List[JSONType]) \ - -> Iterable[JSONType]: + def process_product_insights( + self, barcode: str, insights: List[JSONType] + ) -> Iterable[JSONType]: pass @abc.abstractmethod - def deduplicate_insights(self, data: Iterable[JSONType]) -> \ - Iterable[JSONType]: + def deduplicate_insights(self, data: Iterable[JSONType]) -> Iterable[JSONType]: pass class PackagerCodeInsightImporter(OCRInsightImporter): - def deduplicate_insights(self, - data: Iterable[JSONType]) -> Iterable[JSONType]: - yield from self._deduplicate_insights(data, - lambda x: x['content']['text']) + def deduplicate_insights(self, data: Iterable[JSONType]) -> Iterable[JSONType]: + yield from self._deduplicate_insights(data, lambda x: x["content"]["text"]) @staticmethod def get_type() -> str: return InsightType.packager_code.name - def is_valid(self, barcode: str, - emb_code: str, - code_seen: Set[str]) -> bool: + def is_valid(self, barcode: str, emb_code: str, code_seen: Set[str]) -> bool: product: Optional[Product] = self.product_store[barcode] - product_emb_codes_tags = getattr(product, 'emb_codes_tags', []) + product_emb_codes_tags = getattr(product, "emb_codes_tags", []) normalized_emb_code = normalize_emb_code(emb_code) - normalized_emb_codes = [normalize_emb_code(c) - for c in product_emb_codes_tags] + normalized_emb_codes = [normalize_emb_code(c) for c in product_emb_codes_tags] if normalized_emb_code in normalized_emb_codes: return False if emb_code in code_seen: return False - + return True - - def process_product_insights(self, barcode: str, - insights: List[JSONType]) \ - -> Iterable[JSONType]: + + def process_product_insights( + self, barcode: str, insights: List[JSONType] + ) -> Iterable[JSONType]: code_seen: Set[str] = set() - for t in (ProductInsight.select(ProductInsight.data['text'] - .as_json().alias('text')) - .where(ProductInsight.type == - self.get_type(), - ProductInsight.barcode == - barcode)).iterator(): + for t in ( + ProductInsight.select( + ProductInsight.data["text"].as_json().alias("text") + ).where( + ProductInsight.type == self.get_type(), + ProductInsight.barcode == barcode, + ) + ).iterator(): code_seen.add(t.text) for insight in insights: - content = insight['content'] - emb_code = content['text'] - + content = insight["content"] + emb_code = content["text"] + if not self.is_valid(barcode, emb_code, code_seen): continue - source = insight['source'] + source = insight["source"] yield { - 'source_image': source, - 'data': { - 'source': source, - 'matcher_type': content['type'], - 'raw': content['raw'], - 'text': emb_code, - 'notify': content['notify'], - } + "source_image": source, + "data": { + "source": source, + "matcher_type": content["type"], + "raw": content["raw"], + "text": emb_code, + "notify": content["notify"], + }, } code_seen.add(emb_code) @staticmethod def need_validation(insight: JSONType) -> bool: - if insight['type'] != PackagerCodeInsightImporter.get_type(): - raise ValueError("insight must be of type " - "{}".format(PackagerCodeInsightImporter - .get_type())) + if insight["type"] != PackagerCodeInsightImporter.get_type(): + raise ValueError( + "insight must be of type " + "{}".format(PackagerCodeInsightImporter.get_type()) + ) - if insight['data']['matcher_type'] in ('eu_fr', 'eu_de', 'fr_emb'): + if insight["data"]["matcher_type"] in ("eu_fr", "eu_de", "fr_emb"): return False return True class LabelInsightImporter(OCRInsightImporter): - def deduplicate_insights(self, - data: Iterable[JSONType]) -> Iterable[JSONType]: - yield from self._deduplicate_insights( - data, lambda x: x['content']['label_tag']) + def deduplicate_insights(self, data: Iterable[JSONType]) -> Iterable[JSONType]: + yield from self._deduplicate_insights(data, lambda x: x["content"]["label_tag"]) @staticmethod def get_type() -> str: return InsightType.label.name - def is_valid(self, barcode: str, - label_tag: str, - label_seen: Set[str]) -> bool: + def is_valid(self, barcode: str, label_tag: str, label_seen: Set[str]) -> bool: product = self.product_store[barcode] - product_labels_tags = getattr(product, 'labels_tags', []) + product_labels_tags = getattr(product, "labels_tags", []) if label_tag in product_labels_tags: return False @@ -310,63 +314,60 @@ def is_valid(self, barcode: str, # Check that the predicted label is not a parent of a # current/already predicted label - label_taxonomy: Taxonomy = TAXONOMY_STORES[ - InsightType.label.name].get() + label_taxonomy: Taxonomy = TAXONOMY_STORES[InsightType.label.name].get() if label_tag in label_taxonomy: label_node: TaxonomyNode = label_taxonomy[label_tag] - to_check_labels = (set(product_labels_tags) - .union(label_seen)) - for other_label_node in (label_taxonomy[to_check_label] - for to_check_label - in to_check_labels): - if (other_label_node is not None and - other_label_node.is_child_of(label_node)): + to_check_labels = set(product_labels_tags).union(label_seen) + for other_label_node in ( + label_taxonomy[to_check_label] for to_check_label in to_check_labels + ): + if other_label_node is not None and other_label_node.is_child_of( + label_node + ): return False - + return True - - def process_product_insights(self, barcode: str, - insights: List[JSONType]) \ - -> Iterable[JSONType]: + + def process_product_insights( + self, barcode: str, insights: List[JSONType] + ) -> Iterable[JSONType]: label_seen: Set[str] = set() - for t in (ProductInsight.select(ProductInsight.value_tag) - .where(ProductInsight.type == - self.get_type(), - ProductInsight.barcode == - barcode)).iterator(): + for t in ( + ProductInsight.select(ProductInsight.value_tag).where( + ProductInsight.type == self.get_type(), + ProductInsight.barcode == barcode, + ) + ).iterator(): label_seen.add(t.value_tag) for insight in insights: - barcode = insight['barcode'] - content = insight['content'] - label_tag = content['label_tag'] + barcode = insight["barcode"] + content = insight["content"] + label_tag = content["label_tag"] if not self.is_valid(barcode, label_tag, label_seen): continue - source = insight['source'] - automatic_processing = content.pop('automatic_processing', None) + source = insight["source"] + automatic_processing = content.pop("automatic_processing", None) insert = { - 'value_tag': label_tag, - 'source_image': source, - 'data': { - 'source': source, - **content - } + "value_tag": label_tag, + "source_image": source, + "data": {"source": source, **content}, } if automatic_processing is not None: - insert['automatic_processing'] = automatic_processing + insert["automatic_processing"] = automatic_processing yield insert label_seen.add(label_tag) @staticmethod def need_validation(insight: JSONType) -> bool: - if insight['data']['label_tag'] in AUTHORIZED_LABELS: + if insight["data"]["label_tag"] in AUTHORIZED_LABELS: return False return True @@ -381,109 +382,110 @@ def import_insights(self, data: Iterable[Dict], automatic: bool = False) -> int: inserts = self.process_product_insights(data, automatic) return batch_insert(ProductInsight, inserts, 50) - def process_product_insights(self, insights: Iterable[JSONType], - automatic: bool) \ - -> Iterable[JSONType]: + def process_product_insights( + self, insights: Iterable[JSONType], automatic: bool + ) -> Iterable[JSONType]: category_seen: Dict[str, Set[str]] = {} - for t in (ProductInsight.select(ProductInsight.value_tag, - ProductInsight.barcode) - .where(ProductInsight.type == - self.get_type())).iterator(): + for t in ( + ProductInsight.select( + ProductInsight.value_tag, ProductInsight.barcode + ).where(ProductInsight.type == self.get_type()) + ).iterator(): category_seen.setdefault(t.barcode, set()) category_seen[t.barcode].add(t.value_tag) timestamp = datetime.datetime.utcnow() for insight in insights: - barcode = insight['barcode'] - category = insight['category'] + barcode = insight["barcode"] + category = insight["category"] if not self.is_valid(barcode, category, category_seen): continue - countries_tags = getattr(self.product_store[barcode], - 'countries_tags', []) - brands_tags = getattr(self.product_store[barcode], - 'brands_tags', []) + countries_tags = getattr(self.product_store[barcode], "countries_tags", []) + brands_tags = getattr(self.product_store[barcode], "brands_tags", []) insert = { - 'id': str(uuid.uuid4()), - 'type': self.get_type(), - 'barcode': barcode, - 'countries': countries_tags, - 'brands': brands_tags, - 'timestamp': timestamp, - 'value_tag': category, - 'automatic_processing': False, - 'data': { - 'category': category, - } + "id": str(uuid.uuid4()), + "type": self.get_type(), + "barcode": barcode, + "countries": countries_tags, + "brands": brands_tags, + "timestamp": timestamp, + "value_tag": category, + "automatic_processing": False, + "data": { + "category": category, + }, } - if 'category_depth' in insight: - insert['data']['category_depth'] = insight['category_depth'] + if "category_depth" in insight: + insert["data"]["category_depth"] = insight["category_depth"] - if 'model' in insight: - insert['data']['model'] = insight['model'] + if "model" in insight: + insert["data"]["model"] = insight["model"] - if 'confidence' in insight: - insert['data']['confidence'] = insight['confidence'] + if "confidence" in insight: + insert["data"]["confidence"] = insight["confidence"] - if 'product_name' in insight: - insert['data']['product_name'] = insight['product_name'] + if "product_name" in insight: + insert["data"]["product_name"] = insight["product_name"] - if 'matcher_lang' in insight: - insert['data']['matcher_lang'] = insight['matcher_lang'] + if "matcher_lang" in insight: + insert["data"]["matcher_lang"] = insight["matcher_lang"] yield insert category_seen.setdefault(barcode, set()) category_seen[barcode].add(category) - def is_valid(self, barcode: str, - category: str, - category_seen: Dict[str, Set[str]]): + def is_valid(self, barcode: str, category: str, category_seen: Dict[str, Set[str]]): product = self.product_store[barcode] - product_categories_tags = getattr(product, 'categories_tags', []) + product_categories_tags = getattr(product, "categories_tags", []) if category in product_categories_tags: - logger.debug("The product already belongs to this category, " - "considering the insight as invalid") + logger.debug( + "The product already belongs to this category, " + "considering the insight as invalid" + ) return False if category in category_seen.get(barcode, set()): - logger.debug("An insight already exists for this product and " - "category, considering the insight as invalid") + logger.debug( + "An insight already exists for this product and " + "category, considering the insight as invalid" + ) return False # Check that the predicted category is not a parent of a # current/already predicted category - category_taxonomy: Taxonomy = TAXONOMY_STORES[ - InsightType.category.name].get() + category_taxonomy: Taxonomy = TAXONOMY_STORES[InsightType.category.name].get() if category in category_taxonomy: category_node: TaxonomyNode = category_taxonomy[category] - to_check_categories = (set(product_categories_tags) - .union(category_seen.get(barcode, - set()))) - for other_category_node in (category_taxonomy[to_check_category] - for to_check_category - in to_check_categories): - if (other_category_node is not None and - other_category_node.is_child_of(category_node)): + to_check_categories = set(product_categories_tags).union( + category_seen.get(barcode, set()) + ) + for other_category_node in ( + category_taxonomy[to_check_category] + for to_check_category in to_check_categories + ): + if other_category_node is not None and other_category_node.is_child_of( + category_node + ): logger.debug( "The predicted category is a parent of the product " "category or of the predicted category of an insight, " - "considering the insight as invalid") + "considering the insight as invalid" + ) return False return True class ProductWeightImporter(OCRInsightImporter): - def deduplicate_insights(self, - data: Iterable[JSONType]) -> Iterable[JSONType]: - yield from self._deduplicate_insights( - data, lambda x: x['content']['text']) + def deduplicate_insights(self, data: Iterable[JSONType]) -> Iterable[JSONType]: + yield from self._deduplicate_insights(data, lambda x: x["content"]["text"]) @staticmethod def get_type() -> str: @@ -493,8 +495,7 @@ def is_valid(self, barcode: str, weight_value_str: str) -> bool: try: weight_value = float(weight_value_str) except ValueError: - logger.warn("Weight value is not a float: {}" - "".format(weight_value_str)) + logger.warn("Weight value is not a float: {}" "".format(weight_value_str)) return False if weight_value <= 0: @@ -502,8 +503,10 @@ def is_valid(self, barcode: str, weight_value_str: str) -> bool: return False if float(int(weight_value)) != weight_value: - logger.info("Weight value is not an integer ({}), " - "returning non valid".format(weight_value)) + logger.info( + "Weight value is not an integer ({}), " + "returning non valid".format(weight_value) + ) return False product = self.product_store[barcode] @@ -512,8 +515,7 @@ def is_valid(self, barcode: str, weight_value_str: str) -> bool: return True if product.quantity is not None: - logger.debug("Product quantity field is not null, returning " - "non valid") + logger.debug("Product quantity field is not null, returning " "non valid") return False return True @@ -523,49 +525,52 @@ def group_by_subtype(insights: List[JSONType]) -> Dict[str, List[JSONType]]: insights_by_subtype: Dict[str, List[JSONType]] = {} for insight in insights: - matcher_type = insight['content']['matcher_type'] + matcher_type = insight["content"]["matcher_type"] insights_by_subtype.setdefault(matcher_type, []) insights_by_subtype[matcher_type].append(insight) return insights_by_subtype - def process_product_insights(self, barcode: str, - insights: List[JSONType]) \ - -> Iterable[JSONType]: + def process_product_insights( + self, barcode: str, insights: List[JSONType] + ) -> Iterable[JSONType]: if not insights: return insights_by_subtype = self.group_by_subtype(insights) insight = insights[0] - insight_subtype = insight['content']['matcher_type'] - - if (insight_subtype != 'with_mention' and - len(insights_by_subtype[insight_subtype]) > 1): - logger.info("{} distinct product weights found for product " - "{}, aborting import".format(len(insights), - barcode)) + insight_subtype = insight["content"]["matcher_type"] + + if ( + insight_subtype != "with_mention" + and len(insights_by_subtype[insight_subtype]) > 1 + ): + logger.info( + "{} distinct product weights found for product " + "{}, aborting import".format(len(insights), barcode) + ) return - if ProductInsight.select().where(ProductInsight.type == - self.get_type(), - ProductInsight.barcode == - barcode).count(): + if ( + ProductInsight.select() + .where( + ProductInsight.type == self.get_type(), + ProductInsight.barcode == barcode, + ) + .count() + ): return - content = insight['content'] + content = insight["content"] - if not self.is_valid(barcode, content['value']): + if not self.is_valid(barcode, content["value"]): return - source = insight['source'] + source = insight["source"] yield { - 'source_image': source, - 'data': { - 'source': source, - 'notify': content['notify'], - **content - } + "source_image": source, + "data": {"source": source, "notify": content["notify"], **content}, } @staticmethod @@ -574,10 +579,8 @@ def need_validation(insight: JSONType) -> bool: class ExpirationDateImporter(OCRInsightImporter): - def deduplicate_insights(self, - data: Iterable[JSONType]) -> Iterable[JSONType]: - yield from self._deduplicate_insights( - data, lambda x: x['content']['text']) + def deduplicate_insights(self, data: Iterable[JSONType]) -> Iterable[JSONType]: + yield from self._deduplicate_insights(data, lambda x: x["content"]["text"]) @staticmethod def get_type() -> str: @@ -590,41 +593,43 @@ def is_valid(self, barcode: str) -> bool: return True if product.expiration_date: - logger.debug("Product expiration date field is not null, returning " - "non valid") + logger.debug( + "Product expiration date field is not null, returning " "non valid" + ) return False return True - def process_product_insights(self, barcode: str, - insights: List[JSONType]) \ - -> Iterable[JSONType]: + def process_product_insights( + self, barcode: str, insights: List[JSONType] + ) -> Iterable[JSONType]: if len(insights) > 1: - logger.info("{} distinct expiration dates found for product " - "{}, aborting import".format(len(insights), - barcode)) + logger.info( + "{} distinct expiration dates found for product " + "{}, aborting import".format(len(insights), barcode) + ) return - if ProductInsight.select().where(ProductInsight.type == - self.get_type(), - ProductInsight.barcode == - barcode).count(): + if ( + ProductInsight.select() + .where( + ProductInsight.type == self.get_type(), + ProductInsight.barcode == barcode, + ) + .count() + ): return for insight in insights: - content = insight['content'] + content = insight["content"] if not self.is_valid(barcode): continue - source = insight['source'] + source = insight["source"] yield { - 'source_image': source, - 'data': { - 'source': source, - 'notify': content['notify'], - **content - } + "source_image": source, + "data": {"source": source, "notify": content["notify"], **content}, } break @@ -634,24 +639,22 @@ def need_validation(insight: JSONType) -> bool: class BrandInsightImporter(OCRInsightImporter): - def deduplicate_insights(self, - data: Iterable[JSONType]) -> Iterable[JSONType]: - yield from self._deduplicate_insights( - data, lambda x: x['content']['brand_tag']) + def deduplicate_insights(self, data: Iterable[JSONType]) -> Iterable[JSONType]: + yield from self._deduplicate_insights(data, lambda x: x["content"]["brand_tag"]) @staticmethod def get_type() -> str: return InsightType.brand.name - def is_valid(self, barcode: str, - brand_tag: str, - brand_seen: Set[str]) -> bool: + def is_valid(self, barcode: str, brand_tag: str, brand_seen: Set[str]) -> bool: if brand_tag in brand_seen: return False if not self.in_barcode_range(brand_tag, barcode): - logger.warn("Barcode {} of brand {} not in barcode " - "range".format(barcode, brand_tag)) + logger.warn( + "Barcode {} of brand {} not in barcode " + "range".format(barcode, brand_tag) + ) return False product = self.product_store[barcode] @@ -665,41 +668,42 @@ def is_valid(self, barcode: str, return True - def process_product_insights(self, barcode: str, - insights: List[JSONType]) \ - -> Iterable[JSONType]: + def process_product_insights( + self, barcode: str, insights: List[JSONType] + ) -> Iterable[JSONType]: brand_seen: Set[str] = set() - for t in (ProductInsight.select(ProductInsight.value_tag) - .where(ProductInsight.type == - self.get_type(), - ProductInsight.barcode == - barcode)).iterator(): + for t in ( + ProductInsight.select(ProductInsight.value_tag).where( + ProductInsight.type == self.get_type(), + ProductInsight.barcode == barcode, + ) + ).iterator(): brand_seen.add(t.value_tag) for insight in insights: - barcode = insight['barcode'] - content = insight['content'] - brand_tag = content['brand_tag'] + barcode = insight["barcode"] + content = insight["content"] + brand_tag = content["brand_tag"] if not self.is_valid(barcode, brand_tag, brand_seen): continue - source = insight['source'] + source = insight["source"] insert = { - 'value_tag': brand_tag, - 'source_image': source, - 'data': { - 'source': source, - 'brand_tag': brand_tag, - 'text': content.get('text'), - 'brand': content['brand'], - 'notify': content['notify'], - } + "value_tag": brand_tag, + "source_image": source, + "data": { + "source": source, + "brand_tag": brand_tag, + "text": content.get("text"), + "brand": content["brand"], + "notify": content["notify"], + }, } - if 'automatic_processing' in content: - insert['automatic_processing'] = content['automatic_processing'] + if "automatic_processing" in content: + insert["automatic_processing"] = content["automatic_processing"] yield insert brand_seen.add(brand_tag) @@ -723,7 +727,7 @@ def in_barcode_range(brand_tag: str, barcode: str) -> bool: logger.debug("Barcode range and barcode do not have the same length") return False - barcode_range = barcode_range.replace('x', '') + barcode_range = barcode_range.replace("x", "") if barcode.startswith(barcode_range): return True @@ -733,57 +737,54 @@ def in_barcode_range(brand_tag: str, barcode: str) -> bool: class StoreInsightImporter(OCRInsightImporter): - def deduplicate_insights(self, - data: Iterable[JSONType]) -> Iterable[JSONType]: - yield from self._deduplicate_insights( - data, lambda x: x['content']['store_tag']) + def deduplicate_insights(self, data: Iterable[JSONType]) -> Iterable[JSONType]: + yield from self._deduplicate_insights(data, lambda x: x["content"]["store_tag"]) @staticmethod def get_type() -> str: return InsightType.store.name - def is_valid(self, - store_tag: str, - store_seen: Set[str]) -> bool: + def is_valid(self, store_tag: str, store_seen: Set[str]) -> bool: if store_tag in store_seen: return False return True - def process_product_insights(self, barcode: str, - insights: List[JSONType]) \ - -> Iterable[JSONType]: + def process_product_insights( + self, barcode: str, insights: List[JSONType] + ) -> Iterable[JSONType]: store_seen: Set[str] = set() - for t in (ProductInsight.select(ProductInsight.value_tag) - .where(ProductInsight.type == - self.get_type(), - ProductInsight.barcode == - barcode)).iterator(): + for t in ( + ProductInsight.select(ProductInsight.value_tag).where( + ProductInsight.type == self.get_type(), + ProductInsight.barcode == barcode, + ) + ).iterator(): store_seen.add(t.value_tag) for insight in insights: - content = insight['content'] - store_tag = content['store_tag'] + content = insight["content"] + store_tag = content["store_tag"] if not self.is_valid(store_tag, store_seen): continue - source = insight['source'] + source = insight["source"] insert = { - 'value_tag': store_tag, - 'source_image': source, - 'data': { - 'source': source, - 'store_tag': store_tag, - 'text': content['text'], - 'store': content['store'], - 'notify': content['notify'], - } + "value_tag": store_tag, + "source_image": source, + "data": { + "source": source, + "store_tag": store_tag, + "text": content["text"], + "store": content["store"], + "notify": content["notify"], + }, } - if 'automatic_processing' in content: - insert['automatic_processing'] = content['automatic_processing'] + if "automatic_processing" in content: + insert["automatic_processing"] = content["automatic_processing"] yield insert store_seen.add(store_tag) @@ -806,8 +807,9 @@ class InsightImporterFactory: } @classmethod - def create(cls, insight_type: str, - product_store: Optional[ProductStore]) -> InsightImporter: + def create( + cls, insight_type: str, product_store: Optional[ProductStore] + ) -> InsightImporter: if insight_type in cls.importers: return cls.importers[insight_type](product_store) else: diff --git a/robotoff/insights/normalize.py b/robotoff/insights/normalize.py index 037059b195..df98cd1c97 100644 --- a/robotoff/insights/normalize.py +++ b/robotoff/insights/normalize.py @@ -2,15 +2,13 @@ def normalize_emb_code(emb_code: str): - emb_code = (emb_code.strip() - .lower() - .replace(' ', '') - .replace('-', '') - .replace('.', '')) + emb_code = ( + emb_code.strip().lower().replace(" ", "").replace("-", "").replace(".", "") + ) emb_code = strip_accents_ascii(emb_code) - if emb_code.endswith('ce'): - emb_code = emb_code[:-2] + 'ec' + if emb_code.endswith("ce"): + emb_code = emb_code[:-2] + "ec" return emb_code diff --git a/robotoff/insights/ocr/__init__.py b/robotoff/insights/ocr/__init__.py index b7f6e836c4..a32ffb7217 100644 --- a/robotoff/insights/ocr/__init__.py +++ b/robotoff/insights/ocr/__init__.py @@ -1,4 +1,2 @@ -from .core import (ocr_iter, - extract_insights, - get_barcode_from_path) +from .core import extract_insights, get_barcode_from_path, ocr_iter from .dataclass import OCRResult diff --git a/robotoff/insights/ocr/brand.py b/robotoff/insights/ocr/brand.py index 90279d6d24..7b5f9ae299 100644 --- a/robotoff/insights/ocr/brand.py +++ b/robotoff/insights/ocr/brand.py @@ -1,9 +1,9 @@ import re -from typing import List, Dict, Tuple, Set +from typing import Dict, List, Set, Tuple from robotoff import settings -from robotoff.insights.ocr.dataclass import OCRResult, OCRRegex, OCRField -from robotoff.utils import text_file_iter, get_logger +from robotoff.insights.ocr.dataclass import OCRField, OCRRegex, OCRResult +from robotoff.utils import get_logger, text_file_iter logger = get_logger(__name__) @@ -12,8 +12,8 @@ def get_logo_annotation_brands() -> Dict[str, str]: brands: Dict[str, str] = {} for item in text_file_iter(settings.OCR_LOGO_ANNOTATION_BRANDS_DATA_PATH): - if '||' in item: - logo_description, label_tag = item.split('||') + if "||" in item: + logo_description, label_tag = item.split("||") else: logger.warn("'||' separator expected!") continue @@ -27,10 +27,7 @@ def get_logo_annotation_brands() -> Dict[str, str]: def get_brand_tag(brand: str) -> str: - return (brand.lower() - .replace(' & ', '-') - .replace(' ', '-') - .replace("'", '-')) + return brand.lower().replace(" & ", "-").replace(" ", "-").replace("'", "-") def brand_sort_key(item): @@ -47,8 +44,8 @@ def get_sorted_brands() -> List[Tuple[str, str]]: sorted_brands: Dict[str, str] = {} for item in text_file_iter(settings.OCR_BRANDS_DATA_PATH): - if '||' in item: - brand, regex_str = item.split('||') + if "||" in item: + brand, regex_str = item.split("||") else: brand = item regex_str = re.escape(item.lower()) @@ -59,13 +56,15 @@ def get_sorted_brands() -> List[Tuple[str, str]]: SORTED_BRANDS = get_sorted_brands() -BRAND_REGEX_STR = "|".join(r"((? List[Dict]: @@ -82,25 +81,29 @@ def find_brands(ocr_result: OCRResult) -> List[Dict]: for idx, match_str in enumerate(groups): if match_str is not None: brand, _ = SORTED_BRANDS[idx] - results.append({ - 'brand': brand, - 'brand_tag': get_brand_tag(brand), - 'text': match_str, - 'notify': brand not in NOTIFY_BRANDS_WHITELIST, - }) + results.append( + { + "brand": brand, + "brand_tag": get_brand_tag(brand), + "text": match_str, + "notify": brand not in NOTIFY_BRANDS_WHITELIST, + } + ) return results for logo_annotation in ocr_result.logo_annotations: if logo_annotation.description in LOGO_ANNOTATION_BRANDS: brand = LOGO_ANNOTATION_BRANDS[logo_annotation.description] - results.append({ - 'brand': brand, - 'brand_tag': get_brand_tag(brand), - 'automatic_processing': False, - 'confidence': logo_annotation.score, - 'model': 'google-cloud-vision', - }) + results.append( + { + "brand": brand, + "brand_tag": get_brand_tag(brand), + "automatic_processing": False, + "confidence": logo_annotation.score, + "model": "google-cloud-vision", + } + ) return results return results diff --git a/robotoff/insights/ocr/core.py b/robotoff/insights/ocr/core.py index 1ca508b74f..24e559724d 100644 --- a/robotoff/insights/ocr/core.py +++ b/robotoff/insights/ocr/core.py @@ -1,9 +1,8 @@ # -*- coding: utf-8 -*- import gzip import json - import pathlib as pathlib -from typing import List, Dict, Iterable, Optional, Tuple +from typing import Dict, Iterable, List, Optional, Tuple import requests @@ -27,7 +26,7 @@ def get_barcode_from_path(path: str) -> Optional[str]: - barcode = '' + barcode = "" for parent in pathlib.Path(path).parents: if parent.name.isdigit(): @@ -39,14 +38,15 @@ def get_barcode_from_path(path: str) -> Optional[str]: def fetch_images_for_ean(ean: str): - url = "https://world.openfoodfacts.org/api/v0/product/" \ - "{}.json?fields=images".format(ean) + url = ( + "https://world.openfoodfacts.org/api/v0/product/" + "{}.json?fields=images".format(ean) + ) images = requests.get(url).json() return images -def get_json_for_image(barcode: str, image_name: str) -> \ - Optional[JSONType]: +def get_json_for_image(barcode: str, image_name: str) -> Optional[JSONType]: url = generate_json_ocr_url(barcode, image_name) r = requests.get(url) @@ -56,8 +56,7 @@ def get_json_for_image(barcode: str, image_name: str) -> \ return r.json() -def extract_insights(ocr_result: OCRResult, - insight_type: str) -> List[Dict]: +def extract_insights(ocr_result: OCRResult, insight_type: str) -> List[Dict]: if insight_type == InsightType.packager_code.name: return find_packager_codes(ocr_result) @@ -76,10 +75,10 @@ def extract_insights(ocr_result: OCRResult, elif insight_type == InsightType.product_weight.name: return find_product_weight(ocr_result) - elif insight_type == 'trace': + elif insight_type == "trace": return find_traces(ocr_result) - elif insight_type == 'nutrient': + elif insight_type == "nutrient": return find_nutrient_values(ocr_result) elif insight_type == InsightType.brand.name: @@ -100,14 +99,12 @@ def get_source(image_name: str, json_path: str = None, barcode: str = None): if not barcode: barcode = get_barcode_from_path(str(json_path)) - return "/{}/{}.jpg" \ - "".format('/'.join(split_barcode(barcode)), - image_name) + return "/{}/{}.jpg" "".format("/".join(split_barcode(barcode)), image_name) def ocr_iter(input_str: str) -> Iterable[Tuple[Optional[str], Dict]]: if is_barcode(input_str): - image_data = fetch_images_for_ean(input_str)['product']['images'] + image_data = fetch_images_for_ean(input_str)["product"]["images"] for image_name in image_data.keys(): if image_name.isdigit(): @@ -126,25 +123,24 @@ def ocr_iter(input_str: str) -> Iterable[Tuple[Optional[str], Dict]]: if input_path.is_dir(): for json_path in input_path.glob("**/*.json"): - with open(str(json_path), 'r') as f: - source = get_source(json_path.stem, - json_path=str(json_path)) + with open(str(json_path), "r") as f: + source = get_source(json_path.stem, json_path=str(json_path)) yield source, json.load(f) else: - if '.json' in input_path.suffixes: - with open(str(input_path), 'r') as f: + if ".json" in input_path.suffixes: + with open(str(input_path), "r") as f: yield None, json.load(f) - elif '.jsonl' in input_path.suffixes: - if input_path.suffix == '.gz': + elif ".jsonl" in input_path.suffixes: + if input_path.suffix == ".gz": open_func = gzip.open else: open_func = open - with open_func(input_path, mode='rt') as f: + with open_func(input_path, mode="rt") as f: for line in f: json_data = json.loads(line) - if 'content' in json_data: - source = json_data['source'].replace('//', '/') - yield source, json_data['content'] + if "content" in json_data: + source = json_data["source"].replace("//", "/") + yield source, json_data["content"] diff --git a/robotoff/insights/ocr/dataclass.py b/robotoff/insights/ocr/dataclass.py index 43a30e558b..1b3ba8e7fe 100644 --- a/robotoff/insights/ocr/dataclass.py +++ b/robotoff/insights/ocr/dataclass.py @@ -2,12 +2,11 @@ import operator import re from collections import Counter -from typing import Optional, Callable, Dict, List +from typing import Callable, Dict, List, Optional from robotoff.utils import get_logger from robotoff.utils.types import JSONType - MULTIPLE_SPACES_REGEX = re.compile(r" {2,}") logger = get_logger(__name__) @@ -20,15 +19,17 @@ class OCRField(enum.Enum): class OCRRegex: - __slots__ = ('regex', 'field', 'lowercase', 'processing_func', - 'priority', 'notify') - - def __init__(self, regex, - field: OCRField, - lowercase: bool = False, - processing_func: Optional[Callable] = None, - priority: Optional[int] = None, - notify: bool = False): + __slots__ = ("regex", "field", "lowercase", "processing_func", "priority", "notify") + + def __init__( + self, + regex, + field: OCRField, + lowercase: bool = False, + processing_func: Optional[Callable] = None, + priority: Optional[int] = None, + notify: bool = False, + ): self.regex = regex self.field: OCRField = field self.lowercase: bool = lowercase @@ -46,7 +47,7 @@ class ImageOrientation(enum.Enum): class OrientationResult: - __slots__ = ('count', 'orientation') + __slots__ = ("count", "orientation") def __init__(self, count: Counter): most_common_list = count.most_common(1) @@ -57,22 +58,25 @@ def __init__(self, count: Counter): else: self.orientation = ImageOrientation.unknown - self.count: Dict[str, int] = {key.name: value - for key, value in count.items()} + self.count: Dict[str, int] = {key.name: value for key, value in count.items()} def to_json(self) -> JSONType: return { - 'count': self.count, - 'orientation': self.orientation.name, + "count": self.count, + "orientation": self.orientation.name, } class OCRResult: - __slots__ = ('text_annotations', 'text_annotations_str', - 'text_annotations_str_lower', - 'full_text_annotation', - 'logo_annotations', 'safe_search_annotation', - 'label_annotations') + __slots__ = ( + "text_annotations", + "text_annotations_str", + "text_annotations_str_lower", + "full_text_annotation", + "logo_annotations", + "safe_search_annotation", + "label_annotations", + ) def __init__(self, data: JSONType): self.text_annotations: List[OCRTextAnnotation] = [] @@ -81,7 +85,7 @@ def __init__(self, data: JSONType): self.label_annotations: List[LabelAnnotation] = [] self.safe_search_annotation: Optional[SafeSearchAnnotation] = None - for text_annotation_data in data.get('textAnnotations', []): + for text_annotation_data in data.get("textAnnotations", []): text_annotation = OCRTextAnnotation(text_annotation_data) self.text_annotations.append(text_annotation) @@ -89,28 +93,26 @@ def __init__(self, data: JSONType): self.text_annotations_str_lower: Optional[str] = None if self.text_annotations: - self.text_annotations_str = '||'.join(t.text - for t in self.text_annotations) - self.text_annotations_str_lower = (self.text_annotations_str - .lower()) + self.text_annotations_str = "||".join(t.text for t in self.text_annotations) + self.text_annotations_str_lower = self.text_annotations_str.lower() - full_text_annotation_data = data.get('fullTextAnnotation') + full_text_annotation_data = data.get("fullTextAnnotation") if full_text_annotation_data: - self.full_text_annotation = OCRFullTextAnnotation( - full_text_annotation_data) + self.full_text_annotation = OCRFullTextAnnotation(full_text_annotation_data) - for logo_annotation_data in data.get('logoAnnotations', []): + for logo_annotation_data in data.get("logoAnnotations", []): logo_annotation = LogoAnnotation(logo_annotation_data) self.logo_annotations.append(logo_annotation) - for label_annotation_data in data.get('labelAnnotations', []): + for label_annotation_data in data.get("labelAnnotations", []): label_annotation = LabelAnnotation(label_annotation_data) self.label_annotations.append(label_annotation) - if 'safeSearchAnnotation' in data: + if "safeSearchAnnotation" in data: self.safe_search_annotation = SafeSearchAnnotation( - data['safeSearchAnnotation']) + data["safeSearchAnnotation"] + ) def get_full_text(self, lowercase: bool = False) -> Optional[str]: if self.full_text_annotation is not None: @@ -164,10 +166,10 @@ def get_text(self, ocr_regex: OCRRegex) -> Optional[str]: else: raise ValueError("invalid field: {}".format(field)) - def get_logo_annotations(self) -> List['LogoAnnotation']: + def get_logo_annotations(self) -> List["LogoAnnotation"]: return self.logo_annotations - def get_label_annotations(self) -> List['LabelAnnotation']: + def get_label_annotations(self) -> List["LabelAnnotation"]: return self.label_annotations def get_safe_search_annotation(self): @@ -180,33 +182,38 @@ def get_orientation(self) -> Optional[OrientationResult]: return None @classmethod - def from_json(cls, data: JSONType) -> Optional['OCRResult']: - responses = data.get('responses', []) + def from_json(cls, data: JSONType) -> Optional["OCRResult"]: + responses = data.get("responses", []) if not responses: return None response = responses[0] - if 'error' in response: + if "error" in response: return None return OCRResult(response) class OCRFullTextAnnotation: - __slots__ = ('text', 'text_lower', - 'pages', 'contiguous_text', 'contiguous_text_lower') + __slots__ = ( + "text", + "text_lower", + "pages", + "contiguous_text", + "contiguous_text_lower", + ) def __init__(self, data: JSONType): - self.text = MULTIPLE_SPACES_REGEX.sub(' ', data['text']) + self.text = MULTIPLE_SPACES_REGEX.sub(" ", data["text"]) self.text_lower = self.text.lower() - self.contiguous_text = self.text.replace('\n', ' ') - self.contiguous_text = MULTIPLE_SPACES_REGEX.sub(' ', - self.contiguous_text) + self.contiguous_text = self.text.replace("\n", " ") + self.contiguous_text = MULTIPLE_SPACES_REGEX.sub(" ", self.contiguous_text) self.contiguous_text_lower = self.contiguous_text.lower() - self.pages: List[TextAnnotationPage] = [TextAnnotationPage(page) - for page in data['pages']] + self.pages: List[TextAnnotationPage] = [ + TextAnnotationPage(page) for page in data["pages"] + ] def detect_orientation(self) -> OrientationResult: word_orientations: List[ImageOrientation] = [] @@ -220,9 +227,9 @@ def detect_orientation(self) -> OrientationResult: class TextAnnotationPage: def __init__(self, data: JSONType): - self.width = data['width'] - self.height = data['height'] - self.blocks = [Block(d) for d in data['blocks']] + self.width = data["width"] + self.height = data["height"] + self.blocks = [Block(d) for d in data["blocks"]] def detect_words_orientation(self) -> List[ImageOrientation]: word_orientations: List[ImageOrientation] = [] @@ -235,10 +242,9 @@ def detect_words_orientation(self) -> List[ImageOrientation]: class Block: def __init__(self, data: JSONType): - self.type = data['blockType'] - self.paragraphs = [Paragraph(paragraph) - for paragraph in data['paragraphs']] - self.bounding_poly = BoundingPoly(data['boundingBox']) + self.type = data["blockType"] + self.paragraphs = [Paragraph(paragraph) for paragraph in data["paragraphs"]] + self.bounding_poly = BoundingPoly(data["boundingBox"]) def detect_orientation(self) -> ImageOrientation: return self.bounding_poly.detect_orientation() @@ -254,8 +260,8 @@ def detect_words_orientation(self) -> List[ImageOrientation]: class Paragraph: def __init__(self, data: JSONType): - self.words = [Word(word) for word in data['words']] - self.bounding_poly = BoundingPoly(data['boundingBox']) + self.words = [Word(word) for word in data["words"]] + self.bounding_poly = BoundingPoly(data["boundingBox"]) def detect_orientation(self) -> ImageOrientation: return self.bounding_poly.detect_orientation() @@ -265,28 +271,28 @@ def detect_words_orientation(self) -> List[ImageOrientation]: def get_text(self) -> str: """Return the text of the paragraph, by concatenating the words.""" - return ''.join(w.get_text() for w in self.words) + return "".join(w.get_text() for w in self.words) class Word: - __slots__ = ('bounding_poly', 'symbols', 'languages') + __slots__ = ("bounding_poly", "symbols", "languages") def __init__(self, data: JSONType): - self.bounding_poly = BoundingPoly(data['boundingBox']) - self.symbols: List[Symbol] = [Symbol(s) for s in data['symbols']] + self.bounding_poly = BoundingPoly(data["boundingBox"]) + self.symbols: List[Symbol] = [Symbol(s) for s in data["symbols"]] self.languages: Optional[List[DetectedLanguage]] = None - word_property = data.get('property', {}) + word_property = data.get("property", {}) - if 'detectedLanguages' in word_property: + if "detectedLanguages" in word_property: self.languages: List[DetectedLanguage] = [ - DetectedLanguage(l) for l in - data['property']['detectedLanguages']] + DetectedLanguage(l) for l in data["property"]["detectedLanguages"] + ] def get_text(self) -> str: text_list = [] for symbol in self.symbols: - symbol_str = '' + symbol_str = "" if symbol.symbol_break and symbol.symbol_break.is_prefix: symbol_str = symbol.symbol_break.get_value() @@ -298,69 +304,69 @@ def get_text(self) -> str: text_list.append(symbol_str) - return ''.join(text_list) + return "".join(text_list) def detect_orientation(self) -> ImageOrientation: return self.bounding_poly.detect_orientation() class Symbol: - __slots__ = ('bounding_poly', 'text', 'confidence', 'symbol_break') + __slots__ = ("bounding_poly", "text", "confidence", "symbol_break") def __init__(self, data: JSONType): - self.bounding_poly = BoundingPoly(data['boundingBox']) - self.text = data['text'] - self.confidence = data.get('confidence', None) + self.bounding_poly = BoundingPoly(data["boundingBox"]) + self.text = data["text"] + self.confidence = data.get("confidence", None) self.symbol_break: Optional[DetectedBreak] = None - symbol_property = data.get('property', {}) + symbol_property = data.get("property", {}) - if 'detectedBreak' in symbol_property: - self.symbol_break = DetectedBreak( - symbol_property['detectedBreak']) + if "detectedBreak" in symbol_property: + self.symbol_break = DetectedBreak(symbol_property["detectedBreak"]) def detect_orientation(self) -> ImageOrientation: return self.bounding_poly.detect_orientation() class DetectedBreak: - __slots__ = ('type', 'is_prefix') + __slots__ = ("type", "is_prefix") def __init__(self, data: JSONType): - self.type = data['type'] - self.is_prefix = data.get('isPrefix', False) + self.type = data["type"] + self.is_prefix = data.get("isPrefix", False) def __repr__(self): return "".format(self.type) def get_value(self): - if self.type in ('UNKNOWN', 'HYPHEN'): - return '' + if self.type in ("UNKNOWN", "HYPHEN"): + return "" - elif self.type in ('SPACE', 'SURE_SPACE', 'EOL_SURE_SPACE'): - return ' ' + elif self.type in ("SPACE", "SURE_SPACE", "EOL_SURE_SPACE"): + return " " - elif self.type == 'LINE_BREAK': - return '\n' + elif self.type == "LINE_BREAK": + return "\n" else: raise ValueError("unknown type: {}".format(self.type)) class DetectedLanguage: - __slots__ = ('language', 'confidence') + __slots__ = ("language", "confidence") def __init__(self, data: JSONType): - self.language = data['languageCode'] - self.confidence = data.get('confidence', 0) + self.language = data["languageCode"] + self.confidence = data.get("confidence", 0) class BoundingPoly: - __slots__ = ('vertices', ) + __slots__ = ("vertices",) def __init__(self, data: JSONType): - self.vertices = [(point.get('x', 0), point.get('y', 0)) - for point in data['vertices']] + self.vertices = [ + (point.get("x", 0), point.get("y", 0)) for point in data["vertices"] + ] def detect_orientation(self) -> ImageOrientation: """Detect bounding poly orientation (up, down, left, or right). @@ -390,12 +396,10 @@ def detect_orientation(self) -> ImageOrientation: - (3, 0) for 90° clockwise rotation (right) It is u """ - indexed_vertices = [(x[0], x[1], i) - for i, x in enumerate(self.vertices)] + indexed_vertices = [(x[0], x[1], i) for i, x in enumerate(self.vertices)] # Sort by ascending y-value and select first two vertices: # get the two topmost vertices - indexed_vertices = sorted(indexed_vertices, - key=operator.itemgetter(1))[:2] + indexed_vertices = sorted(indexed_vertices, key=operator.itemgetter(1))[:2] first_vertex_index = indexed_vertices[0][2] second_vertex_index = indexed_vertices[1][2] @@ -417,52 +421,59 @@ def detect_orientation(self) -> ImageOrientation: return ImageOrientation.right else: - logger.error("Unknown orientation: edge {}, vertices {}" - "".format(first_edge, self.vertices)) + logger.error( + "Unknown orientation: edge {}, vertices {}" + "".format(first_edge, self.vertices) + ) return ImageOrientation.unknown class OCRTextAnnotation: - __slots__ = ('locale', 'text', 'bounding_poly') + __slots__ = ("locale", "text", "bounding_poly") def __init__(self, data: JSONType): - self.locale = data.get('locale') - self.text = data['description'] - self.bounding_poly = BoundingPoly(data['boundingPoly']) + self.locale = data.get("locale") + self.text = data["description"] + self.bounding_poly = BoundingPoly(data["boundingPoly"]) class LogoAnnotation: - __slots__ = ('id', 'description', 'score') + __slots__ = ("id", "description", "score") def __init__(self, data: JSONType): - self.id = data.get('mid') or None - self.score = data['score'] - self.description = data['description'] + self.id = data.get("mid") or None + self.score = data["score"] + self.description = data["description"] class LabelAnnotation: - __slots__ = ('id', 'description', 'score') + __slots__ = ("id", "description", "score") def __init__(self, data: JSONType): - self.id = data.get('mid') or None - self.score = data['score'] - self.description = data['description'] + self.id = data.get("mid") or None + self.score = data["score"] + self.description = data["description"] class SafeSearchAnnotation: - __slots__ = ('adult', 'spoof', 'medical', 'violence', 'racy') + __slots__ = ("adult", "spoof", "medical", "violence", "racy") def __init__(self, data: JSONType): - self.adult: SafeSearchAnnotationLikelihood = \ - SafeSearchAnnotationLikelihood[data['adult']] - self.spoof: SafeSearchAnnotationLikelihood = \ - SafeSearchAnnotationLikelihood[data['spoof']] - self.medical: SafeSearchAnnotationLikelihood = \ - SafeSearchAnnotationLikelihood[data['medical']] - self.violence: SafeSearchAnnotationLikelihood = \ - SafeSearchAnnotationLikelihood[data['violence']] - self.racy: SafeSearchAnnotationLikelihood = \ - SafeSearchAnnotationLikelihood[data['racy']] + self.adult: SafeSearchAnnotationLikelihood = SafeSearchAnnotationLikelihood[ + data["adult"] + ] + self.spoof: SafeSearchAnnotationLikelihood = SafeSearchAnnotationLikelihood[ + data["spoof"] + ] + self.medical: SafeSearchAnnotationLikelihood = SafeSearchAnnotationLikelihood[ + data["medical"] + ] + self.violence: SafeSearchAnnotationLikelihood = SafeSearchAnnotationLikelihood[ + data["violence"] + ] + self.racy: SafeSearchAnnotationLikelihood = SafeSearchAnnotationLikelihood[ + data["racy"] + ] class SafeSearchAnnotationLikelihood(enum.IntEnum): diff --git a/robotoff/insights/ocr/expiration_date.py b/robotoff/insights/ocr/expiration_date.py index fdcdce5bbd..73e69e81c4 100644 --- a/robotoff/insights/ocr/expiration_date.py +++ b/robotoff/insights/ocr/expiration_date.py @@ -1,9 +1,9 @@ import datetime import functools import re -from typing import List, Dict, Optional +from typing import Dict, List, Optional -from robotoff.insights.ocr.dataclass import OCRResult, OCRRegex, OCRField +from robotoff.insights.ocr.dataclass import OCRField, OCRRegex, OCRResult def process_full_digits_expiration_date(match, short: bool) -> Optional[datetime.date]: @@ -15,7 +15,9 @@ def process_full_digits_expiration_date(match, short: bool) -> Optional[datetime format_str = "%d/%m/%Y" try: - date = datetime.datetime.strptime("{}/{}/{}".format(day, month, year), format_str).date() + date = datetime.datetime.strptime( + "{}/{}/{}".format(day, month, year), format_str + ).date() except ValueError: return None @@ -23,16 +25,22 @@ def process_full_digits_expiration_date(match, short: bool) -> Optional[datetime EXPIRATION_DATE_REGEX: Dict[str, OCRRegex] = { - 'full_digits_short': OCRRegex(re.compile(r'(? List[Dict]: value = date.strftime("%d/%m/%Y") - results.append({ - "raw": raw, - "text": value, - "type": type_, - "notify": ocr_regex.notify, - }) + results.append( + { + "raw": raw, + "text": value, + "type": type_, + "notify": ocr_regex.notify, + } + ) return results diff --git a/robotoff/insights/ocr/image_flag.py b/robotoff/insights/ocr/image_flag.py index 8ac3efea7b..dafc8f0dd1 100644 --- a/robotoff/insights/ocr/image_flag.py +++ b/robotoff/insights/ocr/image_flag.py @@ -1,4 +1,4 @@ -from typing import List, Dict +from typing import Dict, List from robotoff.insights.ocr.dataclass import OCRResult, SafeSearchAnnotationLikelihood @@ -9,22 +9,27 @@ def flag_image(ocr_result: OCRResult) -> List[Dict]: insights: List[Dict] = [] if safe_search_annotation: - for key in ('adult', 'violence'): - value: SafeSearchAnnotationLikelihood = \ - getattr(safe_search_annotation, key) + for key in ("adult", "violence"): + value: SafeSearchAnnotationLikelihood = getattr(safe_search_annotation, key) if value >= SafeSearchAnnotationLikelihood.VERY_LIKELY: - insights.append({ - 'type': key, - 'likelihood': value.name, - }) + insights.append( + { + "type": key, + "likelihood": value.name, + } + ) for label_annotation in label_annotations: - if (label_annotation.description in ('Face', 'Head', 'Selfie') and - label_annotation.score >= 0.8): - insights.append({ - 'type': label_annotation.description.lower(), - 'likelihood': label_annotation.score - }) + if ( + label_annotation.description in ("Face", "Head", "Selfie") + and label_annotation.score >= 0.8 + ): + insights.append( + { + "type": label_annotation.description.lower(), + "likelihood": label_annotation.score, + } + ) break return insights diff --git a/robotoff/insights/ocr/image_orientation.py b/robotoff/insights/ocr/image_orientation.py index 391bb7ad35..ff6253fbef 100644 --- a/robotoff/insights/ocr/image_orientation.py +++ b/robotoff/insights/ocr/image_orientation.py @@ -1,13 +1,15 @@ -from typing import List, Dict +from typing import Dict, List -from robotoff.insights.ocr.dataclass import OCRResult, ImageOrientation +from robotoff.insights.ocr.dataclass import ImageOrientation, OCRResult def find_image_orientation(ocr_result: OCRResult) -> List[Dict]: orientation_result = ocr_result.get_orientation() - if (orientation_result is None - or orientation_result.orientation == ImageOrientation.up): + if ( + orientation_result is None + or orientation_result.orientation == ImageOrientation.up + ): return [] return [orientation_result.to_json()] diff --git a/robotoff/insights/ocr/label.py b/robotoff/insights/ocr/label.py index bf5d9382c7..e5f8317f79 100644 --- a/robotoff/insights/ocr/label.py +++ b/robotoff/insights/ocr/label.py @@ -2,258 +2,326 @@ from typing import Dict, List from robotoff import settings -from robotoff.insights.ocr.dataclass import OCRRegex, OCRField, OCRResult -from robotoff.utils import text_file_iter, get_logger +from robotoff.insights.ocr.dataclass import OCRField, OCRRegex, OCRResult +from robotoff.utils import get_logger, text_file_iter logger = get_logger(__name__) def process_eu_bio_label_code(match) -> str: - return ("en:{}-{}-{}".format(match.group(1), - match.group(2), - match.group(3)) - .lower() - .replace('ö', 'o') - .replace('ø', 'o')) + return ( + "en:{}-{}-{}".format(match.group(1), match.group(2), match.group(3)) + .lower() + .replace("ö", "o") + .replace("ø", "o") + ) EN_ORGANIC_REGEX_STR = [ - r'ingr[ée]dients?\sbiologiques?', - r'ingr[ée]dients?\sbio[\s.,)]', - r'agriculture ue/non ue biologique', - r'agriculture bio(?:logique)?[\s.,)]', - r'production bio(?:logique)?[\s.,)]', + r"ingr[ée]dients?\sbiologiques?", + r"ingr[ée]dients?\sbio[\s.,)]", + r"agriculture ue/non ue biologique", + r"agriculture bio(?:logique)?[\s.,)]", + r"production bio(?:logique)?[\s.,)]", ] LABELS_REGEX = { - 'en:organic': [ - OCRRegex(re.compile(r"|".join([r"(?:{})".format(x) - for x in EN_ORGANIC_REGEX_STR])), - field=OCRField.full_text_contiguous, - lowercase=True), + "en:organic": [ + OCRRegex( + re.compile(r"|".join([r"(?:{})".format(x) for x in EN_ORGANIC_REGEX_STR])), + field=OCRField.full_text_contiguous, + lowercase=True, + ), ], - 'xx-bio-xx': [ + "xx-bio-xx": [ # The negative lookbehind (? Dict[str, str]: labels: Dict[str, str] = {} for item in text_file_iter(settings.OCR_LOGO_ANNOTATION_LABELS_DATA_PATH): - if '||' in item: - logo_description, label_tag = item.split('||') + if "||" in item: + logo_description, label_tag = item.split("||") else: logger.warn("'||' separator expected!") continue @@ -292,21 +360,25 @@ def find_labels(ocr_result: OCRResult) -> List[Dict]: else: label_value = label_tag - results.append({ - 'label_tag': label_value, - 'text': match.group(), - 'notify': ocr_regex.notify, - }) + results.append( + { + "label_tag": label_value, + "text": match.group(), + "notify": ocr_regex.notify, + } + ) for logo_annotation in ocr_result.logo_annotations: if logo_annotation.description in LOGO_ANNOTATION_LABELS: label_tag = LOGO_ANNOTATION_LABELS[logo_annotation.description] - results.append({ - 'label_tag': label_tag, - 'automatic_processing': False, - 'confidence': logo_annotation.score, - 'model': 'google-cloud-vision', - }) + results.append( + { + "label_tag": label_tag, + "automatic_processing": False, + "confidence": logo_annotation.score, + "model": "google-cloud-vision", + } + ) return results diff --git a/robotoff/insights/ocr/nutrient.py b/robotoff/insights/ocr/nutrient.py index 34d45ac075..56cae5e6c9 100644 --- a/robotoff/insights/ocr/nutrient.py +++ b/robotoff/insights/ocr/nutrient.py @@ -1,33 +1,40 @@ import re -from typing import List, Dict +from typing import Dict, List -from robotoff.insights.ocr.dataclass import OCRResult, OCRRegex, OCRField +from robotoff.insights.ocr.dataclass import OCRField, OCRRegex, OCRResult def generate_nutrient_regex(nutrient_names: List[str], units: List[str]): - nutrient_names_str = '|'.join(nutrient_names) - units_str = '|'.join(units) - return re.compile(r"({}) ?(?:[:-] ?)?([0-9]+[,.]?[0-9]*) ?({})".format(nutrient_names_str, - units_str)) + nutrient_names_str = "|".join(nutrient_names) + units_str = "|".join(units) + return re.compile( + r"({}) ?(?:[:-] ?)?([0-9]+[,.]?[0-9]*) ?({})".format( + nutrient_names_str, units_str + ) + ) NUTRIENT_VALUES_REGEX = { - 'energy': OCRRegex( + "energy": OCRRegex( generate_nutrient_regex(["[ée]nergie", "energy"], ["kj", "kcal"]), field=OCRField.full_text_contiguous, - lowercase=True), - 'fat': OCRRegex( + lowercase=True, + ), + "fat": OCRRegex( generate_nutrient_regex(["mati[èe]res? grasses?"], ["g"]), field=OCRField.full_text_contiguous, - lowercase=True), - 'glucid': OCRRegex( + lowercase=True, + ), + "glucid": OCRRegex( generate_nutrient_regex(["glucides?", "glucids?"], ["g"]), field=OCRField.full_text_contiguous, - lowercase=True), - 'carbohydrate': OCRRegex( + lowercase=True, + ), + "carbohydrate": OCRRegex( generate_nutrient_regex(["sucres?", "carbohydrates?"], ["g"]), field=OCRField.full_text_contiguous, - lowercase=True), + lowercase=True, + ), } @@ -41,14 +48,16 @@ def find_nutrient_values(ocr_result: OCRResult) -> List[Dict]: continue for match in ocr_regex.regex.finditer(text): - value = match.group(2).replace(',', '.') + value = match.group(2).replace(",", ".") unit = match.group(3) - results.append({ - "raw": match.group(0), - "nutrient": regex_code, - 'value': value, - 'unit': unit, - 'notify': ocr_regex.notify, - }) + results.append( + { + "raw": match.group(0), + "nutrient": regex_code, + "value": value, + "unit": unit, + "notify": ocr_regex.notify, + } + ) return results diff --git a/robotoff/insights/ocr/packager_code.py b/robotoff/insights/ocr/packager_code.py index 05925ab524..1043cc77df 100644 --- a/robotoff/insights/ocr/packager_code.py +++ b/robotoff/insights/ocr/packager_code.py @@ -1,7 +1,7 @@ import re from typing import Dict, List -from robotoff.insights.ocr.dataclass import OCRRegex, OCRField, OCRResult +from robotoff.insights.ocr.dataclass import OCRField, OCRRegex, OCRResult def process_fr_packaging_match(match) -> str: @@ -12,31 +12,39 @@ def process_fr_packaging_match(match) -> str: def process_de_packaging_match(match) -> str: federal_state_tag, company_tag = match.group(1, 2) - return "DE {}-{} EC".format(federal_state_tag, - company_tag).upper() + return "DE {}-{} EC".format(federal_state_tag, company_tag).upper() def process_fr_emb_match(match) -> str: city_code, company_code = match.group(1, 2) - city_code = city_code.replace(' ', '') - company_code = company_code or '' - return "EMB {}{}".format(city_code, - company_code).upper() + city_code = city_code.replace(" ", "") + company_code = company_code or "" + return "EMB {}{}".format(city_code, company_code).upper() PACKAGER_CODE: Dict[str, OCRRegex] = { - "fr_emb": OCRRegex(re.compile(r"emb ?(\d ?\d ?\d ?\d ?\d) ?([a-z])?(?![a-z0-9])"), - field=OCRField.text_annotations, - lowercase=True, - processing_func=process_fr_emb_match), - "eu_fr": OCRRegex(re.compile(r"fr (\d{2,3}|2[ab])[\-\s.](\d{3})[\-\s.](\d{3}) (ce|ec)(?![a-z0-9])"), - field=OCRField.full_text_contiguous, - lowercase=True, - processing_func=process_fr_packaging_match), - "eu_de": OCRRegex(re.compile(r"de (bb|be|bw|by|hb|he|hh|mv|ni|nw|rp|sh|sl|sn|st|th)[\-\s.](\d{1,5})[\-\s.] ?(eg|ec)(?![a-z0-9])"), - field=OCRField.full_text_contiguous, - lowercase=True, - processing_func=process_de_packaging_match), + "fr_emb": OCRRegex( + re.compile(r"emb ?(\d ?\d ?\d ?\d ?\d) ?([a-z])?(?![a-z0-9])"), + field=OCRField.text_annotations, + lowercase=True, + processing_func=process_fr_emb_match, + ), + "eu_fr": OCRRegex( + re.compile( + r"fr (\d{2,3}|2[ab])[\-\s.](\d{3})[\-\s.](\d{3}) (ce|ec)(?![a-z0-9])" + ), + field=OCRField.full_text_contiguous, + lowercase=True, + processing_func=process_fr_packaging_match, + ), + "eu_de": OCRRegex( + re.compile( + r"de (bb|be|bw|by|hb|he|hh|mv|ni|nw|rp|sh|sl|sn|st|th)[\-\s.](\d{1,5})[\-\s.] ?(eg|ec)(?![a-z0-9])" + ), + field=OCRField.full_text_contiguous, + lowercase=True, + processing_func=process_de_packaging_match, + ), } @@ -52,11 +60,13 @@ def find_packager_codes(ocr_result: OCRResult) -> List[Dict]: for match in ocr_regex.regex.finditer(text): if ocr_regex.processing_func is not None: value = ocr_regex.processing_func(match) - results.append({ - "raw": match.group(0), - "text": value, - "type": regex_code, - "notify": ocr_regex.notify, - }) + results.append( + { + "raw": match.group(0), + "text": value, + "type": regex_code, + "notify": ocr_regex.notify, + } + ) return results diff --git a/robotoff/insights/ocr/product_weight.py b/robotoff/insights/ocr/product_weight.py index 49193ec32b..9ede699f49 100644 --- a/robotoff/insights/ocr/product_weight.py +++ b/robotoff/insights/ocr/product_weight.py @@ -2,7 +2,7 @@ import re from typing import Dict, List -from robotoff.insights.ocr.dataclass import OCRRegex, OCRField, OCRResult +from robotoff.insights.ocr.dataclass import OCRField, OCRRegex, OCRResult def process_product_weight(match, prompt: bool) -> Dict: @@ -17,21 +17,21 @@ def process_product_weight(match, prompt: bool) -> Dict: value = match.group(1) unit = match.group(2) - if unit in ('dle', 'cle', 'mge', 'mle', 'ge', 'kge', 'le'): + if unit in ("dle", "cle", "mge", "mle", "ge", "kge", "le"): # When the e letter often comes after the weight unit, the # space is often not detected unit = unit[:-1] text = "{} {}".format(value, unit) result = { - 'text': text, - 'raw': raw, - 'value': value, - 'unit': unit, + "text": text, + "raw": raw, + "value": value, + "unit": unit, } if prompt_str is not None: - result['prompt'] = prompt_str + result["prompt"] = prompt_str return result @@ -43,42 +43,43 @@ def process_multi_packaging(match) -> Dict: value = match.group(2) unit = match.group(3) - if unit in ('dle', 'cle', 'mge', 'mle', 'ge', 'kge', 'le'): + if unit in ("dle", "cle", "mge", "mle", "ge", "kge", "le"): # When the e letter often comes after the weight unit, the # space is often not detected unit = unit[:-1] text = "{} x {} {}".format(count, value, unit) - result = { - 'text': text, - 'raw': raw, - 'value': value, - 'unit': unit, - 'count': count - } + result = {"text": text, "raw": raw, "value": value, "unit": unit, "count": count} return result PRODUCT_WEIGHT_REGEX: Dict[str, OCRRegex] = { - 'with_mention': OCRRegex( - re.compile(r"(poids|poids net [aà] l'emballage|poids net|poids net égoutté|masse nette|volume net total|net weight|net wt\.?|peso neto|peso liquido|netto[ -]?gewicht)\s?:?\s?([0-9]+[,.]?[0-9]*)\s?(fl oz|dle?|cle?|mge?|mle?|lbs|oz|ge?|kge?|le?)(?![a-z])"), + "with_mention": OCRRegex( + re.compile( + r"(poids|poids net [aà] l'emballage|poids net|poids net égoutté|masse nette|volume net total|net weight|net wt\.?|peso neto|peso liquido|netto[ -]?gewicht)\s?:?\s?([0-9]+[,.]?[0-9]*)\s?(fl oz|dle?|cle?|mge?|mle?|lbs|oz|ge?|kge?|le?)(?![a-z])" + ), field=OCRField.full_text_contiguous, lowercase=True, processing_func=functools.partial(process_product_weight, prompt=True), - priority=1), - 'multi_packaging': OCRRegex( - re.compile(r"(\d+)\s?x\s?([0-9]+[,.]?[0-9]*)\s?(fl oz|dle?|cle?|mge?|mle?|lbs|oz|ge?|kge?|le?)(?![a-z])"), + priority=1, + ), + "multi_packaging": OCRRegex( + re.compile( + r"(\d+)\s?x\s?([0-9]+[,.]?[0-9]*)\s?(fl oz|dle?|cle?|mge?|mle?|lbs|oz|ge?|kge?|le?)(?![a-z])" + ), field=OCRField.full_text_contiguous, lowercase=True, processing_func=process_multi_packaging, - priority=2), - 'no_mention': OCRRegex( + priority=2, + ), + "no_mention": OCRRegex( re.compile(r"([0-9]+[,.]?[0-9]*)\s?(dle|cle|mge|mle|ge|kge)(?![a-z])"), field=OCRField.full_text_contiguous, lowercase=True, processing_func=functools.partial(process_product_weight, prompt=False), - priority=3), + priority=3, + ), } @@ -96,9 +97,9 @@ def find_product_weight(ocr_result: OCRResult) -> List[Dict]: continue result = ocr_regex.processing_func(match) - result['matcher_type'] = type_ - result['priority'] = ocr_regex.priority - result['notify'] = ocr_regex.notify + result["matcher_type"] = type_ + result["priority"] = ocr_regex.priority + result["notify"] = ocr_regex.notify results.append(result) return results diff --git a/robotoff/insights/ocr/store.py b/robotoff/insights/ocr/store.py index 8ee176d869..5dd217a659 100644 --- a/robotoff/insights/ocr/store.py +++ b/robotoff/insights/ocr/store.py @@ -1,16 +1,13 @@ import re -from typing import List, Dict, Tuple, Set +from typing import Dict, List, Set, Tuple from robotoff import settings -from robotoff.insights.ocr.dataclass import OCRResult, OCRRegex, OCRField +from robotoff.insights.ocr.dataclass import OCRField, OCRRegex, OCRResult from robotoff.utils import text_file_iter def get_store_tag(store: str) -> str: - return (store.lower() - .replace(' & ', '-') - .replace(' ', '-') - .replace("'", '-')) + return store.lower().replace(" & ", "-").replace(" ", "-").replace("'", "-") def store_sort_key(item): @@ -27,8 +24,8 @@ def get_sorted_stores() -> List[Tuple[str, str]]: sorted_stores: Dict[str, str] = {} for item in text_file_iter(settings.OCR_STORES_DATA_PATH): - if '||' in item: - store, regex_str = item.split('||') + if "||" in item: + store, regex_str = item.split("||") else: store = item regex_str = re.escape(item.lower()) @@ -39,13 +36,15 @@ def get_sorted_stores() -> List[Tuple[str, str]]: SORTED_STORES = get_sorted_stores() -STORE_REGEX_STR = "|".join(r"((? List[Dict]: @@ -62,12 +61,14 @@ def find_stores(ocr_result: OCRResult) -> List[Dict]: for idx, match_str in enumerate(groups): if match_str is not None: store, _ = SORTED_STORES[idx] - results.append({ - 'store': store, - 'store_tag': get_store_tag(store), - 'text': match_str, - 'notify': store not in NOTIFY_STORES_WHITELIST, - }) + results.append( + { + "store": store, + "store_tag": get_store_tag(store), + "text": match_str, + "notify": store not in NOTIFY_STORES_WHITELIST, + } + ) break return results diff --git a/robotoff/insights/ocr/trace.py b/robotoff/insights/ocr/trace.py index f97e4406c5..7f2da5daee 100644 --- a/robotoff/insights/ocr/trace.py +++ b/robotoff/insights/ocr/trace.py @@ -1,12 +1,15 @@ import re -from typing import List, Dict +from typing import Dict, List -from robotoff.insights.ocr.dataclass import OCRRegex, OCRField, OCRResult +from robotoff.insights.ocr.dataclass import OCRField, OCRRegex, OCRResult TRACES_REGEX = OCRRegex( - re.compile(r"(?:possibilit[ée] de traces|peut contenir(?: des traces)?|traces? [ée]ventuelles? de)"), + re.compile( + r"(?:possibilit[ée] de traces|peut contenir(?: des traces)?|traces? [ée]ventuelles? de)" + ), field=OCRField.full_text_contiguous, - lowercase=True) + lowercase=True, +) def find_traces(ocr_result: OCRResult) -> List[Dict]: @@ -20,12 +23,12 @@ def find_traces(ocr_result: OCRResult) -> List[Dict]: for match in TRACES_REGEX.regex.finditer(text): raw = match.group() end_idx = match.end() - captured = text[end_idx:end_idx+100] + captured = text[end_idx : end_idx + 100] result = { - 'raw': raw, - 'text': captured, - 'notify': TRACES_REGEX.notify, + "raw": raw, + "text": captured, + "notify": TRACES_REGEX.notify, } results.append(result) diff --git a/robotoff/insights/question.py b/robotoff/insights/question.py index 36883380e2..d8618e984e 100644 --- a/robotoff/insights/question.py +++ b/robotoff/insights/question.py @@ -6,7 +6,7 @@ from robotoff.insights._enum import InsightType from robotoff.models import ProductInsight from robotoff.off import get_product -from robotoff.taxonomy import TaxonomyType, TAXONOMY_STORES, Taxonomy +from robotoff.taxonomy import TAXONOMY_STORES, Taxonomy, TaxonomyType from robotoff.utils import get_logger from robotoff.utils.i18n import TranslationStore from robotoff.utils.types import JSONType @@ -18,8 +18,10 @@ LABEL_IMAGES = { "en:eu-organic": LABEL_IMG_BASE_URL + "en/labels/eu-organic.135x90.svg", - "fr:ab-agriculture-biologique": LABEL_IMG_BASE_URL + "/fr/labels/ab-agriculture-biologique.74x90.svg", - "en:european-vegetarian-union": LABEL_IMG_BASE_URL + "/en/labels/european-vegetarian-union.90x90.svg", + "fr:ab-agriculture-biologique": LABEL_IMG_BASE_URL + + "/fr/labels/ab-agriculture-biologique.74x90.svg", + "en:european-vegetarian-union": LABEL_IMG_BASE_URL + + "/en/labels/european-vegetarian-union.90x90.svg", "en:pgi": LABEL_IMG_BASE_URL + "/en/labels/pgi.90x90.png", } @@ -35,11 +37,14 @@ def get_type(self): class AddBinaryQuestion(Question): - def __init__(self, question: str, - value: str, - insight: ProductInsight, - image_url: Optional[str] = None, - source_image_url: Optional[str] = None): + def __init__( + self, + question: str, + value: str, + insight: ProductInsight, + image_url: Optional[str] = None, + source_image_url: Optional[str] = None, + ): self.question: str = question self.value: str = value self.insight_id: str = str(insight.id) @@ -49,23 +54,23 @@ def __init__(self, question: str, self.source_image_url: Optional[str] = source_image_url def get_type(self): - return 'add-binary' + return "add-binary" def serialize(self) -> JSONType: serial = { - 'barcode': self.barcode, - 'type': self.get_type(), - 'value': self.value, - 'question': self.question, - 'insight_id': self.insight_id, - 'insight_type': self.insight_type, + "barcode": self.barcode, + "type": self.get_type(), + "value": self.value, + "question": self.question, + "insight_id": self.insight_id, + "insight_type": self.insight_type, } if self.image_url: - serial['image_url'] = self.image_url + serial["image_url"] = self.image_url if self.source_image_url: - serial['source_image_url'] = self.source_image_url + serial["source_image_url"] = self.source_image_url return serial @@ -88,31 +93,32 @@ def format_question(self, insight: ProductInsight, lang: str) -> Question: localized_value: str = taxonomy.get_localized_name(value, lang) localized_question = self.translation_store.gettext(lang, self.question) source_image_url = self.get_source_image_url(insight.barcode) - return AddBinaryQuestion(question=localized_question, - value=localized_value, - insight=insight, - source_image_url=source_image_url) + return AddBinaryQuestion( + question=localized_question, + value=localized_value, + insight=insight, + source_image_url=source_image_url, + ) @staticmethod def get_source_image_url(barcode: str) -> Optional[str]: - product: Optional[JSONType] = get_product(barcode, - fields=['selected_images']) + product: Optional[JSONType] = get_product(barcode, fields=["selected_images"]) if product is None: return None - if 'selected_images' not in product: + if "selected_images" not in product: return None - selected_images = product['selected_images'] + selected_images = product["selected_images"] - if 'front' not in selected_images: + if "front" not in selected_images: return None - front_images = selected_images['front'] + front_images = selected_images["front"] - if 'display' in front_images: - display_images = list(front_images['display'].values()) + if "display" in front_images: + display_images = list(front_images["display"].values()) if display_images: return display_images[0] @@ -124,22 +130,25 @@ class ProductWeightQuestionFormatter(QuestionFormatter): question = "Does this weight match the weight displayed on the product?" def format_question(self, insight: ProductInsight, lang: str) -> Question: - value: str = insight.data['text'] + value: str = insight.data["text"] localized_question = self.translation_store.gettext(lang, self.question) - source_image_url = (settings.OFF_IMAGE_BASE_URL + - get_display_image(insight.source_image)) + source_image_url = settings.OFF_IMAGE_BASE_URL + get_display_image( + insight.source_image + ) - return AddBinaryQuestion(question=localized_question, - value=value, - insight=insight, - source_image_url=source_image_url) + return AddBinaryQuestion( + question=localized_question, + value=value, + insight=insight, + source_image_url=source_image_url, + ) class LabelQuestionFormatter(QuestionFormatter): question = "Does the product have this label?" def format_question(self, insight: ProductInsight, lang: str) -> Question: - value: str = insight.data['label_tag'] + value: str = insight.data["label_tag"] image_url = None @@ -149,29 +158,35 @@ def format_question(self, insight: ProductInsight, lang: str) -> Question: taxonomy: Taxonomy = TAXONOMY_STORES[TaxonomyType.label.name].get() localized_value: str = taxonomy.get_localized_name(value, lang) localized_question = self.translation_store.gettext(lang, self.question) - source_image_url = (settings.OFF_IMAGE_BASE_URL + - get_display_image(insight.source_image)) + source_image_url = settings.OFF_IMAGE_BASE_URL + get_display_image( + insight.source_image + ) - return AddBinaryQuestion(question=localized_question, - value=localized_value, - insight=insight, - image_url=image_url, - source_image_url=source_image_url) + return AddBinaryQuestion( + question=localized_question, + value=localized_value, + insight=insight, + image_url=image_url, + source_image_url=source_image_url, + ) class BrandQuestionFormatter(QuestionFormatter): question = "Does the product belong to this brand?" def format_question(self, insight: ProductInsight, lang: str) -> Question: - value: str = insight.data['brand'] + value: str = insight.data["brand"] localized_question = self.translation_store.gettext(lang, self.question) - source_image_url = (settings.OFF_IMAGE_BASE_URL + - get_display_image(insight.source_image)) + source_image_url = settings.OFF_IMAGE_BASE_URL + get_display_image( + insight.source_image + ) - return AddBinaryQuestion(question=localized_question, - value=value, - insight=insight, - source_image_url=source_image_url) + return AddBinaryQuestion( + question=localized_question, + value=value, + insight=insight, + source_image_url=source_image_url, + ) def get_display_image(source_image: str) -> str: @@ -180,7 +195,7 @@ def get_display_image(source_image: str) -> str: if not image_path.stem.isdigit(): return source_image - display_name = "{}.400.jpg".format(image_path.name.split('.')[0]) + display_name = "{}.400.jpg".format(image_path.name.split(".")[0]) return str(image_path.parent / display_name) diff --git a/robotoff/insights/validator.py b/robotoff/insights/validator.py index 5bdf0e2034..b73769f4f1 100644 --- a/robotoff/insights/validator.py +++ b/robotoff/insights/validator.py @@ -4,7 +4,7 @@ from robotoff.insights._enum import InsightType from robotoff.models import ProductInsight from robotoff.products import ProductStore -from robotoff.taxonomy import Taxonomy, TAXONOMY_STORES +from robotoff.taxonomy import TAXONOMY_STORES, Taxonomy from robotoff.utils.types import JSONType @@ -29,7 +29,7 @@ def get_type() -> str: def is_valid(self, insight: ProductInsight) -> bool: product = self.product_store[insight.barcode] - product_labels_tags = getattr(product, 'labels_tags', []) + product_labels_tags = getattr(product, "labels_tags", []) label_tag = insight.value_tag if label_tag in product_labels_tags: @@ -37,12 +37,11 @@ def is_valid(self, insight: ProductInsight) -> bool: # Check that the predicted label is not a parent of a # current/already predicted label - label_taxonomy: Taxonomy = TAXONOMY_STORES[ - InsightType.label.name].get() + label_taxonomy: Taxonomy = TAXONOMY_STORES[InsightType.label.name].get() - if (label_tag in label_taxonomy and - label_taxonomy.is_parent_of_any(label_tag, - product_labels_tags)): + if label_tag in label_taxonomy and label_taxonomy.is_parent_of_any( + label_tag, product_labels_tags + ): return False return True @@ -55,7 +54,7 @@ def get_type() -> str: def is_valid(self, insight: ProductInsight) -> bool: product = self.product_store[insight.barcode] - product_categories_tags = getattr(product, 'categories_tags', []) + product_categories_tags = getattr(product, "categories_tags", []) category_tag = insight.value_tag if category_tag in product_categories_tags: @@ -63,12 +62,11 @@ def is_valid(self, insight: ProductInsight) -> bool: # Check that the predicted category is not a parent of a # current/already predicted category - category_taxonomy: Taxonomy = TAXONOMY_STORES[ - InsightType.category.name].get() + category_taxonomy: Taxonomy = TAXONOMY_STORES[InsightType.category.name].get() - if (category_tag in category_taxonomy and - category_taxonomy.is_parent_of_any(category_tag, - product_categories_tags)): + if category_tag in category_taxonomy and category_taxonomy.is_parent_of_any( + category_tag, product_categories_tags + ): return False return True @@ -81,9 +79,9 @@ class InsightValidatorFactory: } @classmethod - def create(cls, insight_type: str, - product_store: Optional[ProductStore]) -> \ - Optional[InsightValidator]: + def create( + cls, insight_type: str, product_store: Optional[ProductStore] + ) -> Optional[InsightValidator]: if insight_type in cls.validators: return cls.validators[insight_type](product_store) else: diff --git a/robotoff/ml/category_classifier.py b/robotoff/ml/category_classifier.py index 710cc09636..9d28a4e408 100644 --- a/robotoff/ml/category_classifier.py +++ b/robotoff/ml/category_classifier.py @@ -3,30 +3,37 @@ import os import pathlib import re -from typing import List, Optional, Dict, Set +from typing import Dict, List, Optional, Set import networkx import numpy as np import pandas as pd from sklearn.compose import ColumnTransformer -from sklearn.feature_extraction.text import (TfidfTransformer, CountVectorizer, - strip_accents_ascii) - +from sklearn.externals import joblib +from sklearn.feature_extraction.text import ( + CountVectorizer, + TfidfTransformer, + strip_accents_ascii, +) from sklearn.linear_model import LogisticRegression - from sklearn.model_selection import train_test_split -from sklearn.externals import joblib from sklearn.pipeline import Pipeline - from sklearn_hierarchical_classification.classifier import HierarchicalClassifier from sklearn_hierarchical_classification.constants import ROOT -from sklearn_hierarchical_classification.metrics import h_precision_score, \ - h_recall_score, h_fbeta_score +from sklearn_hierarchical_classification.metrics import ( + h_fbeta_score, + h_precision_score, + h_recall_score, +) from robotoff import settings from robotoff.products import ProductDataset -from robotoff.taxonomy import Taxonomy, TAXONOMY_STORES, TaxonomyType, \ - generate_category_hierarchy +from robotoff.taxonomy import ( + TAXONOMY_STORES, + Taxonomy, + TaxonomyType, + generate_category_hierarchy, +) from robotoff.utils import get_logger from robotoff.utils.types import JSONType @@ -39,75 +46,79 @@ class CategoryClassifier: - TRANSFORMER_PATH = 'transformer.joblib' - CLASSIFIER_PATH = 'clf.joblib' - CATEGORY_TAXONOMY_PATH = 'category_taxonomy.json' + TRANSFORMER_PATH = "transformer.joblib" + CLASSIFIER_PATH = "clf.joblib" + CATEGORY_TAXONOMY_PATH = "category_taxonomy.json" def __init__(self, category_taxonomy: Taxonomy): self.category_taxonomy: Taxonomy = category_taxonomy self.categories_set: Set[str] = set(category_taxonomy.keys()) self.categories: List[str] = sorted(self.categories_set) - self.categories_to_index: Dict[str, int] = {cat: i for (i, cat) in - enumerate(self.categories)} + self.categories_to_index: Dict[str, int] = { + cat: i for (i, cat) in enumerate(self.categories) + } self.transformer: Optional[ColumnTransformer] = None self.classifier: Optional[HierarchicalClassifier] = None - def generate_training_df(self, - dataset: ProductDataset) -> pd.DataFrame: - training_dataset_iter = (dataset.stream() - .filter_by_country_tag('en:france') - .filter_nonempty_text_field('product_name') - .filter_nonempty_tag_field('categories_tags')) + def generate_training_df(self, dataset: ProductDataset) -> pd.DataFrame: + training_dataset_iter = ( + dataset.stream() + .filter_by_country_tag("en:france") + .filter_nonempty_text_field("product_name") + .filter_nonempty_tag_field("categories_tags") + ) training_dataset = [] processed = 0 for product in training_dataset_iter: processed += 1 - transformed_product = self.transform_product(product, - add_category=True) + transformed_product = self.transform_product(product, add_category=True) - if 'deepest_category' in transformed_product: + if "deepest_category" in transformed_product: training_dataset.append(transformed_product) - logger.info("{} training samples discarded (category not in " - "taxonomy), {} remaining" - "".format(processed - len(training_dataset), - len(training_dataset))) + logger.info( + "{} training samples discarded (category not in " + "taxonomy), {} remaining" + "".format(processed - len(training_dataset), len(training_dataset)) + ) return pd.DataFrame(training_dataset) - def generate_prediction_df(self, - dataset: ProductDataset) -> pd.DataFrame: - dataset_iter = (dataset.stream() - .filter_by_country_tag('en:france') - .filter_nonempty_text_field('product_name')) + def generate_prediction_df(self, dataset: ProductDataset) -> pd.DataFrame: + dataset_iter = ( + dataset.stream() + .filter_by_country_tag("en:france") + .filter_nonempty_text_field("product_name") + ) return pd.DataFrame((self.transform_product(p) for p in dataset_iter)) - def transform_product(self, product: Dict, - add_category: bool = False) -> Dict: + def transform_product(self, product: Dict, add_category: bool = False) -> Dict: item = { - 'barcode': product['code'], - 'ingredients_tags': product.get('ingredients_tags', []), - 'product_name': product.get('product_name', ''), + "barcode": product["code"], + "ingredients_tags": product.get("ingredients_tags", []), + "product_name": product.get("product_name", ""), } if add_category: - categories_tags: List[str] = product['categories_tags'] + categories_tags: List[str] = product["categories_tags"] - deepest_category: Optional[str] = ( - self.category_taxonomy.find_deepest_item(categories_tags)) + deepest_category: Optional[str] = self.category_taxonomy.find_deepest_item( + categories_tags + ) if deepest_category is not None: - item['deepest_category'] = deepest_category - item['deepest_category_int'] = self.categories_to_index[ - deepest_category] + item["deepest_category"] = deepest_category + item["deepest_category_int"] = self.categories_to_index[ + deepest_category + ] return item def train(self, dataset: ProductDataset): - category_hierarchy = generate_category_hierarchy(self.category_taxonomy, - self.categories_to_index, - ROOT) + category_hierarchy = generate_category_hierarchy( + self.category_taxonomy, self.categories_to_index, ROOT + ) logger.info("Number of categories: {}".format(len(self.categories))) @@ -134,25 +145,28 @@ def generate_insights(self, dataset: ProductDataset) -> List[JSONType]: for i, row in enumerate(df.itertuples()): category = self.categories[y_pred] - insights.append({ - 'barcode': row.barcode, - 'category': category, - 'model': 'hierarchical_classifier', - }) + insights.append( + { + "barcode": row.barcode, + "category": category, + "model": "hierarchical_classifier", + } + ) return insights def raise_if_not_loaded(self): if self.classifier is None or self.transformer is None: - raise RuntimeError("The model must be loaded or trained " - "before prediction") + raise RuntimeError( + "The model must be loaded or trained " "before prediction" + ) def predict(self, product: Dict): self.raise_if_not_loaded() transformed = { - 'product_name': product.get('product_name', ''), - 'ingredients_tags': product.get('ingredients_tags', []), + "product_name": product.get("product_name", ""), + "ingredients_tags": product.get("ingredients_tags", []), } df = pd.DataFrame([transformed]) y_pred = self.classifier.predict(self.transformer.transform(df))[0] @@ -160,22 +174,19 @@ def predict(self, product: Dict): def save(self, output_dir: str) -> None: output_dir_path = pathlib.Path(output_dir) - joblib.dump(self.transformer, - str(output_dir_path / self.TRANSFORMER_PATH)) - joblib.dump(self.classifier, - str(output_dir_path / self.CLASSIFIER_PATH)) + joblib.dump(self.transformer, str(output_dir_path / self.TRANSFORMER_PATH)) + joblib.dump(self.classifier, str(output_dir_path / self.CLASSIFIER_PATH)) - with open(str(output_dir_path / self.CATEGORY_TAXONOMY_PATH), 'w') as f: + with open(str(output_dir_path / self.CATEGORY_TAXONOMY_PATH), "w") as f: json.dump(self.category_taxonomy.to_dict(), f) @classmethod - def load(cls, model_dir: str) -> 'CategoryClassifier': + def load(cls, model_dir: str) -> "CategoryClassifier": model_dir_path = pathlib.Path(model_dir) transformer = joblib.load(str(model_dir_path / cls.TRANSFORMER_PATH)) classifier = joblib.load(str(model_dir_path / cls.CLASSIFIER_PATH)) - with open(str(model_dir_path / - cls.CATEGORY_TAXONOMY_PATH), 'r') as f: + with open(str(model_dir_path / cls.CATEGORY_TAXONOMY_PATH), "r") as f: category_taxonomy_data = json.load(f) category_taxonomy = Taxonomy.from_dict(category_taxonomy_data) @@ -186,46 +197,55 @@ def load(cls, model_dir: str) -> 'CategoryClassifier': @classmethod def create_classifier(cls, category_hierarchy): - return HierarchicalClassifier(base_estimator=cls - .create_base_classifier(), - class_hierarchy=category_hierarchy, - prediction_depth='nmlnp', - algorithm='lcpn', - stopping_criteria=0.5) + return HierarchicalClassifier( + base_estimator=cls.create_base_classifier(), + class_hierarchy=category_hierarchy, + prediction_depth="nmlnp", + algorithm="lcpn", + stopping_criteria=0.5, + ) @staticmethod def create_base_classifier(): - return Pipeline([ - ('tfidf', TfidfTransformer()), - ('clf', LogisticRegression())]) + return Pipeline([("tfidf", TfidfTransformer()), ("clf", LogisticRegression())]) @staticmethod def create_transformer(): - return ColumnTransformer([ - ('ingredients_vectorizer', - CountVectorizer(min_df=5, - preprocessor=ingredient_preprocess, - analyzer='word', - token_pattern=r"[a-zA-Z-:]+"), - 'ingredients_tags'), - ('product_name_vectorizer', - CountVectorizer(min_df=5, - preprocessor=preprocess_product_name), - 'product_name'), - ]) + return ColumnTransformer( + [ + ( + "ingredients_vectorizer", + CountVectorizer( + min_df=5, + preprocessor=ingredient_preprocess, + analyzer="word", + token_pattern=r"[a-zA-Z-:]+", + ), + "ingredients_tags", + ), + ( + "product_name_vectorizer", + CountVectorizer(min_df=5, preprocessor=preprocess_product_name), + "product_name", + ), + ] + ) def evaluate(self, test_df: pd.DataFrame) -> JSONType: self.raise_if_not_loaded() y_test = test_df.deepest_category_int.values y_pred = self.classifier.predict(self.transformer.transform(test_df)) - return self._evaluate(self.classifier.graph_, - y_test, y_pred, len(self.categories)) + return self._evaluate( + self.classifier.graph_, y_test, y_pred, len(self.categories) + ) @staticmethod - def _evaluate(category_graph: networkx.DiGraph, - y_true: np.ndarray, - y_pred: np.ndarray, - category_count: int) -> JSONType: + def _evaluate( + category_graph: networkx.DiGraph, + y_true: np.ndarray, + y_pred: np.ndarray, + category_count: int, + ) -> JSONType: y_true_matrix = np.zeros((y_true.shape[0], category_count)) y_true_matrix[np.arange(y_true.shape[0]), y_true] = 1 @@ -233,22 +253,24 @@ def _evaluate(category_graph: networkx.DiGraph, y_pred_matrix[np.arange(y_pred.shape[0]), y_pred] = 1 return { - 'h_precision': h_precision_score(y_true_matrix, y_pred_matrix, category_graph), - 'h_recall': h_recall_score(y_true_matrix, y_pred_matrix, category_graph), - 'h_fbeta': h_fbeta_score(y_true_matrix, y_pred_matrix, category_graph), + "h_precision": h_precision_score( + y_true_matrix, y_pred_matrix, category_graph + ), + "h_recall": h_recall_score(y_true_matrix, y_pred_matrix, category_graph), + "h_fbeta": h_fbeta_score(y_true_matrix, y_pred_matrix, category_graph), } def ingredient_preprocess(ingredients_tags: List[str]) -> str: - return ' '.join(ingredients_tags) + return " ".join(ingredients_tags) def preprocess_product_name(text): text = strip_accents_ascii(text) text = text.lower() - text = PUNCTUATION_REGEX.sub(' ', text) - text = DIGIT_REGEX.sub(' ', text) - return MULTIPLE_SPACES_REGEX.sub(' ', text) + text = PUNCTUATION_REGEX.sub(" ", text) + text = DIGIT_REGEX.sub(" ", text) + return MULTIPLE_SPACES_REGEX.sub(" ", text) def train(model_output_dir: pathlib.Path, comment: Optional[str] = None): @@ -259,19 +281,21 @@ def train(model_output_dir: pathlib.Path, comment: Optional[str] = None): category_classifier.save(str(model_output_dir)) test_metrics = category_classifier.evaluate(test_df) - dataset_timestamp = datetime.datetime.fromtimestamp(os.path.getmtime(settings.JSONL_DATASET_PATH)) + dataset_timestamp = datetime.datetime.fromtimestamp( + os.path.getmtime(settings.JSONL_DATASET_PATH) + ) meta = { - 'metrics': { - 'test': test_metrics, + "metrics": { + "test": test_metrics, }, - 'dataset_id': dataset_timestamp.date().isoformat(), - 'training_set_count': len(train_df), - 'test_set_count': len(test_df), + "dataset_id": dataset_timestamp.date().isoformat(), + "training_set_count": len(train_df), + "test_set_count": len(test_df), } if comment: - meta['comment'] = comment + meta["comment"] = comment - with open(str(model_output_dir / 'meta.json'), 'w') as f: + with open(str(model_output_dir / "meta.json"), "w") as f: json.dump(meta, f) diff --git a/robotoff/ml/object_detection/core.py b/robotoff/ml/object_detection/core.py index 6544d042a2..9a3354c00e 100644 --- a/robotoff/ml/object_detection/core.py +++ b/robotoff/ml/object_detection/core.py @@ -1,33 +1,29 @@ -import pathlib -from typing import Optional, List, Tuple, Dict - import dataclasses - -import PIL +import pathlib +from typing import Dict, List, Optional, Tuple import numpy as np +import PIL import tensorflow as tf - from PIL import Image from robotoff import settings - +from robotoff.ml.object_detection.utils import label_map_util from robotoff.ml.object_detection.utils import ops as utils_ops +from robotoff.ml.object_detection.utils import visualization_utils as vis_util from robotoff.ml.object_detection.utils.label_map_util import CategoryIndex from robotoff.ml.object_detection.utils.ops import convert_image_to_array -from robotoff.ml.object_detection.utils.string_int_label_map_pb2 import \ - StringIntLabelMap -from robotoff.ml.object_detection.utils import label_map_util -from robotoff.ml.object_detection.utils import visualization_utils as \ - vis_util +from robotoff.ml.object_detection.utils.string_int_label_map_pb2 import ( + StringIntLabelMap, +) from robotoff.utils import get_logger from robotoff.utils.types import JSONType logger = get_logger(__name__) -FROZEN_GRAPH_NAME = 'frozen_inference_graph.pb' -LABEL_MAP_NAME = 'labels.pbtxt' +FROZEN_GRAPH_NAME = "frozen_inference_graph.pb" +LABEL_MAP_NAME = "labels.pbtxt" @dataclasses.dataclass @@ -48,29 +44,31 @@ class ObjectDetectionRawResult: detection_masks: Optional[np.array] = None boxed_image: Optional[PIL.Image.Image] = None - def select(self, threshold: Optional[float] = None) -> \ - List[ObjectDetectionResult]: + def select(self, threshold: Optional[float] = None) -> List[ObjectDetectionResult]: if threshold is None: threshold = 0.5 - box_masks = (self.detection_scores > threshold) + box_masks = self.detection_scores > threshold selected_boxes = self.detection_boxes[box_masks] selected_scores = self.detection_scores[box_masks] selected_classes = self.detection_classes[box_masks] results = [] image_width, image_height = self.image_size - for box, score, label in zip(selected_boxes, selected_scores, - selected_classes): + for box, score, label in zip(selected_boxes, selected_scores, selected_classes): ymin, xmin, ymax, xmax = box - bounding_box = (ymin * image_height, xmin * image_width, - ymax * image_height, xmax * image_width) + bounding_box = ( + ymin * image_height, + xmin * image_width, + ymax * image_height, + xmax * image_width, + ) label_int = int(label) - label_str = self.category_index[label_int]['name'] - result = ObjectDetectionResult(bounding_box=bounding_box, - score=float(score), - label=label_str) + label_str = self.category_index[label_int]["name"] + result = ObjectDetectionResult( + bounding_box=bounding_box, score=float(score), label=label_str + ) results.append(result) return results @@ -80,110 +78,104 @@ def to_json(self, threshold: Optional[float] = None) -> List[JSONType]: class ObjectDetectionModel: - def __init__(self, - graph: tf.Graph, - label_map: StringIntLabelMap): + def __init__(self, graph: tf.Graph, label_map: StringIntLabelMap): self.graph: tf.Graph = graph self.label_map: StringIntLabelMap = label_map self.categories = label_map_util.convert_label_map_to_categories( - label_map, max_num_classes=1000) - self.category_index: CategoryIndex = ( - label_map_util.create_category_index(self.categories)) + label_map, max_num_classes=1000 + ) + self.category_index: CategoryIndex = label_map_util.create_category_index( + self.categories + ) @classmethod def load(cls, graph_path: pathlib.Path, label_path: pathlib.Path): detection_graph = tf.Graph() with detection_graph.as_default(): od_graph_def = tf.GraphDef() - with tf.gfile.GFile( - str(graph_path), 'rb') as f: + with tf.gfile.GFile(str(graph_path), "rb") as f: serialized_graph = f.read() od_graph_def.ParseFromString(serialized_graph) - tf.import_graph_def(od_graph_def, name='') + tf.import_graph_def(od_graph_def, name="") - label_map = label_map_util.load_labelmap( - str(label_path)) + label_map = label_map_util.load_labelmap(str(label_path)) logger.info("Model loaded") - return cls(graph=detection_graph, - label_map=label_map) + return cls(graph=detection_graph, label_map=label_map) - def _run_inference_for_single_image(self, image: np.array) -> \ - ObjectDetectionRawResult: + def _run_inference_for_single_image( + self, image: np.array + ) -> ObjectDetectionRawResult: with tf.Session(graph=self.graph) as sess: # Get handles to input and output tensors ops = self.graph.get_operations() - all_tensor_names = {output.name for op in ops for output in - op.outputs} + all_tensor_names = {output.name for op in ops for output in op.outputs} tensor_dict = {} for key in [ - 'num_detections', 'detection_boxes', 'detection_scores', - 'detection_classes', 'detection_masks' + "num_detections", + "detection_boxes", + "detection_scores", + "detection_classes", + "detection_masks", ]: - tensor_name = key + ':0' + tensor_name = key + ":0" if tensor_name in all_tensor_names: - tensor_dict[ - key] = self.graph.get_tensor_by_name(tensor_name) - if 'detection_masks' in tensor_dict: + tensor_dict[key] = self.graph.get_tensor_by_name(tensor_name) + if "detection_masks" in tensor_dict: # The following processing is only for single image - detection_boxes = tf.squeeze(tensor_dict['detection_boxes'], - [0]) - detection_masks = tf.squeeze(tensor_dict['detection_masks'], - [0]) + detection_boxes = tf.squeeze(tensor_dict["detection_boxes"], [0]) + detection_masks = tf.squeeze(tensor_dict["detection_masks"], [0]) # Reframe is required to translate mask from box coordinates # to image coordinates and fit the image size. - real_num_detection = tf.cast( - tensor_dict['num_detections'][0], - tf.int32) - detection_boxes = tf.slice(detection_boxes, [0, 0], - [real_num_detection, -1]) - detection_masks = tf.slice(detection_masks, [0, 0, 0], - [real_num_detection, -1, -1]) - detection_masks_reframed = utils_ops. \ - reframe_box_masks_to_image_masks( - detection_masks, detection_boxes, image.shape[0], - image.shape[1]) + real_num_detection = tf.cast(tensor_dict["num_detections"][0], tf.int32) + detection_boxes = tf.slice( + detection_boxes, [0, 0], [real_num_detection, -1] + ) + detection_masks = tf.slice( + detection_masks, [0, 0, 0], [real_num_detection, -1, -1] + ) + detection_masks_reframed = utils_ops.reframe_box_masks_to_image_masks( + detection_masks, detection_boxes, image.shape[0], image.shape[1] + ) detection_masks_reframed = tf.cast( - tf.greater(detection_masks_reframed, 0.5), tf.uint8) + tf.greater(detection_masks_reframed, 0.5), tf.uint8 + ) # Follow the convention by adding back the batch dimension - tensor_dict['detection_masks'] = tf.expand_dims( - detection_masks_reframed, 0) - image_tensor = self.graph.get_tensor_by_name('image_tensor:0') + tensor_dict["detection_masks"] = tf.expand_dims( + detection_masks_reframed, 0 + ) + image_tensor = self.graph.get_tensor_by_name("image_tensor:0") # Run inference - output_dict = sess.run(tensor_dict, - feed_dict={ - image_tensor: np.expand_dims(image, - 0)}) + output_dict = sess.run( + tensor_dict, feed_dict={image_tensor: np.expand_dims(image, 0)} + ) # all outputs are float32 numpy arrays, so convert types as # appropriate - output_dict['num_detections'] = int( - output_dict['num_detections'][0]) - output_dict['detection_classes'] = output_dict[ - 'detection_classes'][0].astype(np.uint8) - output_dict['detection_boxes'] = output_dict['detection_boxes'][ - 0] - output_dict['detection_scores'] = \ - output_dict['detection_scores'][0] - if 'detection_masks' in output_dict: - output_dict['detection_masks'] = \ - output_dict['detection_masks'][ - 0] + output_dict["num_detections"] = int(output_dict["num_detections"][0]) + output_dict["detection_classes"] = output_dict["detection_classes"][ + 0 + ].astype(np.uint8) + output_dict["detection_boxes"] = output_dict["detection_boxes"][0] + output_dict["detection_scores"] = output_dict["detection_scores"][0] + if "detection_masks" in output_dict: + output_dict["detection_masks"] = output_dict["detection_masks"][0] return ObjectDetectionRawResult( image_size=(image.shape[0], image.shape[1]), - num_detections=output_dict['num_detections'], - detection_classes=output_dict['detection_classes'], - detection_boxes=output_dict['detection_boxes'], - detection_scores=output_dict['detection_scores'], - detection_masks=output_dict.get('detection_masks'), - category_index=self.category_index) - - def detect_from_image(self, - image: PIL.Image.Image, - output_image: bool = False) -> ObjectDetectionRawResult: + num_detections=output_dict["num_detections"], + detection_classes=output_dict["detection_classes"], + detection_boxes=output_dict["detection_boxes"], + detection_scores=output_dict["detection_scores"], + detection_masks=output_dict.get("detection_masks"), + category_index=self.category_index, + ) + + def detect_from_image( + self, image: PIL.Image.Image, output_image: bool = False + ) -> ObjectDetectionRawResult: image_array = convert_image_to_array(image) result = self._run_inference_for_single_image(image_array) @@ -196,32 +188,32 @@ def detect_from_image(self, self.category_index, instance_masks=result.detection_masks, use_normalized_coordinates=True, - line_thickness=5) + line_thickness=5, + ) image_with_boxes = Image.fromarray(image_array) result.boxed_image = image_with_boxes return result -def run_model(image_dir: pathlib.Path, - model: ObjectDetectionModel): - for filepath in image_dir.glob('*.jpg'): +def run_model(image_dir: pathlib.Path, model: ObjectDetectionModel): + for filepath in image_dir.glob("*.jpg"): boxed_filename = filepath.parent / "{}_box.jpg".format(filepath.stem) - if filepath.stem.endswith('box') or boxed_filename.exists(): + if filepath.stem.endswith("box") or boxed_filename.exists(): continue image: PIL.Image.Image = Image.open(str(filepath)) result = model.detect_from_image(image, output_image=True) - with open(str(boxed_filename), 'wb') as f: + with open(str(boxed_filename), "wb") as f: result.boxed_image.save(f) class ObjectDetectionModelRegistry: models_config = { - 'nutrition-table': settings.MODELS_DIR / 'nutrition-table', - 'nutriscore': settings.MODELS_DIR / 'nutriscore', + "nutrition-table": settings.MODELS_DIR / "nutrition-table", + "nutriscore": settings.MODELS_DIR / "nutriscore", } models: Dict[str, ObjectDetectionModel] = {} @@ -239,8 +231,7 @@ def load(cls, name: str) -> ObjectDetectionModel: model_dir = cls.models_config[name] graph_path = model_dir / FROZEN_GRAPH_NAME label_path = model_dir / LABEL_MAP_NAME - model = ObjectDetectionModel.load(graph_path=graph_path, - label_path=label_path) + model = ObjectDetectionModel.load(graph_path=graph_path, label_path=label_path) cls.models[name] = model return model @@ -257,4 +248,3 @@ def get(cls, name: str): model = cls.models[name] return model - diff --git a/robotoff/ml/object_detection/download.py b/robotoff/ml/object_detection/download.py index 6429192b13..ac47118e78 100644 --- a/robotoff/ml/object_detection/download.py +++ b/robotoff/ml/object_detection/download.py @@ -3,50 +3,48 @@ import requests +from robotoff import settings from robotoff.off import generate_image_url from robotoff.products import ProductDataset -from robotoff import settings from robotoff.utils import get_logger from robotoff.utils.types import JSONType logger = get_logger() -JSONL_SHUF_DATASET_PATH = settings.DATASET_DIR / 'products-shuf.jsonl.gz' +JSONL_SHUF_DATASET_PATH = settings.DATASET_DIR / "products-shuf.jsonl.gz" ds = ProductDataset(JSONL_SHUF_DATASET_PATH) -IMAGE_DATASET_DIR = settings.PROJECT_DIR / 'image_dataset' -NUTRITION_TABLE_IMAGE_DIR = IMAGE_DATASET_DIR / 'nutrition-table-2' +IMAGE_DATASET_DIR = settings.PROJECT_DIR / "image_dataset" +NUTRITION_TABLE_IMAGE_DIR = IMAGE_DATASET_DIR / "nutrition-table-2" def load_seen_set() -> Set[str]: seen_set = set() - with open(IMAGE_DATASET_DIR / 'dataset.txt') as f: + with open(IMAGE_DATASET_DIR / "dataset.txt") as f: for line in f: if line: - line = line.strip('\n') - barcode, _ = line.split('_') + line = line.strip("\n") + barcode, _ = line.split("_") seen_set.add(barcode) return seen_set -def save_image(directory: pathlib.Path, - image_meta: JSONType, - barcode: str, - override: bool = False): - image_name = image_meta['imgid'] +def save_image( + directory: pathlib.Path, image_meta: JSONType, barcode: str, override: bool = False +): + image_name = image_meta["imgid"] image_full_name = "{}_{}.jpg".format(barcode, image_name) image_path = directory / image_full_name if image_path.exists() and not override: return - image_url = generate_image_url(barcode, - image_name) + image_url = generate_image_url(barcode, image_name) logger.info("Downloading image {}".format(image_url)) r = requests.get(image_url) - with open(str(image_path), 'wb') as fd: + with open(str(image_path), "wb") as fd: logger.info("Saving image in {}".format(image_path)) for chunk in r.iter_content(chunk_size=128): fd.write(chunk) @@ -55,11 +53,14 @@ def save_image(directory: pathlib.Path, seen_set = load_seen_set() count = 0 -for product in (ds.stream().filter_by_state_tag('en:complete') - .filter_by_country_tag('en:france') - .filter_nonempty_text_field('code') - .filter_nonempty_tag_field('images')): - barcode = product['code'] +for product in ( + ds.stream() + .filter_by_state_tag("en:complete") + .filter_by_country_tag("en:france") + .filter_nonempty_text_field("code") + .filter_nonempty_tag_field("images") +): + barcode = product["code"] if barcode in seen_set: print("Product already seen: {}".format(barcode)) @@ -68,14 +69,14 @@ def save_image(directory: pathlib.Path, has_nutrition = False has_front = False - for image_key, image_meta in product.get('images', {}).items(): - if not has_nutrition and image_key.startswith('nutrition'): + for image_key, image_meta in product.get("images", {}).items(): + if not has_nutrition and image_key.startswith("nutrition"): has_nutrition = True save_image(NUTRITION_TABLE_IMAGE_DIR, image_meta, barcode) count += 1 continue - elif not has_front and image_key.startswith('front'): + elif not has_front and image_key.startswith("front"): has_front = True save_image(NUTRITION_TABLE_IMAGE_DIR, image_meta, barcode) count += 1 diff --git a/robotoff/ml/object_detection/utils/dataset_util.py b/robotoff/ml/object_detection/utils/dataset_util.py index 3731e64360..d690d54aa7 100644 --- a/robotoff/ml/object_detection/utils/dataset_util.py +++ b/robotoff/ml/object_detection/utils/dataset_util.py @@ -53,7 +53,7 @@ def read_examples_list(path): """ with tf.gfile.GFile(path) as fid: lines = fid.readlines() - return [line.strip().split(' ')[0] for line in lines] + return [line.strip().split(" ")[0] for line in lines] def recursive_parse_xml_to_dict(xml): @@ -70,7 +70,7 @@ def recursive_parse_xml_to_dict(xml): result = {} for child in xml: child_result = recursive_parse_xml_to_dict(child) - if child.tag != 'object': + if child.tag != "object": result[child.tag] = child_result[child.tag] else: if child.tag not in result: diff --git a/robotoff/ml/object_detection/utils/label_map_util.py b/robotoff/ml/object_detection/utils/label_map_util.py index 084cbab563..f53403cf35 100644 --- a/robotoff/ml/object_detection/utils/label_map_util.py +++ b/robotoff/ml/object_detection/utils/label_map_util.py @@ -20,6 +20,7 @@ import tensorflow as tf from google.protobuf import text_format + from . import string_int_label_map_pb2 @@ -34,11 +35,13 @@ def _validate_label_map(label_map): """ for item in label_map.item: if item.id < 0: - raise ValueError('Label map ids should be >= 0.') - if (item.id == 0 and item.name != 'background' and - item.display_name != 'background'): - raise ValueError( - 'Label map id 0 is reserved for the background label') + raise ValueError("Label map ids should be >= 0.") + if ( + item.id == 0 + and item.name != "background" + and item.display_name != "background" + ): + raise ValueError("Label map id 0 is reserved for the background label") CategoryIndex = Dict[int, Dict] @@ -59,7 +62,7 @@ def create_category_index(categories: List[Dict]) -> CategoryIndex: """ category_index = {} for cat in categories: - category_index[cat['id']] = cat + category_index[cat["id"]] = cat return category_index @@ -75,9 +78,9 @@ def get_max_label_map_index(label_map): return max([item.id for item in label_map.item]) -def convert_label_map_to_categories(label_map, - max_num_classes, - use_display_name=True) -> List[Dict]: +def convert_label_map_to_categories( + label_map, max_num_classes, use_display_name=True +) -> List[Dict]: """Loads label map proto and returns categories list compatible with eval. This function loads a label map and returns a list of dicts, each of which @@ -105,23 +108,27 @@ def convert_label_map_to_categories(label_map, if not label_map: label_id_offset = 1 for class_id in range(max_num_classes): - categories.append({ - 'id': class_id + label_id_offset, - 'name': 'category_{}'.format(class_id + label_id_offset) - }) + categories.append( + { + "id": class_id + label_id_offset, + "name": "category_{}".format(class_id + label_id_offset), + } + ) return categories for item in label_map.item: if not 0 < item.id <= max_num_classes: - logging.info('Ignore item %d since it falls outside of requested ' - 'label range.', item.id) + logging.info( + "Ignore item %d since it falls outside of requested " "label range.", + item.id, + ) continue - if use_display_name and item.HasField('display_name'): + if use_display_name and item.HasField("display_name"): name = item.display_name else: name = item.name if item.id not in list_of_ids_already_added: list_of_ids_already_added.append(item.id) - categories.append({'id': item.id, 'name': name}) + categories.append({"id": item.id, "name": name}) return categories @@ -133,7 +140,7 @@ def load_labelmap(path: str) -> string_int_label_map_pb2.StringIntLabelMap: Returns: a StringIntLabelMapProto """ - with tf.gfile.GFile(path, 'r') as fid: + with tf.gfile.GFile(path, "r") as fid: label_map_string = fid.read() label_map = string_int_label_map_pb2.StringIntLabelMap() try: @@ -183,4 +190,4 @@ def create_category_index_from_labelmap(label_map_path): def create_class_agnostic_category_index(): """Creates a category index with a single `object` class.""" - return {1: {'id': 1, 'name': 'object'}} + return {1: {"id": 1, "name": "object"}} diff --git a/robotoff/ml/object_detection/utils/ops.py b/robotoff/ml/object_detection/utils/ops.py index ffdedfcde7..7891f61035 100644 --- a/robotoff/ml/object_detection/utils/ops.py +++ b/robotoff/ml/object_detection/utils/ops.py @@ -1,11 +1,10 @@ -import tensorflow as tf import numpy as np import PIL +import tensorflow as tf from PIL import Image -def reframe_box_masks_to_image_masks(box_masks, boxes, image_height, - image_width): +def reframe_box_masks_to_image_masks(box_masks, boxes, image_height, image_width): """Transforms the box masks back to full image masks. Embeds masks in bounding boxes of larger masks whose shapes correspond to image shape. @@ -37,27 +36,29 @@ def transform_boxes_relative_to_boxes(boxes, reference_boxes): box_masks_expanded = tf.expand_dims(box_masks, axis=3) num_boxes = tf.shape(box_masks_expanded)[0] unit_boxes = tf.concat( - [tf.zeros([num_boxes, 2]), tf.ones([num_boxes, 2])], axis=1) + [tf.zeros([num_boxes, 2]), tf.ones([num_boxes, 2])], axis=1 + ) reverse_boxes = transform_boxes_relative_to_boxes(unit_boxes, boxes) return tf.image.crop_and_resize( image=box_masks_expanded, boxes=reverse_boxes, box_ind=tf.range(num_boxes), crop_size=[image_height, image_width], - extrapolation_value=0.0) + extrapolation_value=0.0, + ) image_masks = tf.cond( tf.shape(box_masks)[0] > 0, reframe_box_masks_to_image_masks_default, - lambda: tf.zeros([0, image_height, image_width, 1], dtype=tf.float32)) + lambda: tf.zeros([0, image_height, image_width, 1], dtype=tf.float32), + ) return tf.squeeze(image_masks, axis=3) def convert_image_to_array(image: PIL.Image.Image) -> np.array: - if image.mode != 'RGB': - image = image.convert('RGB') + if image.mode != "RGB": + image = image.convert("RGB") (im_width, im_height) = image.size - return np.array(image.getdata()).reshape( - (im_height, im_width, 3)).astype(np.uint8) + return np.array(image.getdata()).reshape((im_height, im_width, 3)).astype(np.uint8) diff --git a/robotoff/ml/object_detection/utils/standard_fields.py b/robotoff/ml/object_detection/utils/standard_fields.py index 6e8a58071d..1ca2db625c 100644 --- a/robotoff/ml/object_detection/utils/standard_fields.py +++ b/robotoff/ml/object_detection/utils/standard_fields.py @@ -69,36 +69,37 @@ class InputDataFields(object): images can be padded with zeros. multiclass_scores: the label score per class for each box. """ - image = 'image' - image_additional_channels = 'image_additional_channels' - original_image = 'original_image' - original_image_spatial_shape = 'original_image_spatial_shape' - key = 'key' - source_id = 'source_id' - filename = 'filename' - groundtruth_image_classes = 'groundtruth_image_classes' - groundtruth_image_confidences = 'groundtruth_image_confidences' - groundtruth_boxes = 'groundtruth_boxes' - groundtruth_classes = 'groundtruth_classes' - groundtruth_confidences = 'groundtruth_confidences' - groundtruth_label_types = 'groundtruth_label_types' - groundtruth_is_crowd = 'groundtruth_is_crowd' - groundtruth_area = 'groundtruth_area' - groundtruth_difficult = 'groundtruth_difficult' - groundtruth_group_of = 'groundtruth_group_of' - proposal_boxes = 'proposal_boxes' - proposal_objectness = 'proposal_objectness' - groundtruth_instance_masks = 'groundtruth_instance_masks' - groundtruth_instance_boundaries = 'groundtruth_instance_boundaries' - groundtruth_instance_classes = 'groundtruth_instance_classes' - groundtruth_keypoints = 'groundtruth_keypoints' - groundtruth_keypoint_visibilities = 'groundtruth_keypoint_visibilities' - groundtruth_label_weights = 'groundtruth_label_weights' - groundtruth_weights = 'groundtruth_weights' - num_groundtruth_boxes = 'num_groundtruth_boxes' - is_annotated = 'is_annotated' - true_image_shape = 'true_image_shape' - multiclass_scores = 'multiclass_scores' + + image = "image" + image_additional_channels = "image_additional_channels" + original_image = "original_image" + original_image_spatial_shape = "original_image_spatial_shape" + key = "key" + source_id = "source_id" + filename = "filename" + groundtruth_image_classes = "groundtruth_image_classes" + groundtruth_image_confidences = "groundtruth_image_confidences" + groundtruth_boxes = "groundtruth_boxes" + groundtruth_classes = "groundtruth_classes" + groundtruth_confidences = "groundtruth_confidences" + groundtruth_label_types = "groundtruth_label_types" + groundtruth_is_crowd = "groundtruth_is_crowd" + groundtruth_area = "groundtruth_area" + groundtruth_difficult = "groundtruth_difficult" + groundtruth_group_of = "groundtruth_group_of" + proposal_boxes = "proposal_boxes" + proposal_objectness = "proposal_objectness" + groundtruth_instance_masks = "groundtruth_instance_masks" + groundtruth_instance_boundaries = "groundtruth_instance_boundaries" + groundtruth_instance_classes = "groundtruth_instance_classes" + groundtruth_keypoints = "groundtruth_keypoints" + groundtruth_keypoint_visibilities = "groundtruth_keypoint_visibilities" + groundtruth_label_weights = "groundtruth_label_weights" + groundtruth_weights = "groundtruth_weights" + num_groundtruth_boxes = "num_groundtruth_boxes" + is_annotated = "is_annotated" + true_image_shape = "true_image_shape" + multiclass_scores = "multiclass_scores" class DetectionResultFields(object): @@ -119,17 +120,17 @@ class DetectionResultFields(object): raw_detection_scores: contains class score logits for raw detection boxes. """ - source_id = 'source_id' - key = 'key' - detection_boxes = 'detection_boxes' - detection_scores = 'detection_scores' - detection_classes = 'detection_classes' - detection_masks = 'detection_masks' - detection_boundaries = 'detection_boundaries' - detection_keypoints = 'detection_keypoints' - num_detections = 'num_detections' - raw_detection_boxes = 'raw_detection_boxes' - raw_detection_scores = 'raw_detection_scores' + source_id = "source_id" + key = "key" + detection_boxes = "detection_boxes" + detection_scores = "detection_scores" + detection_classes = "detection_classes" + detection_masks = "detection_masks" + detection_boundaries = "detection_boundaries" + detection_keypoints = "detection_keypoints" + num_detections = "num_detections" + raw_detection_boxes = "raw_detection_boxes" + raw_detection_scores = "raw_detection_scores" class BoxListFields(object): @@ -147,17 +148,18 @@ class BoxListFields(object): keypoint_heatmaps: keypoint heatmaps per bounding box. is_crowd: is_crowd annotation per bounding box. """ - boxes = 'boxes' - classes = 'classes' - scores = 'scores' - weights = 'weights' - confidences = 'confidences' - objectness = 'objectness' - masks = 'masks' - boundaries = 'boundaries' - keypoints = 'keypoints' - keypoint_heatmaps = 'keypoint_heatmaps' - is_crowd = 'is_crowd' + + boxes = "boxes" + classes = "classes" + scores = "scores" + weights = "weights" + confidences = "confidences" + objectness = "objectness" + masks = "masks" + boundaries = "boundaries" + keypoints = "keypoints" + keypoint_heatmaps = "keypoint_heatmaps" + is_crowd = "is_crowd" class TfExampleFields(object): @@ -203,37 +205,38 @@ class TfExampleFields(object): detection_bbox_xmax: xmax coordinates of a detection box. detection_score: detection score for the class label and box. """ - image_encoded = 'image/encoded' - image_format = 'image/format' # format is reserved keyword - filename = 'image/filename' - channels = 'image/channels' - colorspace = 'image/colorspace' - height = 'image/height' - width = 'image/width' - source_id = 'image/source_id' - image_class_text = 'image/class/text' - image_class_label = 'image/class/label' - object_class_text = 'image/object/class/text' - object_class_label = 'image/object/class/label' - object_bbox_ymin = 'image/object/bbox/ymin' - object_bbox_xmin = 'image/object/bbox/xmin' - object_bbox_ymax = 'image/object/bbox/ymax' - object_bbox_xmax = 'image/object/bbox/xmax' - object_view = 'image/object/view' - object_truncated = 'image/object/truncated' - object_occluded = 'image/object/occluded' - object_difficult = 'image/object/difficult' - object_group_of = 'image/object/group_of' - object_depiction = 'image/object/depiction' - object_is_crowd = 'image/object/is_crowd' - object_segment_area = 'image/object/segment/area' - object_weight = 'image/object/weight' - instance_masks = 'image/segmentation/object' - instance_boundaries = 'image/boundaries/object' - instance_classes = 'image/segmentation/object/class' - detection_class_label = 'image/detection/label' - detection_bbox_ymin = 'image/detection/bbox/ymin' - detection_bbox_xmin = 'image/detection/bbox/xmin' - detection_bbox_ymax = 'image/detection/bbox/ymax' - detection_bbox_xmax = 'image/detection/bbox/xmax' - detection_score = 'image/detection/score' + + image_encoded = "image/encoded" + image_format = "image/format" # format is reserved keyword + filename = "image/filename" + channels = "image/channels" + colorspace = "image/colorspace" + height = "image/height" + width = "image/width" + source_id = "image/source_id" + image_class_text = "image/class/text" + image_class_label = "image/class/label" + object_class_text = "image/object/class/text" + object_class_label = "image/object/class/label" + object_bbox_ymin = "image/object/bbox/ymin" + object_bbox_xmin = "image/object/bbox/xmin" + object_bbox_ymax = "image/object/bbox/ymax" + object_bbox_xmax = "image/object/bbox/xmax" + object_view = "image/object/view" + object_truncated = "image/object/truncated" + object_occluded = "image/object/occluded" + object_difficult = "image/object/difficult" + object_group_of = "image/object/group_of" + object_depiction = "image/object/depiction" + object_is_crowd = "image/object/is_crowd" + object_segment_area = "image/object/segment/area" + object_weight = "image/object/weight" + instance_masks = "image/segmentation/object" + instance_boundaries = "image/boundaries/object" + instance_classes = "image/segmentation/object/class" + detection_class_label = "image/detection/label" + detection_bbox_ymin = "image/detection/bbox/ymin" + detection_bbox_xmin = "image/detection/bbox/xmin" + detection_bbox_ymax = "image/detection/bbox/ymax" + detection_bbox_xmax = "image/detection/bbox/xmax" + detection_score = "image/detection/score" diff --git a/robotoff/ml/object_detection/utils/string_int_label_map_pb2.py b/robotoff/ml/object_detection/utils/string_int_label_map_pb2.py index bd6251d502..48d90f6b87 100644 --- a/robotoff/ml/object_detection/utils/string_int_label_map_pb2.py +++ b/robotoff/ml/object_detection/utils/string_int_label_map_pb2.py @@ -2,121 +2,163 @@ # source: robotoff/ml/object_detection/utils/string_int_label_map.proto import sys -_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) + +_b = sys.version_info[0] < 3 and (lambda x: x) or (lambda x: x.encode("latin1")) from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pb2 from google.protobuf import message as _message from google.protobuf import reflection as _reflection from google.protobuf import symbol_database as _symbol_database -from google.protobuf import descriptor_pb2 + # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() - - DESCRIPTOR = _descriptor.FileDescriptor( - name='robotoff/ml/object_detection/utils/string_int_label_map.proto', - package='object_detection.protos', - syntax='proto2', - serialized_pb=_b('\n=robotoff/ml/object_detection/utils/string_int_label_map.proto\x12\x17object_detection.protos\"G\n\x15StringIntLabelMapItem\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\n\n\x02id\x18\x02 \x01(\x05\x12\x14\n\x0c\x64isplay_name\x18\x03 \x01(\t\"Q\n\x11StringIntLabelMap\x12<\n\x04item\x18\x01 \x03(\x0b\x32..object_detection.protos.StringIntLabelMapItem') + name="robotoff/ml/object_detection/utils/string_int_label_map.proto", + package="object_detection.protos", + syntax="proto2", + serialized_pb=_b( + '\n=robotoff/ml/object_detection/utils/string_int_label_map.proto\x12\x17object_detection.protos"G\n\x15StringIntLabelMapItem\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\n\n\x02id\x18\x02 \x01(\x05\x12\x14\n\x0c\x64isplay_name\x18\x03 \x01(\t"Q\n\x11StringIntLabelMap\x12<\n\x04item\x18\x01 \x03(\x0b\x32..object_detection.protos.StringIntLabelMapItem' + ), ) _sym_db.RegisterFileDescriptor(DESCRIPTOR) - - _STRINGINTLABELMAPITEM = _descriptor.Descriptor( - name='StringIntLabelMapItem', - full_name='object_detection.protos.StringIntLabelMapItem', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='name', full_name='object_detection.protos.StringIntLabelMapItem.name', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=_b("").decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='id', full_name='object_detection.protos.StringIntLabelMapItem.id', index=1, - number=2, type=5, cpp_type=1, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='display_name', full_name='object_detection.protos.StringIntLabelMapItem.display_name', index=2, - number=3, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=_b("").decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - options=None, - is_extendable=False, - syntax='proto2', - extension_ranges=[], - oneofs=[ - ], - serialized_start=90, - serialized_end=161, + name="StringIntLabelMapItem", + full_name="object_detection.protos.StringIntLabelMapItem", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="name", + full_name="object_detection.protos.StringIntLabelMapItem.name", + index=0, + number=1, + type=9, + cpp_type=9, + label=1, + has_default_value=False, + default_value=_b("").decode("utf-8"), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="id", + full_name="object_detection.protos.StringIntLabelMapItem.id", + index=1, + number=2, + type=5, + cpp_type=1, + label=1, + has_default_value=False, + default_value=0, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="display_name", + full_name="object_detection.protos.StringIntLabelMapItem.display_name", + index=2, + number=3, + type=9, + cpp_type=9, + label=1, + has_default_value=False, + default_value=_b("").decode("utf-8"), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto2", + extension_ranges=[], + oneofs=[], + serialized_start=90, + serialized_end=161, ) _STRINGINTLABELMAP = _descriptor.Descriptor( - name='StringIntLabelMap', - full_name='object_detection.protos.StringIntLabelMap', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='item', full_name='object_detection.protos.StringIntLabelMap.item', index=0, - number=1, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - options=None, - is_extendable=False, - syntax='proto2', - extension_ranges=[], - oneofs=[ - ], - serialized_start=163, - serialized_end=244, + name="StringIntLabelMap", + full_name="object_detection.protos.StringIntLabelMap", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="item", + full_name="object_detection.protos.StringIntLabelMap.item", + index=0, + number=1, + type=11, + cpp_type=10, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto2", + extension_ranges=[], + oneofs=[], + serialized_start=163, + serialized_end=244, ) -_STRINGINTLABELMAP.fields_by_name['item'].message_type = _STRINGINTLABELMAPITEM -DESCRIPTOR.message_types_by_name['StringIntLabelMapItem'] = _STRINGINTLABELMAPITEM -DESCRIPTOR.message_types_by_name['StringIntLabelMap'] = _STRINGINTLABELMAP - -StringIntLabelMapItem = _reflection.GeneratedProtocolMessageType('StringIntLabelMapItem', (_message.Message,), dict( - DESCRIPTOR = _STRINGINTLABELMAPITEM, - __module__ = 'robotoff.ml.object_detection.utils.string_int_label_map_pb2' - # @@protoc_insertion_point(class_scope:object_detection.protos.StringIntLabelMapItem) - )) +_STRINGINTLABELMAP.fields_by_name["item"].message_type = _STRINGINTLABELMAPITEM +DESCRIPTOR.message_types_by_name["StringIntLabelMapItem"] = _STRINGINTLABELMAPITEM +DESCRIPTOR.message_types_by_name["StringIntLabelMap"] = _STRINGINTLABELMAP + +StringIntLabelMapItem = _reflection.GeneratedProtocolMessageType( + "StringIntLabelMapItem", + (_message.Message,), + dict( + DESCRIPTOR=_STRINGINTLABELMAPITEM, + __module__="robotoff.ml.object_detection.utils.string_int_label_map_pb2" + # @@protoc_insertion_point(class_scope:object_detection.protos.StringIntLabelMapItem) + ), +) _sym_db.RegisterMessage(StringIntLabelMapItem) -StringIntLabelMap = _reflection.GeneratedProtocolMessageType('StringIntLabelMap', (_message.Message,), dict( - DESCRIPTOR = _STRINGINTLABELMAP, - __module__ = 'robotoff.ml.object_detection.utils.string_int_label_map_pb2' - # @@protoc_insertion_point(class_scope:object_detection.protos.StringIntLabelMap) - )) +StringIntLabelMap = _reflection.GeneratedProtocolMessageType( + "StringIntLabelMap", + (_message.Message,), + dict( + DESCRIPTOR=_STRINGINTLABELMAP, + __module__="robotoff.ml.object_detection.utils.string_int_label_map_pb2" + # @@protoc_insertion_point(class_scope:object_detection.protos.StringIntLabelMap) + ), +) _sym_db.RegisterMessage(StringIntLabelMap) diff --git a/robotoff/ml/object_detection/utils/visualization_utils.py b/robotoff/ml/object_detection/utils/visualization_utils.py index 9ac015a570..aacc057758 100644 --- a/robotoff/ml/object_detection/utils/visualization_utils.py +++ b/robotoff/ml/object_detection/utils/visualization_utils.py @@ -25,7 +25,7 @@ # Set headless-friendly backend. import matplotlib -matplotlib.use('Agg') # pylint: disable=multiple-statements +matplotlib.use("Agg") # pylint: disable=multiple-statements import matplotlib.pyplot as plt # pylint: disable=g-import-not-at-top import numpy as np import PIL.Image as Image @@ -35,35 +35,137 @@ import six import tensorflow as tf -from robotoff.ml.object_detection.utils import \ - standard_fields as fields +from robotoff.ml.object_detection.utils import standard_fields as fields _TITLE_LEFT_MARGIN = 10 _TITLE_TOP_MARGIN = 10 STANDARD_COLORS = [ - 'AliceBlue', 'Chartreuse', 'Aqua', 'Aquamarine', 'Azure', 'Beige', 'Bisque', - 'BlanchedAlmond', 'BlueViolet', 'BurlyWood', 'CadetBlue', 'AntiqueWhite', - 'Chocolate', 'Coral', 'CornflowerBlue', 'Cornsilk', 'Crimson', 'Cyan', - 'DarkCyan', 'DarkGoldenRod', 'DarkGrey', 'DarkKhaki', 'DarkOrange', - 'DarkOrchid', 'DarkSalmon', 'DarkSeaGreen', 'DarkTurquoise', 'DarkViolet', - 'DeepPink', 'DeepSkyBlue', 'DodgerBlue', 'FireBrick', 'FloralWhite', - 'ForestGreen', 'Fuchsia', 'Gainsboro', 'GhostWhite', 'Gold', 'GoldenRod', - 'Salmon', 'Tan', 'HoneyDew', 'HotPink', 'IndianRed', 'Ivory', 'Khaki', - 'Lavender', 'LavenderBlush', 'LawnGreen', 'LemonChiffon', 'LightBlue', - 'LightCoral', 'LightCyan', 'LightGoldenRodYellow', 'LightGray', 'LightGrey', - 'LightGreen', 'LightPink', 'LightSalmon', 'LightSeaGreen', 'LightSkyBlue', - 'LightSlateGray', 'LightSlateGrey', 'LightSteelBlue', 'LightYellow', 'Lime', - 'LimeGreen', 'Linen', 'Magenta', 'MediumAquaMarine', 'MediumOrchid', - 'MediumPurple', 'MediumSeaGreen', 'MediumSlateBlue', 'MediumSpringGreen', - 'MediumTurquoise', 'MediumVioletRed', 'MintCream', 'MistyRose', 'Moccasin', - 'NavajoWhite', 'OldLace', 'Olive', 'OliveDrab', 'Orange', 'OrangeRed', - 'Orchid', 'PaleGoldenRod', 'PaleGreen', 'PaleTurquoise', 'PaleVioletRed', - 'PapayaWhip', 'PeachPuff', 'Peru', 'Pink', 'Plum', 'PowderBlue', 'Purple', - 'Red', 'RosyBrown', 'RoyalBlue', 'SaddleBrown', 'Green', 'SandyBrown', - 'SeaGreen', 'SeaShell', 'Sienna', 'Silver', 'SkyBlue', 'SlateBlue', - 'SlateGray', 'SlateGrey', 'Snow', 'SpringGreen', 'SteelBlue', 'GreenYellow', - 'Teal', 'Thistle', 'Tomato', 'Turquoise', 'Violet', 'Wheat', 'White', - 'WhiteSmoke', 'Yellow', 'YellowGreen' + "AliceBlue", + "Chartreuse", + "Aqua", + "Aquamarine", + "Azure", + "Beige", + "Bisque", + "BlanchedAlmond", + "BlueViolet", + "BurlyWood", + "CadetBlue", + "AntiqueWhite", + "Chocolate", + "Coral", + "CornflowerBlue", + "Cornsilk", + "Crimson", + "Cyan", + "DarkCyan", + "DarkGoldenRod", + "DarkGrey", + "DarkKhaki", + "DarkOrange", + "DarkOrchid", + "DarkSalmon", + "DarkSeaGreen", + "DarkTurquoise", + "DarkViolet", + "DeepPink", + "DeepSkyBlue", + "DodgerBlue", + "FireBrick", + "FloralWhite", + "ForestGreen", + "Fuchsia", + "Gainsboro", + "GhostWhite", + "Gold", + "GoldenRod", + "Salmon", + "Tan", + "HoneyDew", + "HotPink", + "IndianRed", + "Ivory", + "Khaki", + "Lavender", + "LavenderBlush", + "LawnGreen", + "LemonChiffon", + "LightBlue", + "LightCoral", + "LightCyan", + "LightGoldenRodYellow", + "LightGray", + "LightGrey", + "LightGreen", + "LightPink", + "LightSalmon", + "LightSeaGreen", + "LightSkyBlue", + "LightSlateGray", + "LightSlateGrey", + "LightSteelBlue", + "LightYellow", + "Lime", + "LimeGreen", + "Linen", + "Magenta", + "MediumAquaMarine", + "MediumOrchid", + "MediumPurple", + "MediumSeaGreen", + "MediumSlateBlue", + "MediumSpringGreen", + "MediumTurquoise", + "MediumVioletRed", + "MintCream", + "MistyRose", + "Moccasin", + "NavajoWhite", + "OldLace", + "Olive", + "OliveDrab", + "Orange", + "OrangeRed", + "Orchid", + "PaleGoldenRod", + "PaleGreen", + "PaleTurquoise", + "PaleVioletRed", + "PapayaWhip", + "PeachPuff", + "Peru", + "Pink", + "Plum", + "PowderBlue", + "Purple", + "Red", + "RosyBrown", + "RoyalBlue", + "SaddleBrown", + "Green", + "SandyBrown", + "SeaGreen", + "SeaShell", + "Sienna", + "Silver", + "SkyBlue", + "SlateBlue", + "SlateGray", + "SlateGrey", + "Snow", + "SpringGreen", + "SteelBlue", + "GreenYellow", + "Teal", + "Thistle", + "Tomato", + "Turquoise", + "Violet", + "Wheat", + "White", + "WhiteSmoke", + "Yellow", + "YellowGreen", ] @@ -74,9 +176,9 @@ def save_image_array_as_png(image, output_path): image: a numpy array with shape [height, width, 3]. output_path: path to which image should be written. """ - image_pil = Image.fromarray(np.uint8(image)).convert('RGB') - with tf.gfile.Open(output_path, 'w') as fid: - image_pil.save(fid, 'PNG') + image_pil = Image.fromarray(np.uint8(image)).convert("RGB") + with tf.gfile.Open(output_path, "w") as fid: + image_pil.save(fid, "PNG") def encode_image_array_as_png_str(image): @@ -90,21 +192,23 @@ def encode_image_array_as_png_str(image): """ image_pil = Image.fromarray(np.uint8(image)) output = six.BytesIO() - image_pil.save(output, format='PNG') + image_pil.save(output, format="PNG") png_string = output.getvalue() output.close() return png_string -def draw_bounding_box_on_image_array(image, - ymin, - xmin, - ymax, - xmax, - color='red', - thickness=4, - display_str_list=(), - use_normalized_coordinates=True): +def draw_bounding_box_on_image_array( + image, + ymin, + xmin, + ymax, + xmax, + color="red", + thickness=4, + display_str_list=(), + use_normalized_coordinates=True, +): """Adds a bounding box to an image (numpy array). Bounding box coordinates can be specified in either absolute (pixel) or @@ -124,22 +228,32 @@ def draw_bounding_box_on_image_array(image, ymin, xmin, ymax, xmax as relative to the image. Otherwise treat coordinates as absolute. """ - image_pil = Image.fromarray(np.uint8(image)).convert('RGB') - draw_bounding_box_on_image(image_pil, ymin, xmin, ymax, xmax, color, - thickness, display_str_list, - use_normalized_coordinates) + image_pil = Image.fromarray(np.uint8(image)).convert("RGB") + draw_bounding_box_on_image( + image_pil, + ymin, + xmin, + ymax, + xmax, + color, + thickness, + display_str_list, + use_normalized_coordinates, + ) np.copyto(image, np.array(image_pil)) -def draw_bounding_box_on_image(image, - ymin, - xmin, - ymax, - xmax, - color='red', - thickness=4, - display_str_list=(), - use_normalized_coordinates=True): +def draw_bounding_box_on_image( + image, + ymin, + xmin, + ymax, + xmax, + color="red", + thickness=4, + display_str_list=(), + use_normalized_coordinates=True, +): """Adds a bounding box to an image. Bounding box coordinates can be specified in either absolute (pixel) or @@ -167,14 +281,21 @@ def draw_bounding_box_on_image(image, draw = ImageDraw.Draw(image) im_width, im_height = image.size if use_normalized_coordinates: - (left, right, top, bottom) = (xmin * im_width, xmax * im_width, - ymin * im_height, ymax * im_height) + (left, right, top, bottom) = ( + xmin * im_width, + xmax * im_width, + ymin * im_height, + ymax * im_height, + ) else: (left, right, top, bottom) = (xmin, xmax, ymin, ymax) - draw.line([(left, top), (left, bottom), (right, bottom), - (right, top), (left, top)], width=thickness, fill=color) + draw.line( + [(left, top), (left, bottom), (right, bottom), (right, top), (left, top)], + width=thickness, + fill=color, + ) try: - font = ImageFont.truetype('arial.ttf', 24) + font = ImageFont.truetype("arial.ttf", 24) except IOError: font = ImageFont.load_default() @@ -194,22 +315,24 @@ def draw_bounding_box_on_image(image, text_width, text_height = font.getsize(display_str) margin = np.ceil(0.05 * text_height) draw.rectangle( - [(left, text_bottom - text_height - 2 * margin), (left + text_width, - text_bottom)], - fill=color) + [ + (left, text_bottom - text_height - 2 * margin), + (left + text_width, text_bottom), + ], + fill=color, + ) draw.text( (left + margin, text_bottom - text_height - margin), display_str, - fill='black', - font=font) + fill="black", + font=font, + ) text_bottom -= text_height - 2 * margin -def draw_bounding_boxes_on_image_array(image, - boxes, - color='red', - thickness=4, - display_str_list_list=()): +def draw_bounding_boxes_on_image_array( + image, boxes, color="red", thickness=4, display_str_list_list=() +): """Draws bounding boxes on image (numpy array). Args: @@ -228,16 +351,15 @@ def draw_bounding_boxes_on_image_array(image, ValueError: if boxes is not a [N, 4] array """ image_pil = Image.fromarray(image) - draw_bounding_boxes_on_image(image_pil, boxes, color, thickness, - display_str_list_list) + draw_bounding_boxes_on_image( + image_pil, boxes, color, thickness, display_str_list_list + ) np.copyto(image, np.array(image_pil)) -def draw_bounding_boxes_on_image(image, - boxes, - color='red', - thickness=4, - display_str_list_list=()): +def draw_bounding_boxes_on_image( + image, boxes, color="red", thickness=4, display_str_list_list=() +): """Draws bounding boxes on image. Args: @@ -259,23 +381,32 @@ def draw_bounding_boxes_on_image(image, if not boxes_shape: return if len(boxes_shape) != 2 or boxes_shape[1] != 4: - raise ValueError('Input must be of size [N, 4]') + raise ValueError("Input must be of size [N, 4]") for i in range(boxes_shape[0]): display_str_list = () if display_str_list_list: display_str_list = display_str_list_list[i] - draw_bounding_box_on_image(image, boxes[i, 0], boxes[i, 1], boxes[i, 2], - boxes[i, 3], color, thickness, - display_str_list) + draw_bounding_box_on_image( + image, + boxes[i, 0], + boxes[i, 1], + boxes[i, 2], + boxes[i, 3], + color, + thickness, + display_str_list, + ) def _visualize_boxes(image, boxes, classes, scores, category_index, **kwargs): return visualize_boxes_and_labels_on_image_array( - image, boxes, classes, scores, category_index=category_index, **kwargs) + image, boxes, classes, scores, category_index=category_index, **kwargs + ) -def _visualize_boxes_and_masks(image, boxes, classes, scores, masks, - category_index, **kwargs): +def _visualize_boxes_and_masks( + image, boxes, classes, scores, masks, category_index, **kwargs +): return visualize_boxes_and_labels_on_image_array( image, boxes, @@ -283,11 +414,13 @@ def _visualize_boxes_and_masks(image, boxes, classes, scores, masks, scores, category_index=category_index, instance_masks=masks, - **kwargs) + **kwargs + ) -def _visualize_boxes_and_keypoints(image, boxes, classes, scores, keypoints, - category_index, **kwargs): +def _visualize_boxes_and_keypoints( + image, boxes, classes, scores, keypoints, category_index, **kwargs +): return visualize_boxes_and_labels_on_image_array( image, boxes, @@ -295,12 +428,13 @@ def _visualize_boxes_and_keypoints(image, boxes, classes, scores, keypoints, scores, category_index=category_index, keypoints=keypoints, - **kwargs) + **kwargs + ) def _visualize_boxes_and_masks_and_keypoints( - image, boxes, classes, scores, masks, keypoints, category_index, - **kwargs): + image, boxes, classes, scores, masks, keypoints, category_index, **kwargs +): return visualize_boxes_and_labels_on_image_array( image, boxes, @@ -309,18 +443,21 @@ def _visualize_boxes_and_masks_and_keypoints( category_index=category_index, instance_masks=masks, keypoints=keypoints, - **kwargs) - - -def draw_bounding_boxes_on_image_tensors(images, - boxes, - classes, - scores, - category_index, - instance_masks=None, - keypoints=None, - max_boxes_to_draw=20, - min_score_thresh=0.2): + **kwargs + ) + + +def draw_bounding_boxes_on_image_tensors( + images, + boxes, + classes, + scores, + category_index, + instance_masks=None, + keypoints=None, + max_boxes_to_draw=20, + min_score_thresh=0.2, +): """Draws bounding boxes, masks, and keypoints on batch of image tensors. Args: @@ -342,52 +479,56 @@ def draw_bounding_boxes_on_image_tensors(images, 4D image tensor of type uint8, with boxes drawn on top. """ visualization_keyword_args = { - 'use_normalized_coordinates': True, - 'max_boxes_to_draw': max_boxes_to_draw, - 'min_score_thresh': min_score_thresh, - 'agnostic_mode': False, - 'line_thickness': 4 + "use_normalized_coordinates": True, + "max_boxes_to_draw": max_boxes_to_draw, + "min_score_thresh": min_score_thresh, + "agnostic_mode": False, + "line_thickness": 4, } if instance_masks is not None and keypoints is None: visualize_boxes_fn = functools.partial( _visualize_boxes_and_masks, category_index=category_index, - **visualization_keyword_args) + **visualization_keyword_args + ) elems = [images, boxes, classes, scores, instance_masks] elif instance_masks is None and keypoints is not None: visualize_boxes_fn = functools.partial( _visualize_boxes_and_keypoints, category_index=category_index, - **visualization_keyword_args) + **visualization_keyword_args + ) elems = [images, boxes, classes, scores, keypoints] elif instance_masks is not None and keypoints is not None: visualize_boxes_fn = functools.partial( _visualize_boxes_and_masks_and_keypoints, category_index=category_index, - **visualization_keyword_args) + **visualization_keyword_args + ) elems = [images, boxes, classes, scores, instance_masks, keypoints] else: visualize_boxes_fn = functools.partial( _visualize_boxes, category_index=category_index, - **visualization_keyword_args) + **visualization_keyword_args + ) elems = [images, boxes, classes, scores] def draw_boxes(image_and_detections): """Draws boxes on image.""" - image_with_boxes = tf.py_func(visualize_boxes_fn, image_and_detections, - tf.uint8) + image_with_boxes = tf.py_func( + visualize_boxes_fn, image_and_detections, tf.uint8 + ) return image_with_boxes images = tf.map_fn(draw_boxes, elems, dtype=tf.uint8, back_prop=False) return images -def draw_side_by_side_evaluation_image(eval_dict, - category_index, - max_boxes_to_draw=20, - min_score_thresh=0.2): +def draw_side_by_side_evaluation_image( + eval_dict, category_index, max_boxes_to_draw=20, min_score_thresh=0.2 +): """Creates a side-by-side image with detections and groundtruth. Bounding boxes (and instance masks, if available) are visualized on both @@ -410,18 +551,21 @@ def draw_side_by_side_evaluation_image(eval_dict, if detection_fields.detection_masks in eval_dict: instance_masks = tf.cast( tf.expand_dims(eval_dict[detection_fields.detection_masks], axis=0), - tf.uint8) + tf.uint8, + ) keypoints = None if detection_fields.detection_keypoints in eval_dict: keypoints = tf.expand_dims( - eval_dict[detection_fields.detection_keypoints], axis=0) + eval_dict[detection_fields.detection_keypoints], axis=0 + ) groundtruth_instance_masks = None if input_data_fields.groundtruth_instance_masks in eval_dict: groundtruth_instance_masks = tf.cast( tf.expand_dims( - eval_dict[input_data_fields.groundtruth_instance_masks], - axis=0), - tf.uint8) + eval_dict[input_data_fields.groundtruth_instance_masks], axis=0 + ), + tf.uint8, + ) images_with_detections = draw_bounding_boxes_on_image_tensors( eval_dict[input_data_fields.original_image], tf.expand_dims(eval_dict[detection_fields.detection_boxes], axis=0), @@ -431,30 +575,30 @@ def draw_side_by_side_evaluation_image(eval_dict, instance_masks=instance_masks, keypoints=keypoints, max_boxes_to_draw=max_boxes_to_draw, - min_score_thresh=min_score_thresh) + min_score_thresh=min_score_thresh, + ) images_with_groundtruth = draw_bounding_boxes_on_image_tensors( eval_dict[input_data_fields.original_image], tf.expand_dims(eval_dict[input_data_fields.groundtruth_boxes], axis=0), - tf.expand_dims(eval_dict[input_data_fields.groundtruth_classes], - axis=0), + tf.expand_dims(eval_dict[input_data_fields.groundtruth_classes], axis=0), tf.expand_dims( tf.ones_like( - eval_dict[input_data_fields.groundtruth_classes], - dtype=tf.float32), - axis=0), + eval_dict[input_data_fields.groundtruth_classes], dtype=tf.float32 + ), + axis=0, + ), category_index, instance_masks=groundtruth_instance_masks, keypoints=None, max_boxes_to_draw=None, - min_score_thresh=0.0) + min_score_thresh=0.0, + ) return tf.concat([images_with_detections, images_with_groundtruth], axis=2) -def draw_keypoints_on_image_array(image, - keypoints, - color='red', - radius=2, - use_normalized_coordinates=True): +def draw_keypoints_on_image_array( + image, keypoints, color="red", radius=2, use_normalized_coordinates=True +): """Draws keypoints on an image (numpy array). Args: @@ -465,17 +609,16 @@ def draw_keypoints_on_image_array(image, use_normalized_coordinates: if True (default), treat keypoint values as relative to the image. Otherwise treat them as absolute. """ - image_pil = Image.fromarray(np.uint8(image)).convert('RGB') - draw_keypoints_on_image(image_pil, keypoints, color, radius, - use_normalized_coordinates) + image_pil = Image.fromarray(np.uint8(image)).convert("RGB") + draw_keypoints_on_image( + image_pil, keypoints, color, radius, use_normalized_coordinates + ) np.copyto(image, np.array(image_pil)) -def draw_keypoints_on_image(image, - keypoints, - color='red', - radius=2, - use_normalized_coordinates=True): +def draw_keypoints_on_image( + image, keypoints, color="red", radius=2, use_normalized_coordinates=True +): """Draws keypoints on an image. Args: @@ -494,12 +637,17 @@ def draw_keypoints_on_image(image, keypoints_x = tuple([im_width * x for x in keypoints_x]) keypoints_y = tuple([im_height * y for y in keypoints_y]) for keypoint_x, keypoint_y in zip(keypoints_x, keypoints_y): - draw.ellipse([(keypoint_x - radius, keypoint_y - radius), - (keypoint_x + radius, keypoint_y + radius)], - outline=color, fill=color) + draw.ellipse( + [ + (keypoint_x - radius, keypoint_y - radius), + (keypoint_x + radius, keypoint_y + radius), + ], + outline=color, + fill=color, + ) -def draw_mask_on_image_array(image, mask, color='red', alpha=0.4): +def draw_mask_on_image_array(image, mask, color="red", alpha=0.4): """Draws mask on an image. Args: @@ -513,42 +661,46 @@ def draw_mask_on_image_array(image, mask, color='red', alpha=0.4): ValueError: On incorrect data type for image or masks. """ if image.dtype != np.uint8: - raise ValueError('`image` not of type np.uint8') + raise ValueError("`image` not of type np.uint8") if mask.dtype != np.uint8: - raise ValueError('`mask` not of type np.uint8') + raise ValueError("`mask` not of type np.uint8") if np.any(np.logical_and(mask != 1, mask != 0)): - raise ValueError('`mask` elements should be in [0, 1]') + raise ValueError("`mask` elements should be in [0, 1]") if image.shape[:2] != mask.shape: - raise ValueError('The image has spatial dimensions %s but the mask has ' - 'dimensions %s' % (image.shape[:2], mask.shape)) + raise ValueError( + "The image has spatial dimensions %s but the mask has " + "dimensions %s" % (image.shape[:2], mask.shape) + ) rgb = ImageColor.getrgb(color) pil_image = Image.fromarray(image) - solid_color = np.expand_dims( - np.ones_like(mask), axis=2) * np.reshape(list(rgb), [1, 1, 3]) - pil_solid_color = Image.fromarray(np.uint8(solid_color)).convert('RGBA') - pil_mask = Image.fromarray(np.uint8(255.0 * alpha * mask)).convert('L') + solid_color = np.expand_dims(np.ones_like(mask), axis=2) * np.reshape( + list(rgb), [1, 1, 3] + ) + pil_solid_color = Image.fromarray(np.uint8(solid_color)).convert("RGBA") + pil_mask = Image.fromarray(np.uint8(255.0 * alpha * mask)).convert("L") pil_image = Image.composite(pil_solid_color, pil_image, pil_mask) - np.copyto(image, np.array(pil_image.convert('RGB'))) + np.copyto(image, np.array(pil_image.convert("RGB"))) def visualize_boxes_and_labels_on_image_array( - image, - boxes, - classes, - scores, - category_index, - instance_masks=None, - instance_boundaries=None, - keypoints=None, - use_normalized_coordinates=False, - max_boxes_to_draw=20, - min_score_thresh=.5, - agnostic_mode=False, - line_thickness=4, - groundtruth_box_visualization_color='black', - skip_scores=False, - skip_labels=False): + image, + boxes, + classes, + scores, + category_index, + instance_masks=None, + instance_boundaries=None, + keypoints=None, + use_normalized_coordinates=False, + max_boxes_to_draw=20, + min_score_thresh=0.5, + agnostic_mode=False, + line_thickness=4, + groundtruth_box_visualization_color="black", + skip_scores=False, + skip_labels=False, +): """Overlay labeled boxes on an image with formatted scores and label names. This function groups boxes that correspond to the same location @@ -610,42 +762,37 @@ def visualize_boxes_and_labels_on_image_array( if scores is None: box_to_color_map[box] = groundtruth_box_visualization_color else: - display_str = '' + display_str = "" if not skip_labels: if not agnostic_mode: if classes[i] in category_index.keys(): - class_name = category_index[classes[i]]['name'] + class_name = category_index[classes[i]]["name"] else: - class_name = 'N/A' + class_name = "N/A" display_str = str(class_name) if not skip_scores: if not display_str: - display_str = '{}%'.format(int(100 * scores[i])) + display_str = "{}%".format(int(100 * scores[i])) else: - display_str = '{}: {}%'.format(display_str, - int(100 * scores[i])) + display_str = "{}: {}%".format( + display_str, int(100 * scores[i]) + ) box_to_display_str_map[box].append(display_str) if agnostic_mode: - box_to_color_map[box] = 'DarkOrange' + box_to_color_map[box] = "DarkOrange" else: box_to_color_map[box] = STANDARD_COLORS[ - classes[i] % len(STANDARD_COLORS)] + classes[i] % len(STANDARD_COLORS) + ] # Draw all boxes onto image. for box, color in box_to_color_map.items(): ymin, xmin, ymax, xmax = box if instance_masks is not None: - draw_mask_on_image_array( - image, - box_to_instance_masks_map[box], - color=color - ) + draw_mask_on_image_array(image, box_to_instance_masks_map[box], color=color) if instance_boundaries is not None: draw_mask_on_image_array( - image, - box_to_instance_boundaries_map[box], - color='red', - alpha=1.0 + image, box_to_instance_boundaries_map[box], color="red", alpha=1.0 ) draw_bounding_box_on_image_array( image, @@ -656,14 +803,16 @@ def visualize_boxes_and_labels_on_image_array( color=color, thickness=line_thickness, display_str_list=box_to_display_str_map[box], - use_normalized_coordinates=use_normalized_coordinates) + use_normalized_coordinates=use_normalized_coordinates, + ) if keypoints is not None: draw_keypoints_on_image_array( image, box_to_keypoints_map[box], color=color, radius=line_thickness / 2, - use_normalized_coordinates=use_normalized_coordinates) + use_normalized_coordinates=use_normalized_coordinates, + ) return image @@ -685,17 +834,18 @@ def cdf_plot(values): sorted_values = np.sort(normalized_values) cumulative_values = np.cumsum(sorted_values) fraction_of_examples = ( - np.arange(cumulative_values.size, dtype=np.float32) - / cumulative_values.size) + np.arange(cumulative_values.size, dtype=np.float32) / cumulative_values.size + ) fig = plt.figure(frameon=False) - ax = fig.add_subplot('111') + ax = fig.add_subplot("111") ax.plot(fraction_of_examples, cumulative_values) - ax.set_ylabel('cumulative normalized values') - ax.set_xlabel('fraction of examples') + ax.set_ylabel("cumulative normalized values") + ax.set_xlabel("fraction of examples") fig.canvas.draw() width, height = fig.get_size_inches() * fig.get_dpi() - image = np.fromstring(fig.canvas.tostring_rgb(), dtype='uint8').reshape( - 1, int(height), int(width), 3) + image = np.fromstring(fig.canvas.tostring_rgb(), dtype="uint8").reshape( + 1, int(height), int(width), 3 + ) return image cdf_plot = tf.py_func(cdf_plot, [values], tf.uint8) @@ -716,16 +866,16 @@ def add_hist_image_summary(values, bins, name): def hist_plot(values, bins): """Numpy function to plot hist.""" fig = plt.figure(frameon=False) - ax = fig.add_subplot('111') + ax = fig.add_subplot("111") y, x = np.histogram(values, bins=bins) ax.plot(x[:-1], y) - ax.set_ylabel('count') - ax.set_xlabel('value') + ax.set_ylabel("count") + ax.set_xlabel("value") fig.canvas.draw() width, height = fig.get_size_inches() * fig.get_dpi() - image = np.fromstring( - fig.canvas.tostring_rgb(), dtype='uint8').reshape( - 1, int(height), int(width), 3) + image = np.fromstring(fig.canvas.tostring_rgb(), dtype="uint8").reshape( + 1, int(height), int(width), 3 + ) return image hist_plot = tf.py_func(hist_plot, [values, bins], tf.uint8) diff --git a/robotoff/models.py b/robotoff/models.py index 7f66eb146b..dc2dd6cc6e 100644 --- a/robotoff/models.py +++ b/robotoff/models.py @@ -1,16 +1,18 @@ -from typing import Iterable, Dict +from typing import Dict, Iterable import peewee -from playhouse.postgres_ext import (PostgresqlExtDatabase, - BinaryJSONField) +from playhouse.postgres_ext import BinaryJSONField, PostgresqlExtDatabase from robotoff import settings from robotoff.utils.types import JSONType -db = PostgresqlExtDatabase(settings.DB_NAME, - user=settings.DB_USER, - password=settings.DB_PASSWORD, - host=settings.DB_HOST, port=5432) +db = PostgresqlExtDatabase( + settings.DB_NAME, + user=settings.DB_USER, + password=settings.DB_PASSWORD, + host=settings.DB_HOST, + port=5432, +) def batch_insert(model_cls, data: Iterable[Dict], batch_size=100) -> int: @@ -54,10 +56,10 @@ class ProductInsight(BaseModel): def serialize(self) -> JSONType: return { - 'id': str(self.id), - 'type': self.type, - 'barcode': self.barcode, - 'countries': self.countries, + "id": str(self.id), + "type": self.type, + "barcode": self.barcode, + "countries": self.countries, **self.data, } diff --git a/robotoff/off.py b/robotoff/off.py index 336db3e18d..b10109d5e6 100644 --- a/robotoff/off.py +++ b/robotoff/off.py @@ -1,5 +1,5 @@ import re -from typing import List, Dict, Optional +from typing import Dict, List, Optional import requests @@ -12,8 +12,8 @@ DRY_POST_URL = "https://world.openfoodfacts.net/cgi/product_jqm2.pl" AUTH = ("roboto-app", settings.OFF_PASSWORD) AUTH_DICT = { - 'user_id': AUTH[0], - 'password': AUTH[1], + "user_id": AUTH[0], + "password": AUTH[1], } API_URL = "https://world.openfoodfacts.org/api/v0" @@ -25,7 +25,7 @@ BARCODE_PATH_REGEX = re.compile(r"^(...)(...)(...)(.*)$") USER_AGENT_HEADERS = { - 'User-Agent': settings.ROBOTOFF_USER_AGENT, + "User-Agent": settings.ROBOTOFF_USER_AGENT, } @@ -43,27 +43,27 @@ def split_barcode(barcode: str) -> List[str]: def generate_json_ocr_url(barcode: str, image_name: str) -> str: splitted_barcode = split_barcode(barcode) - path = "/{}/{}.json".format('/'.join(splitted_barcode), image_name) + path = "/{}/{}.json".format("/".join(splitted_barcode), image_name) return settings.OFF_IMAGE_BASE_URL + path def generate_image_url(barcode: str, image_name: str) -> str: splitted_barcode = split_barcode(barcode) - path = "/{}/{}.jpg".format('/'.join(splitted_barcode), image_name) + path = "/{}/{}.jpg".format("/".join(splitted_barcode), image_name) return settings.OFF_IMAGE_BASE_URL + path def product_exists(barcode: str) -> bool: - return get_product(barcode, ['code']) is not None + return get_product(barcode, ["code"]) is not None def is_valid_image(barcode: str, image_id: str) -> bool: - product = get_product(barcode, fields=['images']) + product = get_product(barcode, fields=["images"]) if product is None: return False - images = product.get('images', {}) + images = product.get("images", {}) return image_id in images @@ -76,7 +76,7 @@ def get_product(barcode: str, fields: List[str] = None) -> Optional[Dict]: # requests escape comma in URLs, as expected, but openfoodfacts server # does not recognize escaped commas. # See https://github.com/openfoodfacts/openfoodfacts-server/issues/1607 - url += '?fields={}'.format(','.join(fields)) + url += "?fields={}".format(",".join(fields)) r = http_session.get(url, headers=USER_AGENT_HEADERS) @@ -85,41 +85,37 @@ def get_product(barcode: str, fields: List[str] = None) -> Optional[Dict]: data = r.json() - if data['status_verbose'] != 'product found': + if data["status_verbose"] != "product found": return None - return data['product'] + return data["product"] def add_category(barcode: str, category: str, dry=False): params = { - 'code': barcode, - 'add_categories': category, - 'comment': "Adding category '{}' " - "(automated edit)".format(category), - **AUTH_DICT + "code": barcode, + "add_categories": category, + "comment": "Adding category '{}' " "(automated edit)".format(category), + **AUTH_DICT, } update_product(params, dry=dry) def update_quantity(barcode: str, quantity: str, dry=False): params = { - 'code': barcode, - 'quantity': quantity, - 'comment': "Updating quantity to '{}' " - "(automated edit)".format(quantity), + "code": barcode, + "quantity": quantity, + "comment": "Updating quantity to '{}' " "(automated edit)".format(quantity), **AUTH_DICT, } update_product(params, dry=dry) -def save_ingredients(barcode: str, ingredient_text: str, - lang: str = None, dry=False): - ingredient_key = ('ingredients_text' if lang is None - else f'ingredients_{lang}_text') +def save_ingredients(barcode: str, ingredient_text: str, lang: str = None, dry=False): + ingredient_key = "ingredients_text" if lang is None else f"ingredients_{lang}_text" params = { - 'code': barcode, - 'comment': "Ingredient spellcheck correction (automated edit)", + "code": barcode, + "comment": "Ingredient spellcheck correction (automated edit)", ingredient_key: ingredient_text, **AUTH_DICT, } @@ -127,12 +123,12 @@ def save_ingredients(barcode: str, ingredient_text: str, def update_emb_codes(barcode: str, emb_codes: List[str], dry=False): - emb_codes_str = ','.join(emb_codes) + emb_codes_str = ",".join(emb_codes) params = { - 'code': barcode, - 'emb_codes': emb_codes_str, - 'comment': "Adding packager code (automated edit)", + "code": barcode, + "emb_codes": emb_codes_str, + "comment": "Adding packager code (automated edit)", **AUTH_DICT, } update_product(params, dry=dry) @@ -140,10 +136,10 @@ def update_emb_codes(barcode: str, emb_codes: List[str], dry=False): def update_expiration_date(barcode: str, expiration_date: str, dry=False): params = { - 'code': barcode, - 'expiration_date': expiration_date, - 'comment': "Adding expiration date '{}' " - "(automated edit)".format(expiration_date), + "code": barcode, + "expiration_date": expiration_date, + "comment": "Adding expiration date '{}' " + "(automated edit)".format(expiration_date), **AUTH_DICT, } update_product(params, dry=dry) @@ -151,9 +147,9 @@ def update_expiration_date(barcode: str, expiration_date: str, dry=False): def add_label_tag(barcode: str, label_tag: str, dry=False): params = { - 'code': barcode, - 'add_labels': label_tag, - 'comment': "Adding label tag '{}' (automated edit)".format(label_tag), + "code": barcode, + "add_labels": label_tag, + "comment": "Adding label tag '{}' (automated edit)".format(label_tag), **AUTH_DICT, } update_product(params, dry=dry) @@ -161,9 +157,9 @@ def add_label_tag(barcode: str, label_tag: str, dry=False): def add_brand(barcode: str, brand: str, dry=False): params = { - 'code': barcode, - 'add_brands': brand, - 'comment': "Adding brand '{}' (automated edit)".format(brand), + "code": barcode, + "add_brands": brand, + "comment": "Adding brand '{}' (automated edit)".format(brand), **AUTH_DICT, } update_product(params, dry=dry) @@ -171,9 +167,9 @@ def add_brand(barcode: str, brand: str, dry=False): def add_store(barcode: str, store: str, dry=False): params = { - 'code': barcode, - 'add_stores': store, - 'comment': "Adding store '{}' (automated edit)".format(store), + "code": barcode, + "add_stores": store, + "comment": "Adding store '{}' (automated edit)".format(store), **AUTH_DICT, } update_product(params, dry=dry) @@ -181,19 +177,16 @@ def add_store(barcode: str, store: str, dry=False): def update_product(params: Dict, dry=False): if dry: - r = http_session.get(DRY_POST_URL, params=params, - auth=('off', 'off'), - headers=USER_AGENT_HEADERS) + r = http_session.get( + DRY_POST_URL, params=params, auth=("off", "off"), headers=USER_AGENT_HEADERS + ) else: - r = http_session.get(POST_URL, params=params, - headers=USER_AGENT_HEADERS) + r = http_session.get(POST_URL, params=params, headers=USER_AGENT_HEADERS) r.raise_for_status() json = r.json() - status = json.get('status_verbose') + status = json.get("status_verbose") if status != "fields saved": - logger.warn( - "Unexpected status during product update: {}".format( - status)) + logger.warn("Unexpected status during product update: {}".format(status)) diff --git a/robotoff/products.py b/robotoff/products.py index f57787def7..7c1a0e59fe 100644 --- a/robotoff/products.py +++ b/robotoff/products.py @@ -5,54 +5,57 @@ import pathlib import shutil import tempfile -from typing import List, Iterable, Dict, Optional, Iterator +from typing import Dict, Iterable, Iterator, List, Optional import requests -from robotoff.utils import jsonl_iter, gzip_jsonl_iter, get_logger from robotoff import settings +from robotoff.utils import get_logger, gzip_jsonl_iter, jsonl_iter from robotoff.utils.cache import CachedStore from robotoff.utils.types import JSONType logger = get_logger(__name__) -def minify_product_dataset(dataset_path: pathlib.Path, - output_path: pathlib.Path): - if dataset_path.suffix == '.gz': +def minify_product_dataset(dataset_path: pathlib.Path, output_path: pathlib.Path): + if dataset_path.suffix == ".gz": jsonl_iter_func = gzip_jsonl_iter else: jsonl_iter_func = jsonl_iter - with gzip.open(output_path, 'wt') as output_: + with gzip.open(output_path, "wt") as output_: for item in jsonl_iter_func(dataset_path): available_fields = Product.get_fields() - minified_item = dict(((field, value) - for (field, value) in item.items() - if field in available_fields)) - output_.write(json.dumps(minified_item) + '\n') + minified_item = dict( + ( + (field, value) + for (field, value) in item.items() + if field in available_fields + ) + ) + output_.write(json.dumps(minified_item) + "\n") def get_product_dataset_etag() -> Optional[str]: if not settings.JSONL_DATASET_ETAG_PATH.is_file(): return None - with open(settings.JSONL_DATASET_ETAG_PATH, 'r') as f: + with open(settings.JSONL_DATASET_ETAG_PATH, "r") as f: return f.readline() def save_product_dataset_etag(etag: str): - with open(settings.JSONL_DATASET_ETAG_PATH, 'w') as f: + with open(settings.JSONL_DATASET_ETAG_PATH, "w") as f: return f.write(etag) def fetch_dataset(): with tempfile.TemporaryDirectory() as tmp_dir: output_dir = pathlib.Path(tmp_dir) - output_path = output_dir / 'products.jsonl.gz' + output_path = output_dir / "products.jsonl.gz" etag = download_dataset(output_path) - minify_path = output_dir / 'products-min.jsonl.gz' + minify_path = output_dir / "products-min.jsonl.gz" logger.info("Minifying product JSONL") minify_product_dataset(output_path, minify_path) @@ -70,7 +73,7 @@ def has_dataset_changed() -> bool: if etag is not None: r = requests.head(settings.JSONL_DATASET_URL) - current_etag = r.headers.get('ETag', '').strip("'\"") + current_etag = r.headers.get("ETag", "").strip("'\"") if current_etag == etag: logger.info("Dataset ETag has not changed") @@ -80,14 +83,13 @@ def has_dataset_changed() -> bool: def download_dataset(output_path: os.PathLike) -> str: - r = requests.get(settings.JSONL_DATASET_URL, - stream=True) - current_etag = r.headers.get('ETag', '').strip("'\"") + r = requests.get(settings.JSONL_DATASET_URL, stream=True) + current_etag = r.headers.get("ETag", "").strip("'\"") logger.info("Dataset has changed, downloading file") logger.debug("Saving temporary file in {}".format(output_path)) - with open(output_path, 'wb') as f: + with open(output_path, "wb") as f: shutil.copyfileobj(r.raw, f) return current_etag @@ -100,53 +102,69 @@ def __init__(self, iterator: Iterable[JSONType]): def __iter__(self) -> Iterator[JSONType]: yield from self.iterator - def filter_by_country_tag(self, country_tag: str) -> 'ProductStream': - filtered = (product for product in self.iterator - if country_tag in (product.get('countries_tags') or [])) + def filter_by_country_tag(self, country_tag: str) -> "ProductStream": + filtered = ( + product + for product in self.iterator + if country_tag in (product.get("countries_tags") or []) + ) return ProductStream(filtered) - def filter_by_state_tag(self, state_tag: str) -> 'ProductStream': - filtered = (product for product in self.iterator - if state_tag in (product.get('states_tags') or [])) + def filter_by_state_tag(self, state_tag: str) -> "ProductStream": + filtered = ( + product + for product in self.iterator + if state_tag in (product.get("states_tags") or []) + ) return ProductStream(filtered) - def filter_nonempty_text_field(self, field: str) -> 'ProductStream': - filtered = (product for product in self.iterator - if (product.get(field) or "") != "") + def filter_nonempty_text_field(self, field: str) -> "ProductStream": + filtered = ( + product for product in self.iterator if (product.get(field) or "") != "" + ) return ProductStream(filtered) - def filter_empty_text_field(self, field: str) -> 'ProductStream': - filtered = (product for product in self.iterator - if not (product.get(field) or "") != "") + def filter_empty_text_field(self, field: str) -> "ProductStream": + filtered = ( + product for product in self.iterator if not (product.get(field) or "") != "" + ) return ProductStream(filtered) - def filter_nonempty_tag_field(self, field: str) -> 'ProductStream': - filtered = (product for product in self.iterator - if (product.get(field) or [])) + def filter_nonempty_tag_field(self, field: str) -> "ProductStream": + filtered = (product for product in self.iterator if (product.get(field) or [])) return ProductStream(filtered) - def filter_empty_tag_field(self, field: str) -> 'ProductStream': - filtered = (product for product in self.iterator - if not (product.get(field) or [])) + def filter_empty_tag_field(self, field: str) -> "ProductStream": + filtered = ( + product for product in self.iterator if not (product.get(field) or []) + ) return ProductStream(filtered) - def filter_by_modified_datetime(self, - from_t: Optional[datetime.datetime] = None, - to_t: Optional[datetime.datetime] = None): + def filter_by_modified_datetime( + self, + from_t: Optional[datetime.datetime] = None, + to_t: Optional[datetime.datetime] = None, + ): if from_t is None and to_t is None: raise ValueError("one of `from_t` or `to_t` must be provided") if from_t: from_timestamp = from_t.timestamp() - filtered = (product for product in self.iterator - if 'last_modified_t' in product and - product['last_modified_t'] >= from_timestamp) + filtered = ( + product + for product in self.iterator + if "last_modified_t" in product + and product["last_modified_t"] >= from_timestamp + ) elif to_t: to_timestamp = to_t.timestamp() - filtered = (product for product in self.iterator - if 'last_modified_t' in product and - product['last_modified_t'] <= to_timestamp) + filtered = ( + product + for product in self.iterator + if "last_modified_t" in product + and product["last_modified_t"] <= to_timestamp + ) return ProductStream(filtered) @@ -160,7 +178,7 @@ def take(self, count: int): def iter(self) -> Iterable[JSONType]: return iter(self) - def iter_product(self) -> Iterable['Product']: + def iter_product(self) -> Iterable["Product"]: for item in self: yield Product(item) @@ -189,33 +207,42 @@ def load(cls): class Product: """Product class.""" - __slots__ = ('barcode', 'countries_tags', 'categories_tags', - 'emb_codes_tags', 'labels_tags', 'quantity', 'expiration_date', - 'brands_tags', 'stores_tags') + + __slots__ = ( + "barcode", + "countries_tags", + "categories_tags", + "emb_codes_tags", + "labels_tags", + "quantity", + "expiration_date", + "brands_tags", + "stores_tags", + ) def __init__(self, product: JSONType): - self.barcode = product.get('code') - self.countries_tags = product.get('countries_tags') or [] - self.categories_tags = product.get('categories_tags') or [] - self.emb_codes_tags = product.get('emb_codes_tags') or [] - self.labels_tags = product.get('labels_tags') or [] - self.quantity = product.get('quantity') or None - self.expiration_date = product.get('expiration_date') or None - self.brands_tags = product.get('brands_tags') or [] - self.stores_tags = product.get('stores_tags') or [] + self.barcode = product.get("code") + self.countries_tags = product.get("countries_tags") or [] + self.categories_tags = product.get("categories_tags") or [] + self.emb_codes_tags = product.get("emb_codes_tags") or [] + self.labels_tags = product.get("labels_tags") or [] + self.quantity = product.get("quantity") or None + self.expiration_date = product.get("expiration_date") or None + self.brands_tags = product.get("brands_tags") or [] + self.stores_tags = product.get("stores_tags") or [] @staticmethod def get_fields(): return { - 'code', - 'countries_tags', - 'categories_tags', - 'emb_codes_tags', - 'labels_tags', - 'quantity', - 'expiration_date', - 'brands_tags', - 'stores_tags', + "code", + "countries_tags", + "categories_tags", + "emb_codes_tags", + "labels_tags", + "quantity", + "expiration_date", + "brands_tags", + "stores_tags", } @@ -223,7 +250,7 @@ class ProductStore: def __init__(self): self.store: Dict[str, Product] = {} - def load(self, path: str, reset: bool=True): + def load(self, path: str, reset: bool = True): logger.info("Loading product store") ds = ProductDataset(path) stream = ds.stream() diff --git a/robotoff/scheduler.py b/robotoff/scheduler.py index 0875c87925..e5f08b911f 100644 --- a/robotoff/scheduler.py +++ b/robotoff/scheduler.py @@ -2,25 +2,32 @@ import os from typing import Dict, Optional +import sentry_sdk from apscheduler.events import EVENT_JOB_ERROR -from apscheduler.jobstores.memory import MemoryJobStore from apscheduler.executors.pool import ThreadPoolExecutor +from apscheduler.jobstores.memory import MemoryJobStore from apscheduler.schedulers.blocking import BlockingScheduler +from sentry_sdk import capture_exception -from robotoff import slack, settings +from robotoff import settings, slack from robotoff.elasticsearch.category.predict import predict_from_dataset -from robotoff.insights.annotate import InsightAnnotatorFactory, UPDATED_ANNOTATION_RESULT +from robotoff.insights.annotate import ( + UPDATED_ANNOTATION_RESULT, + InsightAnnotatorFactory, +) from robotoff.insights.importer import CategoryImporter -from robotoff.insights.validator import InsightValidator, \ - InsightValidatorFactory +from robotoff.insights.validator import InsightValidator, InsightValidatorFactory from robotoff.models import ProductInsight, db -from robotoff.products import has_dataset_changed, fetch_dataset, \ - CACHED_PRODUCT_STORE, Product, ProductStore, ProductDataset +from robotoff.products import ( + CACHED_PRODUCT_STORE, + Product, + ProductDataset, + ProductStore, + fetch_dataset, + has_dataset_changed, +) from robotoff.utils import get_logger -import sentry_sdk -from sentry_sdk import capture_exception - if settings.SENTRY_DSN: sentry_sdk.init(settings.SENTRY_DSN) @@ -31,18 +38,27 @@ def process_insights(): processed = 0 with db: - for insight in (ProductInsight.select() - .where(ProductInsight.annotation.is_null(), - ProductInsight.process_after.is_null(False), - ProductInsight.process_after <= datetime.datetime.utcnow()) - .iterator()): + for insight in ( + ProductInsight.select() + .where( + ProductInsight.annotation.is_null(), + ProductInsight.process_after.is_null(False), + ProductInsight.process_after <= datetime.datetime.utcnow(), + ) + .iterator() + ): annotator = InsightAnnotatorFactory.get(insight.type) - logger.info("Annotating insight {} (product: {})".format(insight.id, insight.barcode)) + logger.info( + "Annotating insight {} (product: {})".format( + insight.id, insight.barcode + ) + ) annotation_result = annotator.annotate(insight, 1, update=True) processed += 1 - if (annotation_result == UPDATED_ANNOTATION_RESULT and - insight.data.get('notify', False)): + if annotation_result == UPDATED_ANNOTATION_RESULT and insight.data.get( + "notify", False + ): slack.notify_automatic_processing(insight) logger.info("{} insights processed".format(processed)) @@ -54,47 +70,57 @@ def refresh_insights(): product_store = CACHED_PRODUCT_STORE.get() datetime_threshold = datetime.datetime.utcnow().replace( - hour=0, minute=0, second=0, microsecond=0) + hour=0, minute=0, second=0, microsecond=0 + ) dataset_datetime = datetime.datetime.fromtimestamp( - os.path.getmtime(settings.JSONL_MIN_DATASET_PATH)) + os.path.getmtime(settings.JSONL_MIN_DATASET_PATH) + ) if dataset_datetime.date() != datetime_threshold.date(): - logger.warn("Dataset version is not up to date, aborting insight " - "removal job") + logger.warn( + "Dataset version is not up to date, aborting insight " "removal job" + ) return validators: Dict[str, InsightValidator] = {} with db: with db.atomic(): - for insight in (ProductInsight.select() - .where(ProductInsight.annotation.is_null(), - ProductInsight.timestamp <= datetime_threshold) - .iterator()): + for insight in ( + ProductInsight.select() + .where( + ProductInsight.annotation.is_null(), + ProductInsight.timestamp <= datetime_threshold, + ) + .iterator() + ): product: Product = product_store[insight.barcode] if product is None: # Product has been deleted from OFF - logger.info("Product with barcode {} deleted" - "".format(insight.barcode)) + logger.info( + "Product with barcode {} deleted" "".format(insight.barcode) + ) deleted += 1 insight.delete_instance() else: if insight.type not in validators: validators[insight.type] = InsightValidatorFactory.create( - insight.type, product_store) + insight.type, product_store + ) validator = validators[insight.type] insight_deleted = delete_invalid_insight(insight, validator) if insight_deleted: deleted += 1 - logger.info("invalid insight {} (type: {}), deleting..." - "".format(insight.id, insight.type)) + logger.info( + "invalid insight {} (type: {}), deleting..." + "".format(insight.id, insight.type) + ) continue - insight_updated = update_insight_attributes(product, - insight) + insight_updated = update_insight_attributes(product, insight) if insight_updated: updated += 1 @@ -103,20 +129,23 @@ def refresh_insights(): logger.info("{} insights updated".format(updated)) -def update_insight_attributes(product: Product, insight: ProductInsight) \ - -> bool: +def update_insight_attributes(product: Product, insight: ProductInsight) -> bool: to_update = False if insight.brands != product.brands_tags: - logger.info("Updating brand {} -> {} ({})".format( - insight.brands, product.brands_tags, - product.barcode)) + logger.info( + "Updating brand {} -> {} ({})".format( + insight.brands, product.brands_tags, product.barcode + ) + ) to_update = True insight.brands = product.brands_tags if insight.countries != product.countries_tags: - logger.info("Updating countries {} -> {} ({})".format( - insight.countries, product.countries_tags, - product.barcode)) + logger.info( + "Updating countries {} -> {} ({})".format( + insight.countries, product.countries_tags, product.barcode + ) + ) to_update = True insight.countries = product.countries_tags @@ -126,8 +155,9 @@ def update_insight_attributes(product: Product, insight: ProductInsight) \ return to_update -def delete_invalid_insight(insight: ProductInsight, - validator: Optional[InsightValidator]) -> bool: +def delete_invalid_insight( + insight: ProductInsight, validator: Optional[InsightValidator] +) -> bool: if validator is None: return False @@ -142,14 +172,22 @@ def mark_insights(): marked = 0 with db: with db.atomic(): - for insight in (ProductInsight.select() - .where(ProductInsight.automatic_processing == True, - ProductInsight.process_after.is_null(), - ProductInsight.annotation.is_null()) - .iterator()): - logger.info("Marking insight {} as processable automatically " - "(product: {})".format(insight.id, insight.barcode)) - insight.process_after = datetime.datetime.utcnow() + datetime.timedelta(minutes=10) + for insight in ( + ProductInsight.select() + .where( + ProductInsight.automatic_processing == True, + ProductInsight.process_after.is_null(), + ProductInsight.annotation.is_null(), + ) + .iterator() + ): + logger.info( + "Marking insight {} as processable automatically " + "(product: {})".format(insight.id, insight.barcode) + ) + insight.process_after = datetime.datetime.utcnow() + datetime.timedelta( + minutes=10 + ) insight.save() marked += 1 @@ -171,7 +209,8 @@ def generate_insights(): importer = CategoryImporter(product_store) datetime_threshold = datetime.datetime.utcnow().replace( - hour=0, minute=0, second=0, microsecond=0) - datetime.timedelta(days=1) + hour=0, minute=0, second=0, microsecond=0 + ) - datetime.timedelta(days=1) dataset = ProductDataset(settings.JSONL_DATASET_PATH) category_insights_iter = predict_from_dataset(dataset, datetime_threshold) @@ -188,15 +227,16 @@ def run(): scheduler = BlockingScheduler() scheduler.add_executor(ThreadPoolExecutor(20)) scheduler.add_jobstore(MemoryJobStore()) - scheduler.add_job(process_insights, 'interval', minutes=2, max_instances=1, - jitter=20) - scheduler.add_job(mark_insights, 'interval', minutes=2, max_instances=1, - jitter=20) - scheduler.add_job(download_product_dataset, 'cron', day='*', hour='3', - max_instances=1) - scheduler.add_job(refresh_insights, 'cron', day='*', hour='4', - max_instances=1) - scheduler.add_job(generate_insights, 'cron', day='*', hour='4', minute=15, - max_instances=1) + scheduler.add_job( + process_insights, "interval", minutes=2, max_instances=1, jitter=20 + ) + scheduler.add_job(mark_insights, "interval", minutes=2, max_instances=1, jitter=20) + scheduler.add_job( + download_product_dataset, "cron", day="*", hour="3", max_instances=1 + ) + scheduler.add_job(refresh_insights, "cron", day="*", hour="4", max_instances=1) + scheduler.add_job( + generate_insights, "cron", day="*", hour="4", minute=15, max_instances=1 + ) scheduler.add_listener(exception_listener, EVENT_JOB_ERROR) scheduler.start() diff --git a/robotoff/settings.py b/robotoff/settings.py index 938dda26df..ef1bfadfbf 100644 --- a/robotoff/settings.py +++ b/robotoff/settings.py @@ -1,36 +1,42 @@ -from pathlib import Path import os +from pathlib import Path from typing import Tuple PROJECT_DIR = Path(__file__).parent.parent -DATA_DIR = PROJECT_DIR / 'data' -DATASET_DIR = PROJECT_DIR / 'datasets' -I18N_DIR = PROJECT_DIR / 'i18n' -DATASET_PATH = DATASET_DIR / 'en.openfoodfacts.org.products.csv' -JSONL_DATASET_PATH = DATASET_DIR / 'products.jsonl.gz' -JSONL_DATASET_ETAG_PATH = DATASET_DIR / 'products-etag.txt' -JSONL_MIN_DATASET_PATH = DATASET_DIR / 'products-min.jsonl.gz' -JSONL_DATASET_URL = "https://static.openfoodfacts.org/data/openfoodfacts-products.jsonl.gz" +DATA_DIR = PROJECT_DIR / "data" +DATASET_DIR = PROJECT_DIR / "datasets" +I18N_DIR = PROJECT_DIR / "i18n" +DATASET_PATH = DATASET_DIR / "en.openfoodfacts.org.products.csv" +JSONL_DATASET_PATH = DATASET_DIR / "products.jsonl.gz" +JSONL_DATASET_ETAG_PATH = DATASET_DIR / "products-etag.txt" +JSONL_MIN_DATASET_PATH = DATASET_DIR / "products-min.jsonl.gz" +JSONL_DATASET_URL = ( + "https://static.openfoodfacts.org/data/openfoodfacts-products.jsonl.gz" +) -TAXONOMY_CATEGORY_URL = "https://static.openfoodfacts.org/data/taxonomies/categories.json" -TAXONOMY_INGREDIENT_URL = "https://static.openfoodfacts.org/data/taxonomies/ingredients.json" +TAXONOMY_CATEGORY_URL = ( + "https://static.openfoodfacts.org/data/taxonomies/categories.json" +) +TAXONOMY_INGREDIENT_URL = ( + "https://static.openfoodfacts.org/data/taxonomies/ingredients.json" +) TAXONOMY_LABEL_URL = "https://static.openfoodfacts.org/data/taxonomies/labels.json" OFF_IMAGE_BASE_URL = "https://static.openfoodfacts.org/images/products" OFF_BASE_WEBSITE_URL = "https://world.openfoodfacts.org" OFF_PASSWORD = os.environ.get("OFF_PASSWORD", "") OFF_SERVER_DOMAIN = "api.openfoodfacts.org" -TAXONOMY_DIR = DATA_DIR / 'taxonomies' -TAXONOMY_CATEGORY_PATH = TAXONOMY_DIR / 'categories.json' -TAXONOMY_INGREDIENT_PATH = TAXONOMY_DIR / 'ingredients.json' -TAXONOMY_LABEL_PATH = TAXONOMY_DIR / 'labels.json' +TAXONOMY_DIR = DATA_DIR / "taxonomies" +TAXONOMY_CATEGORY_PATH = TAXONOMY_DIR / "categories.json" +TAXONOMY_INGREDIENT_PATH = TAXONOMY_DIR / "ingredients.json" +TAXONOMY_LABEL_PATH = TAXONOMY_DIR / "labels.json" DB_NAME = os.environ.get("DB_NAME", "postgres") DB_USER = os.environ.get("DB_USER", "postgres") DB_PASSWORD = os.environ.get("DB_PASSWORD", "postgres") DB_HOST = os.environ.get("DB_HOST", "localhost") -IPC_AUTHKEY = os.environ.get("IPC_AUTHKEY", "IPC").encode('utf-8') +IPC_AUTHKEY = os.environ.get("IPC_AUTHKEY", "IPC").encode("utf-8") IPC_HOST = os.environ.get("IPC_HOST", "localhost") IPC_PORT = int(os.environ.get("IPC_PORT", 6650)) IPC_ADDRESS: Tuple[str, int] = (IPC_HOST, IPC_PORT) @@ -39,10 +45,10 @@ ELASTICSEARCH_HOSTS = os.environ.get("ELASTICSEARCH_HOSTS", "localhost:9200").split(",") ELASTICSEARCH_TYPE = "document" -ELASTICSEARCH_CATEGORY_INDEX = 'category' -ELASTICSEARCH_PRODUCT_INDEX = 'product' +ELASTICSEARCH_CATEGORY_INDEX = "category" +ELASTICSEARCH_PRODUCT_INDEX = "product" -SLACK_TOKEN = os.environ.get('SLACK_TOKEN', "") +SLACK_TOKEN = os.environ.get("SLACK_TOKEN", "") SLACK_OFF_TEST_CHANNEL = "CGLCKGVHS" SLACK_OFF_ROBOTOFF_ALERT_CHANNEL = "CGKPALRCG" SLACK_OFF_ROBOTOFF_USER_ALERT_CHANNEL = "CGWSXDGSF" @@ -51,15 +57,15 @@ SENTRY_DSN = os.environ.get("SENTRY_DSN") -OCR_DATA_DIR = DATA_DIR / 'ocr' -OCR_BRANDS_DATA_PATH = OCR_DATA_DIR / 'regex_brands.txt' -OCR_BRANDS_NOTIFY_WHITELIST_DATA_PATH = OCR_DATA_DIR / 'notify_whitelist_brands.txt' -OCR_LOGO_ANNOTATION_BRANDS_DATA_PATH = OCR_DATA_DIR / 'logo_annotation_brands.txt' -OCR_STORES_DATA_PATH = OCR_DATA_DIR / 'regex_stores.txt' -OCR_STORES_NOTIFY_WHITELIST_DATA_PATH = OCR_DATA_DIR / 'notify_whitelist_stores.txt' -OCR_LOGO_ANNOTATION_LABELS_DATA_PATH = OCR_DATA_DIR / 'logo_annotation_labels.txt' +OCR_DATA_DIR = DATA_DIR / "ocr" +OCR_BRANDS_DATA_PATH = OCR_DATA_DIR / "regex_brands.txt" +OCR_BRANDS_NOTIFY_WHITELIST_DATA_PATH = OCR_DATA_DIR / "notify_whitelist_brands.txt" +OCR_LOGO_ANNOTATION_BRANDS_DATA_PATH = OCR_DATA_DIR / "logo_annotation_brands.txt" +OCR_STORES_DATA_PATH = OCR_DATA_DIR / "regex_stores.txt" +OCR_STORES_NOTIFY_WHITELIST_DATA_PATH = OCR_DATA_DIR / "notify_whitelist_stores.txt" +OCR_LOGO_ANNOTATION_LABELS_DATA_PATH = OCR_DATA_DIR / "logo_annotation_labels.txt" ROBOTOFF_USER_AGENT = "Robotoff Live Analysis" # Models and ML -MODELS_DIR = PROJECT_DIR / 'models' +MODELS_DIR = PROJECT_DIR / "models" diff --git a/robotoff/slack.py b/robotoff/slack.py index 450c3d2b26..e256dc0828 100644 --- a/robotoff/slack.py +++ b/robotoff/slack.py @@ -13,12 +13,12 @@ BASE_URL = "https://slack.com/api" POST_MESSAGE_URL = BASE_URL + "/chat.postMessage" NUTRISCORE_LABELS = { - 'en:nutriscore', - 'en:nutriscore-a', - 'en:nutriscore-b', - 'en:nutriscore-c', - 'en:nutriscore-d', - 'en:nutriscore-e', + "en:nutriscore", + "en:nutriscore-a", + "en:nutriscore-b", + "en:nutriscore-c", + "en:nutriscore-d", + "en:nutriscore-e", } logger = get_logger(__name__) @@ -29,118 +29,115 @@ class SlackException(Exception): def notify_image_flag(insights: List[JSONType], source: str, barcode: str): - flags = ", ".join(["{} (score: {})".format(i['type'], i['likelihood']) - for i in insights]) - url = "{}/{}".format(settings.OFF_IMAGE_BASE_URL, - source) - edit_url = "{}/cgi/product.pl?type=edit&code={}" \ - "".format(settings.OFF_BASE_WEBSITE_URL, barcode) - text = ("Image flagged as {}: {}\nedit: {}".format( - flags, url, edit_url)) + flags = ", ".join( + ["{} (score: {})".format(i["type"], i["likelihood"]) for i in insights] + ) + url = "{}/{}".format(settings.OFF_IMAGE_BASE_URL, source) + edit_url = "{}/cgi/product.pl?type=edit&code={}" "".format( + settings.OFF_BASE_WEBSITE_URL, barcode + ) + text = "Image flagged as {}: {}\nedit: {}".format(flags, url, edit_url) post_message(text, settings.SLACK_OFF_ROBOTOFF_IMAGE_ALERT_CHANNEL) def notify_automatic_processing(insight: ProductInsight): - product_url = "{}/product/{}".format(settings.OFF_BASE_WEBSITE_URL, - insight.barcode) + product_url = "{}/product/{}".format(settings.OFF_BASE_WEBSITE_URL, insight.barcode) source_image = insight.source_image if source_image: image_url = "https://static.openfoodfacts.org/images/products" + source_image - metadata_text = "(<{}|product>, <{}|source image>)".format(product_url, image_url) + metadata_text = "(<{}|product>, <{}|source image>)".format( + product_url, image_url + ) else: metadata_text = "(<{}|product>)".format(product_url) if insight.type == InsightType.label.name: - text = ("The `{}` label was automatically added to product {}" - "".format(insight.value_tag, - insight.barcode)) + text = "The `{}` label was automatically added to product {}" "".format( + insight.value_tag, insight.barcode + ) elif insight.type == InsightType.product_weight.name: - text = ("The weight `{}` (match: `{}`) was automatically added to " - "product {}" - "".format(insight.data['text'], - insight.data['raw'], - insight.barcode)) + text = ( + "The weight `{}` (match: `{}`) was automatically added to " + "product {}" + "".format(insight.data["text"], insight.data["raw"], insight.barcode) + ) elif insight.type == InsightType.packager_code.name: - text = ("The `{}` packager code was automatically added to " - "product {}".format(insight.data['text'], - insight.barcode)) + text = "The `{}` packager code was automatically added to " "product {}".format( + insight.data["text"], insight.barcode + ) elif insight.type == InsightType.expiration_date.name: - text = ("The expiration date `{}` (match: `{}`) was automatically added to " - "product {}".format(insight.data['text'], - insight.data['raw'], - insight.barcode)) + text = ( + "The expiration date `{}` (match: `{}`) was automatically added to " + "product {}".format( + insight.data["text"], insight.data["raw"], insight.barcode + ) + ) elif insight.type == InsightType.brand.name: - text = ("The `{}` brand was automatically added to " - "product {}".format(insight.data['brand'], - insight.barcode)) + text = "The `{}` brand was automatically added to " "product {}".format( + insight.data["brand"], insight.barcode + ) elif insight.type == InsightType.store.name: - text = ("The `{}` store was automatically added to " - "product {}".format(insight.data['store'], - insight.barcode)) + text = "The `{}` store was automatically added to " "product {}".format( + insight.data["store"], insight.barcode + ) else: return text += " " + metadata_text slack_kwargs = { - 'unfurl_links': False, - 'unfurl_media': False, + "unfurl_links": False, + "unfurl_media": False, } if insight.value_tag in NUTRISCORE_LABELS: - post_message(text, settings.SLACK_OFF_NUTRISCORE_ALERT_CHANNEL, - **slack_kwargs) + post_message(text, settings.SLACK_OFF_NUTRISCORE_ALERT_CHANNEL, **slack_kwargs) return - post_message(text, settings.SLACK_OFF_ROBOTOFF_ALERT_CHANNEL, - **slack_kwargs) + post_message(text, settings.SLACK_OFF_ROBOTOFF_ALERT_CHANNEL, **slack_kwargs) def get_base_params() -> JSONType: return { - 'username': "robotoff-bot", - 'token': settings.SLACK_TOKEN, - 'icon_url': "https://s3-us-west-2.amazonaws.com/slack-files2/" - "bot_icons/2019-03-01/565595869687_48.png", + "username": "robotoff-bot", + "token": settings.SLACK_TOKEN, + "icon_url": "https://s3-us-west-2.amazonaws.com/slack-files2/" + "bot_icons/2019-03-01/565595869687_48.png", } def raise_if_slack_token_undefined(): if settings.SLACK_TOKEN is None: - raise ValueError("The bot Slack token must be passed in the SLACK_" - "TOKEN environment variable") + raise ValueError( + "The bot Slack token must be passed in the SLACK_" + "TOKEN environment variable" + ) -def post_message(text: str, - channel: str, - attachments: Optional[List[JSONType]] = None, - **kwargs): +def post_message( + text: str, channel: str, attachments: Optional[List[JSONType]] = None, **kwargs +): try: _post_message(text, channel, attachments, **kwargs) except Exception as e: - logger.error("An exception occurred when sending a Slack " - "notification", exc_info=e) + logger.error( + "An exception occurred when sending a Slack " "notification", exc_info=e + ) -def _post_message(text: str, - channel: str, - attachments: Optional[List[JSONType]] = None, - **kwargs): +def _post_message( + text: str, channel: str, attachments: Optional[List[JSONType]] = None, **kwargs +): raise_if_slack_token_undefined() - params: JSONType = { - **get_base_params(), - 'channel': channel, - 'text': text, - **kwargs - } + params: JSONType = {**get_base_params(), "channel": channel, "text": text, **kwargs} if attachments: - params['attachments'] = attachments + params["attachments"] = attachments r = http_session.post(POST_MESSAGE_URL, data=params) response_json = get_slack_json(r) @@ -151,12 +148,13 @@ def get_slack_json(response: requests.Response) -> JSONType: json_data = response.json() if not response.ok: - raise SlackException("Non-200 status code from Slack: " - "{}, response: {}" - "".format(response.status_code, - json_data)) + raise SlackException( + "Non-200 status code from Slack: " + "{}, response: {}" + "".format(response.status_code, json_data) + ) - if not json_data.get('ok', False): + if not json_data.get("ok", False): raise SlackException("Non-ok response: {}".format(json_data)) return json_data diff --git a/robotoff/taxonomy.py b/robotoff/taxonomy.py index 5af0bd47b2..f1b6ed3c69 100644 --- a/robotoff/taxonomy.py +++ b/robotoff/taxonomy.py @@ -2,7 +2,7 @@ import functools import json from enum import Enum -from typing import List, Dict, Iterable, Optional, Set +from typing import Dict, Iterable, List, Optional, Set import requests @@ -18,16 +18,15 @@ class TaxonomyType(Enum): class TaxonomyNode: - __slots__ = ('id', 'names', 'parents', 'children') + __slots__ = ("id", "names", "parents", "children") - def __init__(self, identifier: str, - names: List[Dict[str, str]]): + def __init__(self, identifier: str, names: List[Dict[str, str]]): self.id: str = identifier self.names: Dict[str, str] = names - self.parents: List['TaxonomyNode'] = [] - self.children: List['TaxonomyNode'] = [] + self.parents: List["TaxonomyNode"] = [] + self.children: List["TaxonomyNode"] = [] - def is_child_of(self, item: 'TaxonomyNode'): + def is_child_of(self, item: "TaxonomyNode"): if not self.parents: return False @@ -48,17 +47,14 @@ def get_localized_name(self, lang: str) -> str: return self.id - def add_parents(self, parents: Iterable['TaxonomyNode']): + def add_parents(self, parents: Iterable["TaxonomyNode"]): for parent in parents: if parent not in self.parents: self.parents.append(parent) parent.children.append(self) def to_dict(self) -> JSONType: - return { - 'name': self.names, - 'parents': [p.id for p in self.parents] - } + return {"name": self.names, "parents": [p.id for p in self.parents]} def __repr__(self): return "" % self.id @@ -108,9 +104,7 @@ def find_deepest_item(self, keys: List[str]) -> Optional[str]: return [key for key in keys if key not in excluded][0] - def is_parent_of_any(self, - item: str, - candidates: Iterable[str]) -> bool: + def is_parent_of_any(self, item: str, candidates: Iterable[str]) -> bool: node: TaxonomyNode = self[item] if node is None: @@ -145,33 +139,31 @@ def to_dict(self) -> JSONType: return export @classmethod - def from_dict(cls, data: JSONType) -> 'Taxonomy': + def from_dict(cls, data: JSONType) -> "Taxonomy": taxonomy = Taxonomy() for key, key_data in data.items(): if key not in taxonomy: - node = TaxonomyNode(identifier=key, - names=key_data.get('name', {})) + node = TaxonomyNode(identifier=key, names=key_data.get("name", {})) taxonomy.add(key, node) for key, key_data in data.items(): node = taxonomy[key] - parents = [taxonomy[ref] - for ref in key_data.get('parents', [])] + parents = [taxonomy[ref] for ref in key_data.get("parents", [])] node.add_parents(parents) return taxonomy @classmethod def from_json(cls, file_path: str): - with open(file_path, 'r') as f: + with open(file_path, "r") as f: data = json.load(f) return cls.from_dict(data) -def generate_category_hierarchy(taxonomy: Taxonomy, - category_to_index: Dict[str, int], - root: int): +def generate_category_hierarchy( + taxonomy: Taxonomy, category_to_index: Dict[str, int], root: int +): categories_hierarchy = collections.defaultdict(set) for node in taxonomy.iter_nodes(): @@ -180,23 +172,26 @@ def generate_category_hierarchy(taxonomy: Taxonomy, if not node.parents: categories_hierarchy[root].add(category_index) - children_indexes = set([category_to_index[c.id] - for c in node.children - if c.id in category_to_index]) + children_indexes = set( + [ + category_to_index[c.id] + for c in node.children + if c.id in category_to_index + ] + ) - categories_hierarchy[category_index] = \ - categories_hierarchy[category_index].union(children_indexes) + categories_hierarchy[category_index] = categories_hierarchy[ + category_index + ].union(children_indexes) categories_hierarchy_list = {} for category in categories_hierarchy.keys(): - categories_hierarchy_list[category] = \ - list(categories_hierarchy[category]) + categories_hierarchy_list[category] = list(categories_hierarchy[category]) return categories_hierarchy_list -def fetch_taxonomy(url: str, fallback_path: str, offline=False) \ - -> Optional[Taxonomy]: +def fetch_taxonomy(url: str, fallback_path: str, offline=False) -> Optional[Taxonomy]: if offline: return Taxonomy.from_json(fallback_path) @@ -213,19 +208,25 @@ def fetch_taxonomy(url: str, fallback_path: str, offline=False) \ TAXONOMY_STORES: Dict[str, CachedStore] = { - TaxonomyType.category.name: - CachedStore(functools.partial(fetch_taxonomy, - url=settings.TAXONOMY_CATEGORY_URL, - fallback_path= - settings.TAXONOMY_CATEGORY_PATH)), - TaxonomyType.ingredient.name: - CachedStore(functools.partial(fetch_taxonomy, - url=settings.TAXONOMY_INGREDIENT_URL, - fallback_path= - settings.TAXONOMY_INGREDIENT_PATH)), - TaxonomyType.label.name: - CachedStore(functools.partial(fetch_taxonomy, - url=settings.TAXONOMY_LABEL_URL, - fallback_path= - settings.TAXONOMY_LABEL_PATH)) + TaxonomyType.category.name: CachedStore( + functools.partial( + fetch_taxonomy, + url=settings.TAXONOMY_CATEGORY_URL, + fallback_path=settings.TAXONOMY_CATEGORY_PATH, + ) + ), + TaxonomyType.ingredient.name: CachedStore( + functools.partial( + fetch_taxonomy, + url=settings.TAXONOMY_INGREDIENT_URL, + fallback_path=settings.TAXONOMY_INGREDIENT_PATH, + ) + ), + TaxonomyType.label.name: CachedStore( + functools.partial( + fetch_taxonomy, + url=settings.TAXONOMY_LABEL_URL, + fallback_path=settings.TAXONOMY_LABEL_PATH, + ) + ), } diff --git a/robotoff/utils/__init__.py b/robotoff/utils/__init__.py index e89d174310..f19046ebaa 100644 --- a/robotoff/utils/__init__.py +++ b/robotoff/utils/__init__.py @@ -3,12 +3,11 @@ import logging import os import pathlib +import sys import tempfile +from typing import Dict, Iterable, Optional, Union import requests -import sys -from typing import Union, Iterable, Dict, Optional - from PIL import Image @@ -23,59 +22,62 @@ def get_logger(name=None, level: str = "INFO"): def configure_root_logger(logger, level: str = "INFO"): - log_level = os.environ.get('LOG_LEVEL', "INFO").upper() + log_level = os.environ.get("LOG_LEVEL", "INFO").upper() - if log_level not in ( - "DEBUG", "INFO", "WARNING", "ERROR", "FATAL", "CRITICAL"): - print("Unknown log level: {}, fallback " - "to INFO".format(log_level), file=sys.stderr) + if log_level not in ("DEBUG", "INFO", "WARNING", "ERROR", "FATAL", "CRITICAL"): + print( + "Unknown log level: {}, fallback " "to INFO".format(log_level), + file=sys.stderr, + ) log_level = level logger.setLevel(log_level) handler = logging.StreamHandler() - formatter = logging.Formatter('%(asctime)s :: %(processName)s :: ' - '%(threadName)s :: %(levelname)s :: ' - '%(message)s') + formatter = logging.Formatter( + "%(asctime)s :: %(processName)s :: " + "%(threadName)s :: %(levelname)s :: " + "%(message)s" + ) handler.setFormatter(formatter) handler.setLevel(log_level) logger.addHandler(handler) def jsonl_iter(jsonl_path: Union[str, pathlib.Path]) -> Iterable[Dict]: - with open(jsonl_path, 'r') as f: + with open(jsonl_path, "r") as f: yield from jsonl_iter_fp(f) def gzip_jsonl_iter(jsonl_path: Union[str, pathlib.Path]) -> Iterable[Dict]: - with gzip.open(jsonl_path, 'rt') as f: + with gzip.open(jsonl_path, "rt") as f: yield from jsonl_iter_fp(f) def jsonl_iter_fp(fp) -> Iterable[Dict]: for line in fp: - line = line.strip('\n') + line = line.strip("\n") if line: yield json.loads(line) -def dump_jsonl(filepath: Union[str, pathlib.Path], - json_iter: Iterable[Dict]): - with open(str(filepath), 'w') as f: +def dump_jsonl(filepath: Union[str, pathlib.Path], json_iter: Iterable[Dict]): + with open(str(filepath), "w") as f: for item in json_iter: f.write(json.dumps(item) + "\n") def text_file_iter(filepath: Union[str, pathlib.Path]) -> Iterable[str]: - with open(str(filepath), 'r') as f: + with open(str(filepath), "r") as f: for item in f: - item = item.strip('\n') + item = item.strip("\n") if item: yield item -def get_image_from_url(image_url: str, - error_raise: bool = False) -> Optional[Image.Image]: +def get_image_from_url( + image_url: str, error_raise: bool = False +) -> Optional[Image.Image]: r = requests.get(image_url) if error_raise: diff --git a/robotoff/utils/cache.py b/robotoff/utils/cache.py index ece8dad411..5f3bbacc0d 100644 --- a/robotoff/utils/cache.py +++ b/robotoff/utils/cache.py @@ -1,6 +1,6 @@ import abc import datetime -from typing import Optional, Callable +from typing import Callable, Optional from robotoff.utils import get_logger @@ -8,23 +8,24 @@ class CachedStore(metaclass=abc.ABCMeta): - def __init__(self, - fetch_func: Callable, - expiration_timedelta: Optional[datetime.timedelta] = None): + def __init__( + self, + fetch_func: Callable, + expiration_timedelta: Optional[datetime.timedelta] = None, + ): self.store = None self.expires_after: Optional[datetime.datetime] = None self.fetch_func: Callable = fetch_func - self.expiration_timedelta = (expiration_timedelta or - datetime.timedelta(minutes=30)) + self.expiration_timedelta = expiration_timedelta or datetime.timedelta( + minutes=30 + ) def get(self, **kwargs): - if (self.store is None or - datetime.datetime.utcnow() >= self.expires_after): + if self.store is None or datetime.datetime.utcnow() >= self.expires_after: if self.store is not None: logger.info("ProductStore expired, reloading...") - self.expires_after = (datetime.datetime.utcnow() + - self.expiration_timedelta) + self.expires_after = datetime.datetime.utcnow() + self.expiration_timedelta self.store = self.fetch_func(**kwargs) return self.store diff --git a/robotoff/utils/es.py b/robotoff/utils/es.py index 34c00b7315..76ce25afab 100644 --- a/robotoff/utils/es.py +++ b/robotoff/utils/es.py @@ -1,8 +1,7 @@ import json -from typing import Iterable, Dict, Tuple +from typing import Dict, Iterable, Tuple import elasticsearch - from robotoff import settings @@ -10,24 +9,14 @@ def get_es_client(): return elasticsearch.Elasticsearch(settings.ELASTICSEARCH_HOSTS) -def perform_export(client, - data: Iterable[Tuple[str, Dict]], - index: str, - batch_size=100): +def perform_export( + client, data: Iterable[Tuple[str, Dict]], index: str, batch_size=100 +): batch = [] rows_inserted = 0 for id_, item in data: - batch.append( - ( - { - 'index': { - '_id': id_ - } - }, - item - ) - ) + batch.append(({"index": {"_id": id_}}, item)) if len(batch) >= batch_size: insert_batch(client, batch, index) @@ -44,19 +33,16 @@ def perform_export(client, def insert_batch(client, batch: Iterable[Tuple[Dict, Dict]], index: str): body = "" for action, source in batch: - body += "{}\n{}\n".format(json.dumps(action), - json.dumps(source)) + body += "{}\n{}\n".format(json.dumps(action), json.dumps(source)) - client.bulk(body=body, - index=index, - doc_type=settings.ELASTICSEARCH_TYPE) + client.bulk(body=body, index=index, doc_type=settings.ELASTICSEARCH_TYPE) def generate_msearch_body(index: str, queries: Iterable[Dict]): lines = [] for query in queries: - lines.append(json.dumps({'index': index})) + lines.append(json.dumps({"index": index})) lines.append(json.dumps(query)) - return '\n'.join(lines) + return "\n".join(lines) diff --git a/robotoff/utils/i18n.py b/robotoff/utils/i18n.py index a3096a842c..0e1800fc2c 100644 --- a/robotoff/utils/i18n.py +++ b/robotoff/utils/i18n.py @@ -1,15 +1,15 @@ import gettext -from typing import Set, Dict, Optional +from typing import Dict, Optional, Set from robotoff import settings class TranslationStore: SUPPORTED_LANGUAGES: Set[str] = { - 'fr', - 'es', - 'it', - 'de', + "fr", + "es", + "it", + "de", } def __init__(self): @@ -17,9 +17,9 @@ def __init__(self): def load(self): for lang in self.SUPPORTED_LANGUAGES: - t = gettext.translation('robotoff', - str(settings.I18N_DIR), - languages=[lang]) + t = gettext.translation( + "robotoff", str(settings.I18N_DIR), languages=[lang] + ) if t is not None: self.translations[lang] = t diff --git a/robotoff/utils/text.py b/robotoff/utils/text.py index 31cfe69106..48056d2931 100644 --- a/robotoff/utils/text.py +++ b/robotoff/utils/text.py @@ -17,5 +17,5 @@ def strip_accents_ascii(s): strip_accents_unicode Remove accentuated char for any unicode symbol. """ - nkfd_form = unicodedata.normalize('NFKD', s) - return nkfd_form.encode('ASCII', 'ignore').decode('ASCII') + nkfd_form = unicodedata.normalize("NFKD", s) + return nkfd_form.encode("ASCII", "ignore").decode("ASCII") diff --git a/robotoff/workers/client.py b/robotoff/workers/client.py index 90730aad10..72106be472 100644 --- a/robotoff/workers/client.py +++ b/robotoff/workers/client.py @@ -10,14 +10,15 @@ def send_ipc_event(event_type: str, meta: Dict = None): meta = meta or {} - logger.info("Connecting listener server on {}:{}" - "".format(*settings.IPC_ADDRESS)) - with Client(settings.IPC_ADDRESS, - authkey=settings.IPC_AUTHKEY, - family='AF_INET') as conn: + logger.info("Connecting listener server on {}:{}" "".format(*settings.IPC_ADDRESS)) + with Client( + settings.IPC_ADDRESS, authkey=settings.IPC_AUTHKEY, family="AF_INET" + ) as conn: logger.info("Sending event through IPC") - conn.send({ - 'type': event_type, - 'meta': meta, - }) + conn.send( + { + "type": event_type, + "meta": meta, + } + ) logger.info("IPC event sent") diff --git a/robotoff/workers/listener.py b/robotoff/workers/listener.py index be7a2251ad..04b9ab46e0 100644 --- a/robotoff/workers/listener.py +++ b/robotoff/workers/listener.py @@ -2,13 +2,13 @@ from multiprocessing.pool import Pool from typing import Dict +import sentry_sdk +from sentry_sdk import capture_exception + from robotoff import settings from robotoff.utils import get_logger from robotoff.workers.tasks import run_task -import sentry_sdk -from sentry_sdk import capture_exception - if settings.SENTRY_DSN: sentry_sdk.init(settings.SENTRY_DSN) @@ -19,22 +19,21 @@ def run(): pool: Pool = Pool(settings.WORKER_COUNT) - logger.info("Starting listener server on {}:{}" - "".format(*settings.IPC_ADDRESS)) + logger.info("Starting listener server on {}:{}" "".format(*settings.IPC_ADDRESS)) logger.info("Starting listener server") - with Listener(settings.IPC_ADDRESS, - authkey=settings.IPC_AUTHKEY, - family='AF_INET') as listener: + with Listener( + settings.IPC_ADDRESS, authkey=settings.IPC_AUTHKEY, family="AF_INET" + ) as listener: while True: try: logger.info("Waiting for a connection...") with listener.accept() as conn: event = conn.recv() - event_type: str = event['type'] + event_type: str = event["type"] logger.info("New '{}' event received".format(event_type)) - event_kwargs: Dict = event.get('meta', {}) + event_kwargs: Dict = event.get("meta", {}) logger.info("Sending task to pool...") pool.apply_async(run_task, (event_type, event_kwargs)) diff --git a/robotoff/workers/tasks.py b/robotoff/workers/tasks.py index cf273939c5..96c4f0ba07 100644 --- a/robotoff/workers/tasks.py +++ b/robotoff/workers/tasks.py @@ -1,18 +1,17 @@ import json import logging import multiprocessing -from typing import List, Dict, Callable +from typing import Callable, Dict, List from robotoff.elasticsearch.category.predict import predict_from_product from robotoff.insights._enum import InsightType -from robotoff.insights.importer import InsightImporterFactory, InsightImporter from robotoff.insights.extraction import get_insights_from_image -from robotoff.models import db, ProductInsight +from robotoff.insights.importer import InsightImporter, InsightImporterFactory +from robotoff.models import ProductInsight, db from robotoff.off import get_product -from robotoff.products import (has_dataset_changed, fetch_dataset, - CACHED_PRODUCT_STORE) +from robotoff.products import CACHED_PRODUCT_STORE, fetch_dataset, has_dataset_changed from robotoff.slack import notify_image_flag -from robotoff.utils import get_logger, configure_root_logger +from robotoff.utils import configure_root_logger, get_logger from robotoff.utils.types import JSONType logger = get_logger(__name__) @@ -39,21 +38,23 @@ def download_product_dataset(): fetch_dataset() -def import_insights(insight_type: str, - items: List[str]): +def import_insights(insight_type: str, items: List[str]): product_store = CACHED_PRODUCT_STORE.get() - importer: InsightImporter = InsightImporterFactory.create(insight_type, - product_store) + importer: InsightImporter = InsightImporterFactory.create( + insight_type, product_store + ) with db.atomic(): - imported = importer.import_insights((json.loads(l) for l in items), - automatic=False) + imported = importer.import_insights( + (json.loads(l) for l in items), automatic=False + ) logger.info("Import finished, {} insights imported".format(imported)) def import_image(barcode: str, image_url: str, ocr_url: str): - logger.info("Detect insights for product {}, " - "image {}".format(barcode, image_url)) + logger.info( + "Detect insights for product {}, " "image {}".format(barcode, image_url) + ) product_store = CACHED_PRODUCT_STORE.get() insights_all = get_insights_from_image(barcode, image_url, ocr_url) @@ -62,14 +63,15 @@ def import_image(barcode: str, image_url: str, ocr_url: str): for insight_type, insights in insights_all.items(): if insight_type == InsightType.image_flag.name: - notify_image_flag(insights['insights'], - insights['source'], - insights['barcode']) + notify_image_flag( + insights["insights"], insights["source"], insights["barcode"] + ) continue logger.info("Extracting {}".format(insight_type)) - importer: InsightImporter = InsightImporterFactory.create(insight_type, - product_store) + importer: InsightImporter = InsightImporterFactory.create( + insight_type, product_store + ) with db.atomic(): imported = importer.import_insights([insights], automatic=True) @@ -77,11 +79,13 @@ def import_image(barcode: str, image_url: str, ocr_url: str): def delete_product_insights(barcode: str): - logger.info("Product {} deleted, deleting associated " - "insights...".format(barcode)) + logger.info( + "Product {} deleted, deleting associated " "insights...".format(barcode) + ) with db.atomic(): - deleted = (ProductInsight.delete() - .where(ProductInsight.barcode == barcode).execute()) + deleted = ( + ProductInsight.delete().where(ProductInsight.barcode == barcode).execute() + ) logger.info("{} insights deleted".format(deleted)) @@ -98,9 +102,8 @@ def updated_product_update_insights(barcode: str): logger.info("Product {} updated".format(barcode)) -def updated_product_add_category_insight(barcode: str, - product: JSONType) -> bool: - if product.get('categories_tags', []): +def updated_product_add_category_insight(barcode: str, product: JSONType) -> bool: + if product.get("categories_tags", []): return False insight = predict_from_product(product) @@ -109,8 +112,7 @@ def updated_product_add_category_insight(barcode: str, return False product_store = CACHED_PRODUCT_STORE.get() - importer = InsightImporterFactory.create(InsightType.category.name, - product_store) + importer = InsightImporterFactory.create(InsightType.category.name, product_store) imported = importer.import_insights([insight], automatic=False) @@ -121,9 +123,9 @@ def updated_product_add_category_insight(barcode: str, EVENT_MAPPING: Dict[str, Callable] = { - 'import_insights': import_insights, - 'import_image': import_image, - 'download_dataset': download_product_dataset, - 'product_deleted': delete_product_insights, - 'product_updated': updated_product_update_insights, + "import_insights": import_insights, + "import_image": import_image, + "download_dataset": download_product_dataset, + "product_deleted": delete_product_insights, + "product_updated": updated_product_update_insights, } diff --git a/tests/insights/ocr/test_brand.py b/tests/insights/ocr/test_brand.py index 15dd54a3ec..780464a292 100644 --- a/tests/insights/ocr/test_brand.py +++ b/tests/insights/ocr/test_brand.py @@ -3,22 +3,25 @@ from robotoff.insights.ocr.brand import BRAND_REGEX -@pytest.mark.parametrize('input_str,is_match', [ - ("other string", False), - ("carre", False), - ("carrefour", True), - ("monoprix p'tit prix", True), - ("marks & spencer", True), - ("nestlé", True), - ("nestle", True), - ("carrefour gaby", True), - ("carrefour baby", True), - ("dr. oetker", True), - ("dr oetker", True), - ("m-budget", True), - ("la belle iloise", True), - ("la belle-îloise", True), -]) +@pytest.mark.parametrize( + "input_str,is_match", + [ + ("other string", False), + ("carre", False), + ("carrefour", True), + ("monoprix p'tit prix", True), + ("marks & spencer", True), + ("nestlé", True), + ("nestle", True), + ("carrefour gaby", True), + ("carrefour baby", True), + ("dr. oetker", True), + ("dr oetker", True), + ("m-budget", True), + ("la belle iloise", True), + ("la belle-îloise", True), + ], +) def test_brand_regex(input_str: str, is_match: bool): regex = BRAND_REGEX.regex assert (regex.match(input_str) is not None) == is_match diff --git a/tests/insights/ocr/test_image_orientation.py b/tests/insights/ocr/test_image_orientation.py index 5f7be60acf..dea5557538 100644 --- a/tests/insights/ocr/test_image_orientation.py +++ b/tests/insights/ocr/test_image_orientation.py @@ -4,28 +4,38 @@ def generate_bounding_poly(*items): - vertices = [{'x': item[0], 'y': item[1]} - for item in items] - data = { - 'vertices': vertices - } + vertices = [{"x": item[0], "y": item[1]} for item in items] + data = {"vertices": vertices} return BoundingPoly(data) class TestBoundingPoly: - @pytest.mark.parametrize('bounding_poly,orientation', [ - (generate_bounding_poly((66, 458), (60, 348), (94, 346), (100, 456)), - ImageOrientation.left), - (generate_bounding_poly((66, 458), (60, 340), (94, 346), (100, 456)), - ImageOrientation.left), - (generate_bounding_poly((1106, 414), (1178, 421), (1175, 446), - (1103, 439)), - ImageOrientation.up), - (generate_bounding_poly((1106, 421), (1178, 414), (1175, 446), - (1103, 439)), - ImageOrientation.up), - ]) - def test_detect_orientation(self, - bounding_poly: BoundingPoly, - orientation: ImageOrientation): + @pytest.mark.parametrize( + "bounding_poly,orientation", + [ + ( + generate_bounding_poly((66, 458), (60, 348), (94, 346), (100, 456)), + ImageOrientation.left, + ), + ( + generate_bounding_poly((66, 458), (60, 340), (94, 346), (100, 456)), + ImageOrientation.left, + ), + ( + generate_bounding_poly( + (1106, 414), (1178, 421), (1175, 446), (1103, 439) + ), + ImageOrientation.up, + ), + ( + generate_bounding_poly( + (1106, 421), (1178, 414), (1175, 446), (1103, 439) + ), + ImageOrientation.up, + ), + ], + ) + def test_detect_orientation( + self, bounding_poly: BoundingPoly, orientation: ImageOrientation + ): assert bounding_poly.detect_orientation() == orientation diff --git a/tests/insights/ocr/test_product_weight.py b/tests/insights/ocr/test_product_weight.py index 73b23da308..831dfe45b1 100644 --- a/tests/insights/ocr/test_product_weight.py +++ b/tests/insights/ocr/test_product_weight.py @@ -4,17 +4,20 @@ from robotoff.insights.ocr.product_weight import PRODUCT_WEIGHT_REGEX -@pytest.mark.parametrize('input_str,is_match', [ - ("poids net à l'emballage: 500g", True), - ("poids 2kg", True), - ("poids 2kgv", False), - ("net wt. 1.4 fl oz", True), - ("other string", False), - ("1.4 g", False), - ("2 l", False), -]) +@pytest.mark.parametrize( + "input_str,is_match", + [ + ("poids net à l'emballage: 500g", True), + ("poids 2kg", True), + ("poids 2kgv", False), + ("net wt. 1.4 fl oz", True), + ("other string", False), + ("1.4 g", False), + ("2 l", False), + ], +) def test_product_weight_with_mention_regex(input_str: str, is_match: bool): - with_mention_ocr_regex: OCRRegex = PRODUCT_WEIGHT_REGEX['with_mention'] + with_mention_ocr_regex: OCRRegex = PRODUCT_WEIGHT_REGEX["with_mention"] with_mention_regex = with_mention_ocr_regex.regex assert (with_mention_regex.match(input_str) is not None) == is_match diff --git a/tests/insights/test_annotate.py b/tests/insights/test_annotate.py index fdb7e6f76d..1d1f47fb02 100644 --- a/tests/insights/test_annotate.py +++ b/tests/insights/test_annotate.py @@ -1,14 +1,28 @@ - import pytest from robotoff.insights.annotate import IngredientSpellcheckAnnotator class TestIngredientSpellCheckAnnotator: - @pytest.mark.parametrize('ingredient_str,start_offset,end_offset,correction,expected', [ - ("fqrine de blé complet", 0, 6, "farine", "farine de blé complet"), - ("farine de blé complet, paudre à lever", 23, 29, "poudre", "farine de blé complet, poudre à lever"), - ]) - def test_generate_full_correction(self, ingredient_str, start_offset, end_offset, correction, expected): - assert IngredientSpellcheckAnnotator.generate_full_correction(ingredient_str, - start_offset, end_offset, correction) == expected + @pytest.mark.parametrize( + "ingredient_str,start_offset,end_offset,correction,expected", + [ + ("fqrine de blé complet", 0, 6, "farine", "farine de blé complet"), + ( + "farine de blé complet, paudre à lever", + 23, + 29, + "poudre", + "farine de blé complet, poudre à lever", + ), + ], + ) + def test_generate_full_correction( + self, ingredient_str, start_offset, end_offset, correction, expected + ): + assert ( + IngredientSpellcheckAnnotator.generate_full_correction( + ingredient_str, start_offset, end_offset, correction + ) + == expected + ) diff --git a/tests/insights/test_importer.py b/tests/insights/test_importer.py index 0cbc02a468..a588a59c17 100644 --- a/tests/insights/test_importer.py +++ b/tests/insights/test_importer.py @@ -4,13 +4,15 @@ class TestBrandInsightImporter: - @pytest.mark.parametrize('barcode,brand_tag,is_valid', [ - ("5400141651306", "boni", True), - ("5400142968395", "boni", False), - ("3406790524499", "boni", False), - ("025252", "boni", False), - ("3406790524499", "unknown-brand", True), - - ]) + @pytest.mark.parametrize( + "barcode,brand_tag,is_valid", + [ + ("5400141651306", "boni", True), + ("5400142968395", "boni", False), + ("3406790524499", "boni", False), + ("025252", "boni", False), + ("3406790524499", "unknown-brand", True), + ], + ) def test_generate_full_correction(self, barcode, brand_tag, is_valid): assert BrandInsightImporter.in_barcode_range(brand_tag, barcode) is is_valid diff --git a/tests/insights/test_question.py b/tests/insights/test_question.py index 7714c90896..74218875cc 100644 --- a/tests/insights/test_question.py +++ b/tests/insights/test_question.py @@ -3,12 +3,14 @@ from robotoff.insights.question import get_display_image -@pytest.mark.parametrize('source_image,output', [ - ('/366/194/903/0038/1.jpg', '/366/194/903/0038/1.400.jpg'), - ('/366/194/903/0038/20.jpg', '/366/194/903/0038/20.400.jpg'), - ('/366/194/903/0038/20.400.jpg', '/366/194/903/0038/20.400.jpg'), - ('/366/194/903/0038/20test.jpg', '/366/194/903/0038/20test.jpg'), -]) +@pytest.mark.parametrize( + "source_image,output", + [ + ("/366/194/903/0038/1.jpg", "/366/194/903/0038/1.400.jpg"), + ("/366/194/903/0038/20.jpg", "/366/194/903/0038/20.400.jpg"), + ("/366/194/903/0038/20.400.jpg", "/366/194/903/0038/20.400.jpg"), + ("/366/194/903/0038/20test.jpg", "/366/194/903/0038/20test.jpg"), + ], +) def test_get_display_image(source_image: str, output: str): assert get_display_image(source_image) == output - diff --git a/tests/test_ingredients.py b/tests/test_ingredients.py index 7f17686374..6596ac79d8 100644 --- a/tests/test_ingredients.py +++ b/tests/test_ingredients.py @@ -1,22 +1,41 @@ import pytest -from robotoff.ingredients import process_ingredients, normalize_ingredients, Ingredients +from robotoff.ingredients import Ingredients, normalize_ingredients, process_ingredients -@pytest.mark.parametrize('text,normalized', [ - ("farine de blé 10,5%, huile de colza 8%, soja 0,15%", "farine de blé , huile de colza , soja "), - ("Eau, céréales 15,2 % (épeautre 7 %, riz 6 %, °avoine_), pâte", "Eau, céréales (épeautre , riz , °avoine ), pâte"), -]) +@pytest.mark.parametrize( + "text,normalized", + [ + ( + "farine de blé 10,5%, huile de colza 8%, soja 0,15%", + "farine de blé , huile de colza , soja ", + ), + ( + "Eau, céréales 15,2 % (épeautre 7 %, riz 6 %, °avoine_), pâte", + "Eau, céréales (épeautre , riz , °avoine ), pâte", + ), + ], +) def test_normalize_ingredients(text, normalized): assert normalized == normalize_ingredients(text) def test_process_ingredients(): text = "Eau, oeufs frais, farine de blé 19%, huile de colza, lactose et protéines de lait, sel, extrait d'épices" - normalized = "Eau oeufs frais farine de blé huile de colza lactose et protéines de lait sel " \ - "extrait d'épices" + normalized = ( + "Eau oeufs frais farine de blé huile de colza lactose et protéines de lait sel " + "extrait d'épices" + ) ingredients = process_ingredients(text) assert isinstance(ingredients, Ingredients) assert ingredients.text == text assert ingredients.normalized == normalized - assert ingredients.offsets == [(0, 3), (4, 16), (17, 35), (36, 51), (52, 81), (82, 86), (87, 104)] + assert ingredients.offsets == [ + (0, 3), + (4, 16), + (17, 35), + (36, 51), + (52, 81), + (82, 86), + (87, 104), + ] diff --git a/tests/test_taxonomy.py b/tests/test_taxonomy.py index 88faa17ea8..7c51075d5c 100644 --- a/tests/test_taxonomy.py +++ b/tests/test_taxonomy.py @@ -5,23 +5,34 @@ from robotoff import settings from robotoff.taxonomy import Taxonomy - label_taxonomy = Taxonomy.from_json(settings.TAXONOMY_LABEL_PATH) class TestTaxonomy: - @pytest.mark.parametrize('taxonomy,item,candidates,output', [ - (label_taxonomy, 'en:organic', {'en:fr-bio-01'}, True), - (label_taxonomy, 'en:fr-bio-01', {'en:organic'}, False), - (label_taxonomy, 'en:fr-bio-01', [], False), - (label_taxonomy, 'en:organic', {'en:gluten-free'}, False), - (label_taxonomy, 'en:organic', - {'en:gluten-free', 'en:no-additives', 'en:vegan'}, False), - (label_taxonomy, 'en:organic', - {'en:gluten-free', 'en:no-additives', 'en:fr-bio-16'}, True), - ]) - def test_is_child_of_any(self, taxonomy: Taxonomy, item: str, - candidates: List, output: bool): + @pytest.mark.parametrize( + "taxonomy,item,candidates,output", + [ + (label_taxonomy, "en:organic", {"en:fr-bio-01"}, True), + (label_taxonomy, "en:fr-bio-01", {"en:organic"}, False), + (label_taxonomy, "en:fr-bio-01", [], False), + (label_taxonomy, "en:organic", {"en:gluten-free"}, False), + ( + label_taxonomy, + "en:organic", + {"en:gluten-free", "en:no-additives", "en:vegan"}, + False, + ), + ( + label_taxonomy, + "en:organic", + {"en:gluten-free", "en:no-additives", "en:fr-bio-16"}, + True, + ), + ], + ) + def test_is_child_of_any( + self, taxonomy: Taxonomy, item: str, candidates: List, output: bool + ): assert taxonomy.is_parent_of_any(item, candidates) is output def test_is_child_of_any_unknwon_item(self):