Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Batch job - Spellcheck #1401

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
75ebf64
feat(Batch job - Spellcheck): :zap:
jeremyarancio Aug 21, 2024
a5830e2
Merge branch 'main' into batch-job-spellcheck
jeremyarancio Aug 21, 2024
eb15bab
fix(batch-spellcheck): :lipstick: Fix Spellcheck Batch job file name …
jeremyarancio Aug 21, 2024
d36648b
feat(batch-spellcheck): :zap: Batch extraction from database before B…
jeremyarancio Aug 22, 2024
c14338d
refactor(batch-spellcheck): :green_heart: Fix some bugs: batch-extrac…
jeremyarancio Aug 23, 2024
6c83b8c
feat(batch - spellcheck): :zap: From predictions to insights
jeremyarancio Aug 24, 2024
a369a59
feat(batch - spellcheck): :zap: API endpoint batch/launch ok: Batch e…
jeremyarancio Aug 26, 2024
729d4e1
feat(batch - spellcheck): :zap: Integrate batch data from job into Ro…
jeremyarancio Aug 27, 2024
34ce80e
feat: :sparkles: Restructure code
jeremyarancio Aug 27, 2024
f381ecb
Merge branch 'main' into batch-job-spellcheck
jeremyarancio Aug 27, 2024
92cb5f3
feat: :sparkles: Change batch job launch from api endpoint to CLI
jeremyarancio Aug 28, 2024
54f1734
feat: :lock: Secure Batch Data Import endpoint with a token key
jeremyarancio Aug 28, 2024
4aabf4b
feat: :art: Add key during request by the batch job
jeremyarancio Aug 28, 2024
01d884a
feat: :sparkles: Implemenation reviews
jeremyarancio Sep 2, 2024
fda7b5d
style: :sparkles: make lint
jeremyarancio Sep 2, 2024
f8ed76a
fix: :bug: Fixed bug & Better error handling with Falcon
jeremyarancio Sep 3, 2024
85b7bfb
feat: :ambulance: Changes
jeremyarancio Sep 3, 2024
31ce875
feat: :ambulance: Credential + Importer
jeremyarancio Sep 3, 2024
7c92836
feat: :ambulance: Credentials + Importer + Test
jeremyarancio Sep 4, 2024
cb49cd9
Merge branch 'main' into batch-job-spellcheck
jeremyarancio Sep 4, 2024
be475bd
feat: :bug: Forgot a return
jeremyarancio Sep 4, 2024
762722f
style: :sparkles: Black on spellcheck script
jeremyarancio Sep 4, 2024
10791e7
docs: :memo: Add batch/import api endpoint to doc
jeremyarancio Sep 4, 2024
400818b
docs: :memo: Because perfection
jeremyarancio Sep 4, 2024
4ebfd87
fix: :art: Change predictor version to also track... the predictor ve…
jeremyarancio Sep 4, 2024
b0b15b4
Merge branch 'main' into batch-job-spellcheck
raphael0202 Sep 10, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ DOCKER_COMPOSE=docker compose --env-file=${ENV_FILE}
DOCKER_COMPOSE_TEST=COMPOSE_PROJECT_NAME=robotoff_test COMMON_NET_NAME=po_test docker compose --env-file=${ENV_FILE}
ML_OBJECT_DETECTION_MODELS := tf-universal-logo-detector tf-nutrition-table tf-nutriscore

# Spellcheck
IMAGE_NAME = spellcheck-batch-vllm
TAG = latest
GCLOUD_LOCATION = europe-west9-docker.pkg.dev
REGISTRY = ${GCLOUD_LOCATION}/robotoff/gcf-artifacts

.DEFAULT_GOAL := dev
# avoid target corresponding to file names, to depends on them
.PHONY: *
Expand Down Expand Up @@ -290,4 +296,17 @@ create-migration: guard-args

# create network if not exists
create-po-default-network:
docker network create po_default || true
docker network create po_default || true

# Spellcheck
build-spellcheck:
docker build -f batch/spellcheck/Dockerfile -t $(IMAGE_NAME):$(TAG) batch/spellcheck

# Push the image to the registry
push-spellcheck:
docker tag $(IMAGE_NAME):$(TAG) $(REGISTRY)/$(IMAGE_NAME):$(TAG)
docker push $(REGISTRY)/$(IMAGE_NAME):$(TAG)

# Build and push in one command
deploy-spellcheck:
build-spellcheck push-spellcheck
15 changes: 15 additions & 0 deletions batch/spellcheck/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
FROM pytorch/pytorch:2.2.0-cuda12.1-cudnn8-devel

ENV PYTHONUNBUFFERED=1 \
PYTHONDONTWRITEBYTECODE=1 \
PIP_DISABLE_PIP_VERSION_CHECK=on

WORKDIR /app

COPY main.py /app
COPY requirements.txt /app

RUN pip install --no-cache-dir -r requirements.txt

# Set the entrypoint to the batch job script
ENTRYPOINT ["python", "main.py"]
41 changes: 41 additions & 0 deletions batch/spellcheck/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Google Batch job

