diff --git a/robotoff/insights/importer.py b/robotoff/insights/importer.py index 29bd98d9e2..da40f8659a 100644 --- a/robotoff/insights/importer.py +++ b/robotoff/insights/importer.py @@ -1002,26 +1002,26 @@ def is_in_barcode_range(barcode: str, tag: str) -> bool: @staticmethod def is_prediction_valid(item: Prediction | ProductInsight) -> bool: - """Return True if the Prediction or ProductInsight is valid: - - - we check for 'taxonomy' predictor whether the brand is excluded - - we check that the brand is compatible with the barcode - range + """Return True if the Prediction or ProductInsight is valid. + For 'taxonomy' and 'curated-list' predictors, we check that the brand + is not in the blacklist and that it is compatible with the barcode + range. :param item: a Prediction or a ProductInsight + :return: True if the item is valid """ - if item.predictor == "universal-logo-detector" and "username" in item.data: - # Don't perform barcode range check and for logos detected - # using universal-logo-detector model and annotated manually - return True + if item.predictor in ("taxonomy", "curated-list"): + brand_blacklist = get_brand_blacklist() + if item.value_tag in brand_blacklist: + return False - brand_blacklist = get_brand_blacklist() - if item.predictor == "taxonomy" and item.value_tag in brand_blacklist: - return False + return BrandInsightImporter.is_in_barcode_range( + item.barcode, item.value_tag # type: ignore + ) - return BrandInsightImporter.is_in_barcode_range( - item.barcode, item.value_tag # type: ignore - ) + # Don't perform barcode range check and for other predictors + # (universal-logo-detector, google-cloud-vision) + return True @classmethod def generate_candidates( diff --git a/tests/unit/insights/test_importer.py b/tests/unit/insights/test_importer.py index 1271d56c9c..cfb55f5c95 100644 --- a/tests/unit/insights/test_importer.py +++ b/tests/unit/insights/test_importer.py @@ -1168,6 +1168,70 @@ def test_is_conflicting_insight(self): ProductInsight(value_tag="tag1"), ProductInsight(value_tag="tag2") ) + def test_is_prediction_valid(self): + base_values = { + "type": PredictionType.brand, + "value_tag": "carrefour", + } + assert ( + BrandInsightImporter.is_prediction_valid( + Prediction( + barcode="3560070880973", # This is a Carrefour product + predictor="curated-list", + **base_values, + ) + ) + is True + ) + assert ( + BrandInsightImporter.is_prediction_valid( + Prediction( + barcode="3510070880973", # This is *not* a Carrefour product + predictor="curated-list", + **base_values, + ) + ) + is False + ) + # We don't check for barcode range if the predictor is not curated-list or + # taxonomy + assert ( + BrandInsightImporter.is_prediction_valid( + Prediction( + barcode="3510070880973", # This is *not* a Carrefour product + predictor="google-cloud-vision", + **base_values, + ) + ) + is True + ) + # We only check the inclusion of the brand in the blacklist if the predictor + # is not curated-list or taxonomy + assert ( + BrandInsightImporter.is_prediction_valid( + Prediction( + barcode="3510070880973", + predictor="google-cloud-vision", + type=PredictionType.brand, + value_tag="asia", # This brand is in the blacklist + ) + ) + is True + ) + # We check the inclusion of the brand in the blacklist if the predictor is + # curated-list or taxonomy + assert ( + BrandInsightImporter.is_prediction_valid( + Prediction( + barcode="3510070880973", + predictor="taxonomy", + type=PredictionType.brand, + value_tag="asia", # This brand is in the blacklist + ) + ) + is False + ) + class TestStoreInsightImporter: def test_get_type(self):