Skip to content

Commit

Permalink
style: ✨ make lint
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremyarancio committed Sep 2, 2024
1 parent 01d884a commit fda7b5d
Show file tree
Hide file tree
Showing 10 changed files with 80 additions and 75 deletions.
24 changes: 11 additions & 13 deletions robotoff/app/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
from robotoff import settings
from robotoff.app import schema
from robotoff.app.auth import (
APITokenError,
BasicAuthDecodeError,
APITokenError,
basic_decode,
basic_decode,
validate_token,
)
from robotoff.app.core import (
Expand All @@ -45,6 +45,7 @@
validate_params,
)
from robotoff.app.middleware import DBConnectionMiddleware
from robotoff.batch import BatchJobType, import_batch_predictions
from robotoff.elasticsearch import get_es_client
from robotoff.insights.extraction import (
DEFAULT_OCR_PREDICTION_TYPES,
Expand Down Expand Up @@ -91,10 +92,6 @@
from robotoff.utils.text import get_tag
from robotoff.workers.queues import enqueue_job, get_high_queue, low_queue
from robotoff.workers.tasks import download_product_dataset_job
from robotoff.batch import (
BatchJobType,
import_batch_predictions,
)

logger = get_logger()

Expand Down Expand Up @@ -311,7 +308,7 @@ def parse_valid_token(req: falcon.Request, ref_token_name: str) -> bool:
:param req: Request.
:type req: falcon.Request
:param ref_token_name: Secret environment variable name.
:param ref_token_name: Secret environment variable name.
:type ref_token_name: str
:return: Token valid or not.
"""
Expand All @@ -321,11 +318,13 @@ def parse_valid_token(req: falcon.Request, ref_token_name: str) -> bool:
scheme, token = auth_header.split()
except APITokenError:
raise falcon.HTTPUnauthorized("Invalid authentication scheme.")
if scheme.lower() != 'bearer':
raise falcon.HTTPUnauthorized("Invalid authentication scheme: 'Bearer Token' expected.")
if scheme.lower() != "bearer":
raise falcon.HTTPUnauthorized(
"Invalid authentication scheme: 'Bearer Token' expected."
)
is_token_valid = validate_token(token, ref_token_name)
if not is_token_valid:
raise falcon.HTTPUnauthorized('Invalid token.')
raise falcon.HTTPUnauthorized("Invalid token.")
else:
return True

Expand Down Expand Up @@ -1779,14 +1778,13 @@ def on_get(self, req: falcon.Request, resp: falcon.Response):
resp.media = response



class BatchJobImportResource:
def on_post(self, req: falcon.Request, resp: falcon.Response):
job_type_str: str = req.get_param("job_type", required=True)

try:
job_type = BatchJobType[job_type_str]
except KeyError:
except KeyError:
raise falcon.HTTPBadRequest(
description=f"invalid job_type: {job_type_str}. Valid job_types are: {[elt.value for elt in BatchJobType]}"
)
Expand All @@ -1804,7 +1802,7 @@ def on_post(self, req: falcon.Request, resp: falcon.Response):
)
logger.info("Batch import %s has been queued.", job_type)


class RobotsTxtResource:
def on_get(self, req: falcon.Request, resp: falcon.Response):
# Disallow completely indexation: otherwise web crawlers send millions
Expand Down
52 changes: 26 additions & 26 deletions robotoff/batch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,15 @@
import os
import tempfile

from robotoff.utils import get_logger
from robotoff.types import (
BatchJobType,
Prediction,
ServerType,
)
from robotoff.models import db
from robotoff.insights.importer import import_insights
from robotoff import settings
from robotoff.types import PredictionType
from robotoff.insights.importer import import_insights
from robotoff.models import db
from robotoff.types import BatchJobType, Prediction, PredictionType, ServerType
from robotoff.utils import get_logger

from .launch import launch_job, GoogleBatchJobConfig
from .buckets import fetch_dataframe_from_gcs, upload_file_to_gcs
from .extraction import extract_from_dataset
from .buckets import upload_file_to_gcs, fetch_dataframe_from_gcs

from .launch import GoogleBatchJobConfig, launch_job

logger = get_logger(__name__)

Expand All @@ -28,7 +22,7 @@ def launch_batch_job(job_type: BatchJobType) -> None:
launch_spellcheck_batch_job()
else:
raise NotImplementedError(f"Batch job type {job_type} not implemented.")


def import_batch_predictions(job_type: BatchJobType) -> None:
"""Import batch predictions once the job finished.
Expand All @@ -41,12 +35,13 @@ def import_batch_predictions(job_type: BatchJobType) -> None:


def launch_spellcheck_batch_job() -> None:
"""Launch spellcheck batch job.
"""
"""Launch spellcheck batch job."""
# Init
JOB_NAME = "ingredients-spellcheck"
QUERY_FILE_PATH = settings.BATCH_JOB_CONFIG_DIR / "sql/spellcheck.sql"
BATCH_JOB_CONFIG_PATH = settings.BATCH_JOB_CONFIG_DIR / "job_configs/spellcheck.yaml"
BATCH_JOB_CONFIG_PATH = (
settings.BATCH_JOB_CONFIG_DIR / "job_configs/spellcheck.yaml"
)
BUCKET_NAME = "robotoff-spellcheck"
SUFFIX_PREPROCESS = "data/preprocessed_data.parquet"

Expand All @@ -56,29 +51,35 @@ def launch_spellcheck_batch_job() -> None:
extract_from_dataset(QUERY_FILE_PATH, file_path)

# Upload the extracted file to the bucket
upload_file_to_gcs(file_path=file_path, bucket_name=BUCKET_NAME, suffix=SUFFIX_PREPROCESS)
upload_file_to_gcs(
file_path=file_path, bucket_name=BUCKET_NAME, suffix=SUFFIX_PREPROCESS
)
logger.debug(f"File uploaded to the bucket {BUCKET_NAME}/{SUFFIX_PREPROCESS}")

# Launch batch job
batch_job_config = GoogleBatchJobConfig.init(job_name=JOB_NAME, config_path=BATCH_JOB_CONFIG_PATH)
batch_job_config = GoogleBatchJobConfig.init(
job_name=JOB_NAME, config_path=BATCH_JOB_CONFIG_PATH
)
batch_job = launch_job(batch_job_config=batch_job_config)
logger.info(f"Batch job succesfully launched. Batch job name: {batch_job.name}.")


def import_spellcheck_batch_predictions() -> None:
"""Import spellcheck predictions from remote storage.
"""
"""Import spellcheck predictions from remote storage."""
# Init
BUCKET_NAME = "robotoff-spellcheck"
SUFFIX_POSTPROCESS = "data/postprocessed_data.parquet"
PREDICTION_TYPE = PredictionType.ingredient_spellcheck
PREDICTOR_VERSION = "1" #TODO: shard HF model version instead of manual change?
PREDICTOR_VERSION = "1" # TODO: shard HF model version instead of manual change?
PREDICTOR = "fine-tuned-mistral-7b"
SERVER_TYPE = ServerType.off

df = fetch_dataframe_from_gcs(bucket_name=BUCKET_NAME, suffix_postprocess=SUFFIX_POSTPROCESS)
logger.debug(f"Batch data downloaded from bucket {BUCKET_NAME}/{SUFFIX_POSTPROCESS}")

df = fetch_dataframe_from_gcs(
bucket_name=BUCKET_NAME, suffix_postprocess=SUFFIX_POSTPROCESS
)
logger.debug(
f"Batch data downloaded from bucket {BUCKET_NAME}/{SUFFIX_POSTPROCESS}"
)

# Generate predictions
predictions = []
Expand All @@ -97,7 +98,6 @@ def import_spellcheck_batch_predictions() -> None:
# Store predictions and insights
with db:
import_results = import_insights(
predictions=predictions,
server_type=SERVER_TYPE
predictions=predictions, server_type=SERVER_TYPE
)
logger.info("Batch import results: %s", import_results)
6 changes: 4 additions & 2 deletions robotoff/batch/buckets.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@ def fetch_dataframe_from_gcs(bucket_name: str, suffix: str) -> pd.DataFrame:
bucket = client.get_bucket(bucket_name)
blob = bucket.blob(suffix)
with blob.open("rb") as f:
try:
try:
df = pd.read_parquet(f)
except Exception as e:
raise ValueError(f"Could not read parquet file from {bucket_name}/{suffix}. Error: {e}")
raise ValueError(
f"Could not read parquet file from {bucket_name}/{suffix}. Error: {e}"
)
return df
9 changes: 2 additions & 7 deletions robotoff/batch/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from robotoff import settings
from robotoff.utils import get_logger


logger = get_logger(__name__)


Expand All @@ -30,7 +29,6 @@ def extract_from_dataset(
logger.debug(f"Batch data succesfully extracted and saved at {output_file_path}")



def _load_query(query_file_path: Path, dataset_path: Path) -> str:
"""Load the SQL query from a corresponding file.
Expand All @@ -49,6 +47,7 @@ def _load_query(query_file_path: Path, dataset_path: Path) -> str:
logger.debug(f"Query used to extract batch from dataset: {query}")
return query


def _extract_and_save_batch_data(query: str, output_file_path: str) -> None:
"""Query and save the data.
Expand All @@ -57,8 +56,4 @@ def _extract_and_save_batch_data(query: str, output_file_path: str) -> None:
:param output_file_path: Path to save the extracted data.
:type output_file_path: str
"""
(
duckdb
.sql(query)
.write_parquet(output_file_path)
)
(duckdb.sql(query).write_parquet(output_file_path))
19 changes: 12 additions & 7 deletions robotoff/batch/launch.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
from typing import List, Optional
import yaml
import datetime
import re
from pathlib import Path
from typing import List, Optional

import yaml
from google.cloud import batch_v1
from pydantic import BaseModel, Field, ConfigDict
from pydantic import BaseModel, ConfigDict, Field

from robotoff import settings


class GoogleBatchJobConfig(BaseModel):
"""Batch job configuration class."""

# By default, extra fields are just ignored. We raise an error in case of extra fields.
model_config: ConfigDict = {"extra": "forbid"}

Expand Down Expand Up @@ -95,8 +96,10 @@ def init(cls, job_name: str, config_path: Path) -> "GoogleBatchJobConfig":
# Batch job name should respect a specific pattern, or returns an error
pattern = "^[a-z]([a-z0-9-]{0,61}[a-z0-9])?$"
if not re.match(pattern, job_name):
raise ValueError(f"Job name should respect the pattern: {pattern}. Current job name: {job_name}")

raise ValueError(
f"Job name should respect the pattern: {pattern}. Current job name: {job_name}"
)

# Generate unique id for the job
unique_job_name = (
job_name + "-" + datetime.datetime.now().strftime("%Y%m%d%H%M%S")
Expand All @@ -113,7 +116,7 @@ def launch_job(batch_job_config: GoogleBatchJobConfig) -> batch_v1.Job:
Sources:
* https://github.com/GoogleCloudPlatform/python-docs-samples/tree/main/batch/create
* https://cloud.google.com/python/docs/reference/batch/latest/google.cloud.batch_v1.types
:param google_batch_launch_config: Config to run a job on Google Batch.
:type google_batch_launch_config: GoogleBatchLaunchConfig
:param batch_job_config: Config to run a specific job on Google Batch.
Expand Down Expand Up @@ -176,6 +179,8 @@ def launch_job(batch_job_config: GoogleBatchJobConfig) -> batch_v1.Job:
create_request.job = job
create_request.job_id = batch_job_config.job_name
# The job's parent is the region in which the job will run
create_request.parent = f"projects/{settings.GOOGLE_PROJECT_NAME}/locations/{batch_job_config.location}"
create_request.parent = (
f"projects/{settings.GOOGLE_PROJECT_NAME}/locations/{batch_job_config.location}"
)

return client.create_job(create_request)
12 changes: 8 additions & 4 deletions robotoff/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -1000,16 +1000,20 @@ def create_migration(

@app.command()
def launch_batch_job(
job_type: str = typer.Argument(..., help="Type of job to launch. Ex: 'ingredients_spellcheck'"),
job_type: str = typer.Argument(
..., help="Type of job to launch. Ex: 'ingredients_spellcheck'"
),
) -> None:
"""Launch a batch job."""
from robotoff.batch import launch_batch_job as _launch_batch_job
from robotoff.utils import get_logger
from robotoff.types import BatchJobType
from robotoff.utils import get_logger

if job_type not in BatchJobType.__members__:
raise ValueError(f"Invalid job type: {job_type}. Must be one of those: {[job.name for job in BatchJobType]}")

raise ValueError(
f"Invalid job type: {job_type}. Must be one of those: {[job.name for job in BatchJobType]}"
)

get_logger()
job_type = BatchJobType[job_type]
_launch_batch_job(job_type)
Expand Down
9 changes: 4 additions & 5 deletions robotoff/insights/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1480,7 +1480,7 @@ class IngredientSpellcheckImporter(InsightImporter):
@staticmethod
def get_type() -> InsightType:
return InsightType.ingredient_spellcheck

@classmethod
def get_required_prediction_types(cls) -> set[PredictionType]:
return {PredictionType.ingredient_spellcheck}
Expand All @@ -1495,15 +1495,14 @@ def generate_candidates(
# Only one prediction
for candidate in predictions:
yield ProductInsight(**candidate.to_dict())

@classmethod
def is_conflicting_insight(
cls,
candidate: ProductInsight,
reference: ProductInsight
cls, candidate: ProductInsight, reference: ProductInsight
) -> bool:
candidate.value_tag == reference.value_tag


class PackagingElementTaxonomyException(Exception):
pass

Expand Down
2 changes: 1 addition & 1 deletion robotoff/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,4 +359,4 @@ def get_package_version() -> str:
CROP_ALLOWED_DOMAINS = os.environ.get("CROP_ALLOWED_DOMAINS", "").split(",")

# Batch jobs
GOOGLE_PROJECT_NAME= "robotoff"
GOOGLE_PROJECT_NAME = "robotoff"
7 changes: 4 additions & 3 deletions robotoff/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,8 +359,9 @@ class PackagingElementProperty(enum.Enum):

InsightAnnotation = Literal[-1, 0, 1, 2]


@enum.unique
class BatchJobType(enum.Enum):
"""Each job type correspond to a task that will be executed in the batch job.
"""
ingredients_spellcheck = "ingredients-spellcheck"
"""Each job type correspond to a task that will be executed in the batch job."""

ingredients_spellcheck = "ingredients-spellcheck"
Loading

0 comments on commit fda7b5d

Please sign in to comment.