## Notes

* Netherland (europe-west4) has GPUs (A100, L4)
* Check [CLOUD-LOGGING](https://console.cloud.google.com/logs/query;query=SEARCH%2528%22spellcheck%22%2529;cursorTimestamp=2024-08-14T11:21:32.485988660Z;duration=PT1H?referrer=search&project=robotoff) for logs
* Require deep learning image to run: [deep learning containers list](https://cloud.google.com/deep-learning-containers/docs/choosing-container#pytorch)
* Custom storage capacity to host the heavy docker image (~24GB) by adding BootDisk
* 1000 products processed: 1:30min (g2-instance-with 8) (overall batch job: 3:25min)
* L4: g2-instance-8 hourly cost: $0.896306 -> ~ 0.05$ to process batch of 1000
* A100: a2-highgpu-1g: $3.748064
* A100/Cuda doesn't support FP8
* A100 has less availability than L4: need to wait for batch job (can be long)
* Don't forget to enable **Batch & Storage API** if used without gcloud

## Links

* [GPU availability per region](https://cloud.google.com/compute/docs/gpus/gpu-regions-zones)
* [Batch job with GPU](https://cloud.google.com/batch/docs/create-run-job-gpus#create-job-gpu-examples)
* [VM Instance pricing](https://cloud.google.com/compute/vm-instance-pricing#vm-instance-pricing)
* [Trigger cloud function with bucket updates](https://cloud.google.com/functions/docs/calling/storage)
* [Python Google Batch](https://github.com/GoogleCloudPlatform/python-docs-samples/tree/main/batch)

## Commands

### List GPUs per region
```bash
gcloud compute accelerator-types list
```

### List deep learning images
```bash
gcloud compute images list \
--project deeplearning-platform-release \
--format="value(NAME)" \
--no-standard-images
```

## Workflow / Orchestration

* [Workflow](https://cloud.google.com/workflows/docs/overview)
217 changes: 217 additions & 0 deletions batch/spellcheck/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
import os
import argparse
import tempfile
import logging
import sys
import requests
from typing import List

import pandas as pd
from vllm import LLM, SamplingParams
from google.cloud import storage


logger = logging.getLogger(__name__)
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[logging.StreamHandler(sys.stdout)],
)

FEATURES_VALIDATION = ["code", "text"]


def parse() -> argparse.Namespace:
"""Parse command line arguments."""
parser = argparse.ArgumentParser(description="Spellcheck module.")
parser.add_argument(
"--data_bucket", type=str, default="robotoff-spellcheck", help="Bucket name."
)
parser.add_argument(
"--pre_data_suffix",
type=str,
default="data/preprocessed_data.parquet",
help="Dataset suffix containing the data to be processed.",
)
parser.add_argument(
"--post_data_suffix",
type=str,
default="data/postprocessed_data.parquet",
help="Dataset suffix containing the processed data.",
)
parser.add_argument(
"--model_path",
default="openfoodfacts/spellcheck-mistral-7b",
type=str,
help="HF model path.",
)
parser.add_argument(
"--max_model_len",
default=1024,
type=int,
help="Maximum model context length. A lower max context length reduces the memory footprint and accelerate the inference.",
)
parser.add_argument(
"--temperature", default=0, type=float, help="Sampling temperature."
)
parser.add_argument(
"--max_tokens",
default=1024,
type=int,
help="Maximum number of tokens to generate.",
)
parser.add_argument(
"--quantization", default="fp8", type=str, help="Quantization type."
)
parser.add_argument(
"--dtype",
default="auto",
type=str,
help="Model weights precision. Default corresponds to the modle config (float16 here)",
)
return parser.parse_args()


def main():
"""Batch processing job.

Original lists of ingredients are stored in a gs bucket before being loaded then processed by the model.
The corrected lists of ingredients are then stored back in gs.

We use vLLM to process the batch optimaly. The model is loaded from the Open Food Facts Hugging Face model repository.
"""
logger.info("Starting batch processing job.")
args = parse()

logger.info(f"Loading data from GCS: {args.data_bucket}/{args.pre_data_suffix}")
data = load_gcs(bucket_name=args.data_bucket, suffix=args.pre_data_suffix)
logger.info(f"Feature in uploaded data: {data.columns}")
if not all(feature in data.columns for feature in FEATURES_VALIDATION):
raise ValueError(
f"Data should contain the following features: {FEATURES_VALIDATION}. Current features: {data.columns}"
)

instructions = [prepare_instruction(text) for text in data["text"]]
llm = LLM(
model=args.model_path,
max_model_len=args.max_model_len,
dtype=args.dtype,
quantization=args.quantization,
)
sampling_params = SamplingParams(
temperature=args.temperature, max_tokens=args.max_tokens
)

logger.info(
f"Starting batch inference:\n {llm}.\n\nSampling parameters: {sampling_params}"
)
data["correction"] = batch_inference(
instructions, llm=llm, sampling_params=sampling_params
)

logger.info(f"Uploading data to GCS: {args.data_bucket}/{args.post_data_suffix}")
# Save DataFrame as Parquet to a temporary file
with tempfile.NamedTemporaryFile(delete=True, suffix=".parquet") as temp_file:
data.to_parquet(temp_file.name)
temp_file_name = temp_file.name
upload_gcs(
temp_file_name, bucket_name=args.data_bucket, suffix=args.post_data_suffix
)

logger.info("Request Robotoff API batch import endpoint.")
run_robotoff_endpoint_batch_import()

logger.info("Batch processing job completed.")


def prepare_instruction(text: str) -> str:
"""Prepare instruction prompt for fine-tuning and inference.

Args:
text (str): List of ingredients

Returns:
str: Instruction.
"""
instruction = (
"###Correct the list of ingredients:\n" + text + "\n\n###Correction:\n"
)
return instruction


def batch_inference(
texts: List[str], llm: LLM, sampling_params: SamplingParams
) -> List[str]:
"""Process batch of texts with vLLM.

Args:
texts (List[str]): Batch
llm (LLM): Model engine optimized with vLLM
sampling_params (SamplingParams): Generation parameters

Returns:
List[str]: Processed batch of texts
"""
outputs = llm.generate(
texts,
sampling_params,
)
corrections = [output.outputs[0].text for output in outputs]
return corrections


def load_gcs(bucket_name: str, suffix: str) -> pd.DataFrame:
"""Load data from Google Cloud Storage bucket.

Args:
bucket_name (str):
suffix (str): Path inside the bucket

Returns:
pd.DataFrame: Df from parquet file.
"""
client = storage.Client()
bucket = client.get_bucket(bucket_name)
blob = bucket.blob(suffix)
with blob.open("rb") as f:
df = pd.read_parquet(f)
return df


def upload_gcs(file_path: str, bucket_name: str, suffix: str) -> None:
"""Upload data to GCS.

Args:
filepath (str): File path to export.
bucket_name (str): Bucket name.
suffix (str): Path inside the bucket.
"""
client = storage.Client()
bucket = client.get_bucket(bucket_name)
blob = bucket.blob(suffix)
blob.upload_from_filename(filename=file_path)


def run_robotoff_endpoint_batch_import():
"""Run Robotoff api endpoint to import batch data into tables."""
url = "https://robotoff.openfoodfacts.org/api/v1/batch/import"
data = {"job_type": "ingredients_spellcheck"}
headers = {
"Authorization": f"Bearer {os.getenv('BATCH_JOB_KEY')}",
"Content-Type": "application/json",
}
try:
response = requests.post(
url,
data=data,
headers=headers,
)
logger.info(
f"Import batch Robotoff API endpoint succesfully requested: {response.text}"
)
except requests.exceptions.RequestException as e:
raise SystemExit(e)


if __name__ == "__main__":
main()
2 changes: 2 additions & 0 deletions batch/spellcheck/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
vllm==0.5.4
google-cloud-storage==2.18.0
37 changes: 37 additions & 0 deletions doc/references/api.yml
Original file line number Diff line number Diff line change
Expand Up @@ -1105,6 +1105,28 @@ paths:
"400":
description: "An HTTP 400 is returned if the provided parameters are invalid"

/batch/import:
post:
tags:
- Batch Job
summary: Import batch processed data to the Robotoff database.
security:
- batch_job_key: []
description:
Trigger import of the batch processed data to the Robotoff database. A `BATCH_JOB_KEY` is expected in the authorization header.
This endpoint is mainly used by the batch job once the job is finished.
parameters:
- $ref: "#/components/parameters/job_type"
responses:
"200":
description: Data successfully imported.
content:
application/json:
status:
type: string
description: Request successful. Importing processed data.
"400":
description: "An HTTP 400 is returned if the authentification key is invalid or if the job_type is not supported."

components:
schemas:
Expand Down Expand Up @@ -1391,6 +1413,21 @@ components:
schema:
type: integer
example: 5410041040807
job_type:
name: job_type
in: query
required: true
description: The type of batch job launched.
schema:
type: string
enum:
- ingredients_spellcheck

securitySchemes:
batch_job_key:
type: http
scheme: bearer

tags:
- name: Questions
- name: Insights
Expand Down
6 changes: 5 additions & 1 deletion docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,11 @@ x-robotoff-base-env:
IMAGE_MODERATION_SERVICE_URL:
CROP_ALLOWED_DOMAINS:
NUM_RQ_WORKERS: 4 # Update worker service command accordingly if you change this settings

GOOGLE_APPLICATION_CREDENTIALS: /opt/robotoff/credentials/google/credentials.json
GOOGLE_CLOUD_PROJECT: "robotoff"
GOOGLE_CREDENTIALS: # JSON credentials pasted as environment variable
BATCH_JOB_KEY: # Secure Batch job import with a token key

x-robotoff-worker-base:
&robotoff-worker
restart: $RESTART_POLICY
Expand Down
Loading
Loading