diff --git a/robotoff/prediction/object_detection/core.py b/robotoff/prediction/object_detection/core.py index 25fb2d772f..f84751d877 100644 --- a/robotoff/prediction/object_detection/core.py +++ b/robotoff/prediction/object_detection/core.py @@ -11,6 +11,7 @@ from robotoff.triton import get_triton_inference_stub from robotoff.types import JSONType, ObjectDetectionModel from robotoff.utils import get_logger, text_file_iter +from robotoff.utils.image import convert_image_to_array logger = get_logger(__name__) @@ -70,15 +71,6 @@ def to_json(self, threshold: Optional[float] = None) -> list[JSONType]: return [dataclasses.asdict(r) for r in self.select(threshold)] -def convert_image_to_array(image: Image.Image) -> np.ndarray: - if image.mode != "RGB": - image = image.convert("RGB") - - (im_width, im_height) = image.size - - return np.array(image.getdata()).reshape((im_height, im_width, 3)).astype(np.uint8) - - def add_boxes_and_labels(image_array: np.ndarray, raw_result: ObjectDetectionRawResult): vis_util.visualize_boxes_and_labels_on_image_array( image_array, diff --git a/robotoff/utils/__init__.py b/robotoff/utils/__init__.py index e966bfddf7..499f903107 100644 --- a/robotoff/utils/__init__.py +++ b/robotoff/utils/__init__.py @@ -1,61 +1,16 @@ import gzip -import logging -import os import pathlib -import sys -from io import BytesIO -from typing import Any, Callable, Iterable, Optional, Union -from urllib.parse import urlparse +from typing import Any, Callable, Iterable, Union import orjson -import PIL import requests -from PIL import Image from requests.adapters import HTTPAdapter -from requests.exceptions import ConnectionError as RequestConnectionError -from requests.exceptions import SSLError, Timeout from robotoff import settings from robotoff.types import JSONType - -def get_logger(name=None, level: Optional[int] = None): - logger = logging.getLogger(name) - - if level is None: - log_level = os.environ.get("LOG_LEVEL", "INFO").upper() - level = logging.getLevelName(log_level) - - if not isinstance(level, int): - print( - "Unknown log level: {}, fallback to INFO".format(log_level), - file=sys.stderr, - ) - level = 20 - - logger.setLevel(level) - - if name is None: - configure_root_logger(logger, level) - - return logger - - -def configure_root_logger(logger, level: int = 20): - logger.setLevel(level) - handler = logging.StreamHandler() - formatter = logging.Formatter( - "%(asctime)s :: %(processName)s :: " - "%(threadName)s :: %(levelname)s :: " - "%(message)s" - ) - handler.setFormatter(formatter) - handler.setLevel(level) - logger.addHandler(handler) - - for name in ("redis_lock", "spacy"): - logging.getLogger(name).setLevel(logging.WARNING) - +from .image import ImageLoadingException, get_image_from_url # noqa: F401 +from .logger import get_logger logger = get_logger(__name__) @@ -164,73 +119,6 @@ def dump_text(filepath: Union[str, pathlib.Path], text_iter: Iterable[str]): f.write(item + "\n") -class ImageLoadingException(Exception): - """Exception raised by `get_image_from_url`` when image cannot be fetched - from URL or if loading failed. - """ - - pass - - -def get_image_from_url( - image_url: str, - error_raise: bool = True, - session: Optional[requests.Session] = None, -) -> Optional[Image.Image]: - """Fetch an image from `image_url` and load it. - - :param image_url: URL of the image to load - :param error_raise: if True, raises a `ImageLoadingException` if an error - occured, defaults to False. If False, None is returned if an error occurs. - :param session: requests Session to use, by default no session is used. - :raises ImageLoadingException: _description_ - :return: the Pillow Image or None. - """ - auth = ( - settings._off_net_auth - if urlparse(image_url).netloc.endswith("openfoodfacts.net") - else None - ) - try: - if session: - r = session.get(image_url, auth=auth) - else: - r = requests.get(image_url, auth=auth) - except (RequestConnectionError, SSLError, Timeout) as e: - error_message = "Cannot download image %s" - if error_raise: - raise ImageLoadingException(error_message % image_url) from e - logger.info(error_message, image_url, exc_info=e) - return None - - if not r.ok: - error_message = "Cannot download image %s: HTTP %s" - error_args = (image_url, r.status_code) - if error_raise: - raise ImageLoadingException(error_message % error_args) - logger.log( - logging.INFO if r.status_code < 500 else logging.WARNING, - error_message, - *error_args, - ) - return None - - try: - return Image.open(BytesIO(r.content)) - except PIL.UnidentifiedImageError: - error_message = f"Cannot identify image {image_url}" - if error_raise: - raise ImageLoadingException(error_message) - logger.info(error_message) - except PIL.Image.DecompressionBombError: - error_message = f"Decompression bomb error for image {image_url}" - if error_raise: - raise ImageLoadingException(error_message) - logger.info(error_message) - - return None - - http_session = requests.Session() USER_AGENT_HEADERS = { "User-Agent": settings.ROBOTOFF_USER_AGENT, diff --git a/robotoff/utils/image.py b/robotoff/utils/image.py new file mode 100644 index 0000000000..c0ac961eb0 --- /dev/null +++ b/robotoff/utils/image.py @@ -0,0 +1,100 @@ +import logging +from io import BytesIO +from typing import Optional +from urllib.parse import urlparse + +import numpy as np +import PIL +import requests +from PIL import Image +from requests.exceptions import ConnectionError as RequestConnectionError +from requests.exceptions import SSLError, Timeout + +from robotoff import settings + +from .logger import get_logger + +logger = get_logger(__name__) + + +def convert_image_to_array(image: Image.Image) -> np.ndarray: + """Convert a PIL Image into a numpy array. + + The image is converted to RGB if needed before generating the array. + + :param image: the input image + :return: the generated numpy array of shape (width, height, 3) + """ + if image.mode != "RGB": + image = image.convert("RGB") + + (im_width, im_height) = image.size + + return np.array(image.getdata()).reshape((im_height, im_width, 3)).astype(np.uint8) + + +class ImageLoadingException(Exception): + """Exception raised by `get_image_from_url`` when image cannot be fetched + from URL or if loading failed. + """ + + pass + + +def get_image_from_url( + image_url: str, + error_raise: bool = True, + session: Optional[requests.Session] = None, +) -> Optional[Image.Image]: + """Fetch an image from `image_url` and load it. + + :param image_url: URL of the image to load + :param error_raise: if True, raises a `ImageLoadingException` if an error + occured, defaults to False. If False, None is returned if an error occurs. + :param session: requests Session to use, by default no session is used. + :raises ImageLoadingException: _description_ + :return: the Pillow Image or None. + """ + auth = ( + settings._off_net_auth + if urlparse(image_url).netloc.endswith("openfoodfacts.net") + else None + ) + try: + if session: + r = session.get(image_url, auth=auth) + else: + r = requests.get(image_url, auth=auth) + except (RequestConnectionError, SSLError, Timeout) as e: + error_message = "Cannot download image %s" + if error_raise: + raise ImageLoadingException(error_message % image_url) from e + logger.info(error_message, image_url, exc_info=e) + return None + + if not r.ok: + error_message = "Cannot download image %s: HTTP %s" + error_args = (image_url, r.status_code) + if error_raise: + raise ImageLoadingException(error_message % error_args) + logger.log( + logging.INFO if r.status_code < 500 else logging.WARNING, + error_message, + *error_args, + ) + return None + + try: + return Image.open(BytesIO(r.content)) + except PIL.UnidentifiedImageError: + error_message = f"Cannot identify image {image_url}" + if error_raise: + raise ImageLoadingException(error_message) + logger.info(error_message) + except PIL.Image.DecompressionBombError: + error_message = f"Decompression bomb error for image {image_url}" + if error_raise: + raise ImageLoadingException(error_message) + logger.info(error_message) + + return None diff --git a/robotoff/utils/logger.py b/robotoff/utils/logger.py new file mode 100644 index 0000000000..8dcea0e1d3 --- /dev/null +++ b/robotoff/utils/logger.py @@ -0,0 +1,42 @@ +import logging +import os +import sys +from typing import Optional + + +def get_logger(name=None, level: Optional[int] = None): + logger = logging.getLogger(name) + + if level is None: + log_level = os.environ.get("LOG_LEVEL", "INFO").upper() + level = logging.getLevelName(log_level) + + if not isinstance(level, int): + print( + "Unknown log level: {}, fallback to INFO".format(log_level), + file=sys.stderr, + ) + level = 20 + + logger.setLevel(level) + + if name is None: + configure_root_logger(logger, level) + + return logger + + +def configure_root_logger(logger, level: int = 20): + logger.setLevel(level) + handler = logging.StreamHandler() + formatter = logging.Formatter( + "%(asctime)s :: %(processName)s :: " + "%(threadName)s :: %(levelname)s :: " + "%(message)s" + ) + handler.setFormatter(formatter) + handler.setLevel(level) + logger.addHandler(handler) + + for name in ("redis_lock", "spacy"): + logging.getLogger(name).setLevel(logging.WARNING) diff --git a/robotoff/workers/tasks/import_image.py b/robotoff/workers/tasks/import_image.py index 034146bb0a..73217ead14 100644 --- a/robotoff/workers/tasks/import_image.py +++ b/robotoff/workers/tasks/import_image.py @@ -3,7 +3,6 @@ from typing import Optional import elasticsearch -import numpy from elasticsearch.helpers import BulkIndexError from PIL import Image @@ -45,6 +44,7 @@ ServerType, ) from robotoff.utils import get_image_from_url, get_logger, http_session +from robotoff.utils.image import convert_image_to_array from robotoff.workers.queues import enqueue_job, get_high_queue logger = get_logger(__name__) @@ -267,7 +267,9 @@ def run_upc_detection(product_id: ProductIdentifier, image_url: str) -> None: logger.info("Error while downloading image %s", image_url) return - area, prediction_class, polygon = find_image_is_upc(numpy.array(image)) + area, prediction_class, polygon = find_image_is_upc( + convert_image_to_array(image) + ) ImagePrediction.create( image=image_model, type="upc_image", diff --git a/tests/unit/utils/test_module.py b/tests/unit/utils/test_module.py index 0411baaa69..0305756df8 100644 --- a/tests/unit/utils/test_module.py +++ b/tests/unit/utils/test_module.py @@ -80,7 +80,8 @@ def test_get_image_from_url_decompression_bomb(mocker): session_mock = mocker.Mock() response_mock = mocker.Mock() mocker.patch( - "robotoff.utils.Image", **{"open.side_effect": Image.DecompressionBombError()} + "robotoff.utils.image.Image", + **{"open.side_effect": Image.DecompressionBombError()} ) response_mock.content = generate_image() response_mock.ok = True