Skip to content

Commit

Permalink
style: ✨ Black on spellcheck script
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremyarancio committed Sep 4, 2024
1 parent be475bd commit 762722f
Showing 1 changed file with 75 additions and 37 deletions.
112 changes: 75 additions & 37 deletions batch/spellcheck/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,53 @@


def parse() -> argparse.Namespace:
"""Parse command line arguments.
"""
"""Parse command line arguments."""
parser = argparse.ArgumentParser(description="Spellcheck module.")
parser.add_argument("--data_bucket", type=str, default="robotoff-spellcheck", help="Bucket name.")
parser.add_argument("--pre_data_suffix", type=str, default="data/preprocessed_data.parquet", help="Dataset suffix containing the data to be processed.")
parser.add_argument("--post_data_suffix", type=str, default="data/postprocessed_data.parquet", help="Dataset suffix containing the processed data.")
parser.add_argument("--model_path", default="openfoodfacts/spellcheck-mistral-7b", type=str, help="HF model path.")
parser.add_argument("--max_model_len", default=1024, type=int, help="Maximum model context length. A lower max context length reduces the memory footprint and accelerate the inference.")
parser.add_argument("--temperature", default=0, type=float, help="Sampling temperature.")
parser.add_argument("--max_tokens", default=1024, type=int, help="Maximum number of tokens to generate.")
parser.add_argument("--quantization", default="fp8", type=str, help="Quantization type.")
parser.add_argument("--dtype", default="auto", type=str, help="Model weights precision. Default corresponds to the modle config (float16 here)")
parser.add_argument(
"--data_bucket", type=str, default="robotoff-spellcheck", help="Bucket name."
)
parser.add_argument(
"--pre_data_suffix",
type=str,
default="data/preprocessed_data.parquet",
help="Dataset suffix containing the data to be processed.",
)
parser.add_argument(
"--post_data_suffix",
type=str,
default="data/postprocessed_data.parquet",
help="Dataset suffix containing the processed data.",
)
parser.add_argument(
"--model_path",
default="openfoodfacts/spellcheck-mistral-7b",
type=str,
help="HF model path.",
)
parser.add_argument(
"--max_model_len",
default=1024,
type=int,
help="Maximum model context length. A lower max context length reduces the memory footprint and accelerate the inference.",
)
parser.add_argument(
"--temperature", default=0, type=float, help="Sampling temperature."
)
parser.add_argument(
"--max_tokens",
default=1024,
type=int,
help="Maximum number of tokens to generate.",
)
parser.add_argument(
"--quantization", default="fp8", type=str, help="Quantization type."
)
parser.add_argument(
"--dtype",
default="auto",
type=str,
help="Model weights precision. Default corresponds to the modle config (float16 here)",
)
return parser.parse_args()


Expand All @@ -43,7 +78,7 @@ def main():
Original lists of ingredients are stored in a gs bucket before being loaded then processed by the model.
The corrected lists of ingredients are then stored back in gs.
We use vLLM to process the batch optimaly. The model is loaded from the Open Food Facts Hugging Face model repository.
We use vLLM to process the batch optimaly. The model is loaded from the Open Food Facts Hugging Face model repository.
"""
logger.info("Starting batch processing job.")
args = parse()
Expand All @@ -52,32 +87,35 @@ def main():
data = load_gcs(bucket_name=args.data_bucket, suffix=args.pre_data_suffix)
logger.info(f"Feature in uploaded data: {data.columns}")
if not all(feature in data.columns for feature in FEATURES_VALIDATION):
raise ValueError(f"Data should contain the following features: {FEATURES_VALIDATION}. Current features: {data.columns}")
raise ValueError(
f"Data should contain the following features: {FEATURES_VALIDATION}. Current features: {data.columns}"
)

instructions = [prepare_instruction(text) for text in data["text"]]
llm = LLM(
model=args.model_path,
max_model_len=args.max_model_len,
model=args.model_path,
max_model_len=args.max_model_len,
dtype=args.dtype,
quantization=args.quantization,
)
sampling_params = SamplingParams(
temperature=args.temperature,
max_tokens=args.max_tokens
temperature=args.temperature, max_tokens=args.max_tokens
)

logger.info(f"Starting batch inference:\n {llm}.\n\nSampling parameters: {sampling_params}")
data["correction"] = batch_inference(instructions, llm=llm, sampling_params=sampling_params)
logger.info(
f"Starting batch inference:\n {llm}.\n\nSampling parameters: {sampling_params}"
)
data["correction"] = batch_inference(
instructions, llm=llm, sampling_params=sampling_params
)

logger.info(f"Uploading data to GCS: {args.data_bucket}/{args.post_data_suffix}")
# Save DataFrame as Parquet to a temporary file
with tempfile.NamedTemporaryFile(delete=True, suffix='.parquet') as temp_file:
with tempfile.NamedTemporaryFile(delete=True, suffix=".parquet") as temp_file:
data.to_parquet(temp_file.name)
temp_file_name = temp_file.name
upload_gcs(
temp_file_name,
bucket_name=args.data_bucket,
suffix=args.post_data_suffix
temp_file_name, bucket_name=args.data_bucket, suffix=args.post_data_suffix
)

logger.info("Request Robotoff API batch import endpoint.")
Expand All @@ -96,18 +134,14 @@ def prepare_instruction(text: str) -> str:
str: Instruction.
"""
instruction = (
"###Correct the list of ingredients:\n"
+ text
+ "\n\n###Correction:\n"
"###Correct the list of ingredients:\n" + text + "\n\n###Correction:\n"
)
return instruction


def batch_inference(
texts: List[str],
llm: LLM,
sampling_params: SamplingParams
) -> List[str]:
texts: List[str], llm: LLM, sampling_params: SamplingParams
) -> List[str]:
"""Process batch of texts with vLLM.
Args:
Expand All @@ -118,7 +152,10 @@ def batch_inference(
Returns:
List[str]: Processed batch of texts
"""
outputs = llm.generate(texts, sampling_params,)
outputs = llm.generate(
texts,
sampling_params,
)
corrections = [output.outputs[0].text for output in outputs]
return corrections

Expand All @@ -127,7 +164,7 @@ def load_gcs(bucket_name: str, suffix: str) -> pd.DataFrame:
"""Load data from Google Cloud Storage bucket.
Args:
bucket_name (str):
bucket_name (str):
suffix (str): Path inside the bucket
Returns:
Expand Down Expand Up @@ -156,24 +193,25 @@ def upload_gcs(file_path: str, bucket_name: str, suffix: str) -> None:


def run_robotoff_endpoint_batch_import():
"""Run Robotoff api endpoint to import batch data into tables.
"""
"""Run Robotoff api endpoint to import batch data into tables."""
url = "https://robotoff.openfoodfacts.org/api/v1/batch/import"
data = {"job_type": "ingredients_spellcheck"}
headers = {
"Authorization": f"Bearer {os.getenv('BATCH_JOB_KEY')}",
"Content-Type": "application/json"
"Content-Type": "application/json",
}
try:
response = requests.post(
url,
data=data,
headers=headers,
)
logger.info(f"Import batch Robotoff API endpoint succesfully requested: {response.text}")
logger.info(
f"Import batch Robotoff API endpoint succesfully requested: {response.text}"
)
except requests.exceptions.RequestException as e:
raise SystemExit(e)


if __name__ == "__main__":
main()

0 comments on commit 762722f

Please sign in to comment.