From 762722f9fb3168abeea9173748d041fbafc6ac21 Mon Sep 17 00:00:00 2001 From: jeremyarancio Date: Wed, 4 Sep 2024 12:44:19 +0200 Subject: [PATCH] style: :sparkles: Black on spellcheck script --- batch/spellcheck/main.py | 112 ++++++++++++++++++++++++++------------- 1 file changed, 75 insertions(+), 37 deletions(-) diff --git a/batch/spellcheck/main.py b/batch/spellcheck/main.py index 34c7b98c1d..0ae9adeff9 100644 --- a/batch/spellcheck/main.py +++ b/batch/spellcheck/main.py @@ -22,18 +22,53 @@ def parse() -> argparse.Namespace: - """Parse command line arguments. - """ + """Parse command line arguments.""" parser = argparse.ArgumentParser(description="Spellcheck module.") - parser.add_argument("--data_bucket", type=str, default="robotoff-spellcheck", help="Bucket name.") - parser.add_argument("--pre_data_suffix", type=str, default="data/preprocessed_data.parquet", help="Dataset suffix containing the data to be processed.") - parser.add_argument("--post_data_suffix", type=str, default="data/postprocessed_data.parquet", help="Dataset suffix containing the processed data.") - parser.add_argument("--model_path", default="openfoodfacts/spellcheck-mistral-7b", type=str, help="HF model path.") - parser.add_argument("--max_model_len", default=1024, type=int, help="Maximum model context length. A lower max context length reduces the memory footprint and accelerate the inference.") - parser.add_argument("--temperature", default=0, type=float, help="Sampling temperature.") - parser.add_argument("--max_tokens", default=1024, type=int, help="Maximum number of tokens to generate.") - parser.add_argument("--quantization", default="fp8", type=str, help="Quantization type.") - parser.add_argument("--dtype", default="auto", type=str, help="Model weights precision. Default corresponds to the modle config (float16 here)") + parser.add_argument( + "--data_bucket", type=str, default="robotoff-spellcheck", help="Bucket name." + ) + parser.add_argument( + "--pre_data_suffix", + type=str, + default="data/preprocessed_data.parquet", + help="Dataset suffix containing the data to be processed.", + ) + parser.add_argument( + "--post_data_suffix", + type=str, + default="data/postprocessed_data.parquet", + help="Dataset suffix containing the processed data.", + ) + parser.add_argument( + "--model_path", + default="openfoodfacts/spellcheck-mistral-7b", + type=str, + help="HF model path.", + ) + parser.add_argument( + "--max_model_len", + default=1024, + type=int, + help="Maximum model context length. A lower max context length reduces the memory footprint and accelerate the inference.", + ) + parser.add_argument( + "--temperature", default=0, type=float, help="Sampling temperature." + ) + parser.add_argument( + "--max_tokens", + default=1024, + type=int, + help="Maximum number of tokens to generate.", + ) + parser.add_argument( + "--quantization", default="fp8", type=str, help="Quantization type." + ) + parser.add_argument( + "--dtype", + default="auto", + type=str, + help="Model weights precision. Default corresponds to the modle config (float16 here)", + ) return parser.parse_args() @@ -43,7 +78,7 @@ def main(): Original lists of ingredients are stored in a gs bucket before being loaded then processed by the model. The corrected lists of ingredients are then stored back in gs. - We use vLLM to process the batch optimaly. The model is loaded from the Open Food Facts Hugging Face model repository. + We use vLLM to process the batch optimaly. The model is loaded from the Open Food Facts Hugging Face model repository. """ logger.info("Starting batch processing job.") args = parse() @@ -52,32 +87,35 @@ def main(): data = load_gcs(bucket_name=args.data_bucket, suffix=args.pre_data_suffix) logger.info(f"Feature in uploaded data: {data.columns}") if not all(feature in data.columns for feature in FEATURES_VALIDATION): - raise ValueError(f"Data should contain the following features: {FEATURES_VALIDATION}. Current features: {data.columns}") + raise ValueError( + f"Data should contain the following features: {FEATURES_VALIDATION}. Current features: {data.columns}" + ) instructions = [prepare_instruction(text) for text in data["text"]] llm = LLM( - model=args.model_path, - max_model_len=args.max_model_len, + model=args.model_path, + max_model_len=args.max_model_len, dtype=args.dtype, quantization=args.quantization, ) sampling_params = SamplingParams( - temperature=args.temperature, - max_tokens=args.max_tokens + temperature=args.temperature, max_tokens=args.max_tokens ) - logger.info(f"Starting batch inference:\n {llm}.\n\nSampling parameters: {sampling_params}") - data["correction"] = batch_inference(instructions, llm=llm, sampling_params=sampling_params) + logger.info( + f"Starting batch inference:\n {llm}.\n\nSampling parameters: {sampling_params}" + ) + data["correction"] = batch_inference( + instructions, llm=llm, sampling_params=sampling_params + ) logger.info(f"Uploading data to GCS: {args.data_bucket}/{args.post_data_suffix}") # Save DataFrame as Parquet to a temporary file - with tempfile.NamedTemporaryFile(delete=True, suffix='.parquet') as temp_file: + with tempfile.NamedTemporaryFile(delete=True, suffix=".parquet") as temp_file: data.to_parquet(temp_file.name) temp_file_name = temp_file.name upload_gcs( - temp_file_name, - bucket_name=args.data_bucket, - suffix=args.post_data_suffix + temp_file_name, bucket_name=args.data_bucket, suffix=args.post_data_suffix ) logger.info("Request Robotoff API batch import endpoint.") @@ -96,18 +134,14 @@ def prepare_instruction(text: str) -> str: str: Instruction. """ instruction = ( - "###Correct the list of ingredients:\n" - + text - + "\n\n###Correction:\n" + "###Correct the list of ingredients:\n" + text + "\n\n###Correction:\n" ) return instruction def batch_inference( - texts: List[str], - llm: LLM, - sampling_params: SamplingParams - ) -> List[str]: + texts: List[str], llm: LLM, sampling_params: SamplingParams +) -> List[str]: """Process batch of texts with vLLM. Args: @@ -118,7 +152,10 @@ def batch_inference( Returns: List[str]: Processed batch of texts """ - outputs = llm.generate(texts, sampling_params,) + outputs = llm.generate( + texts, + sampling_params, + ) corrections = [output.outputs[0].text for output in outputs] return corrections @@ -127,7 +164,7 @@ def load_gcs(bucket_name: str, suffix: str) -> pd.DataFrame: """Load data from Google Cloud Storage bucket. Args: - bucket_name (str): + bucket_name (str): suffix (str): Path inside the bucket Returns: @@ -156,13 +193,12 @@ def upload_gcs(file_path: str, bucket_name: str, suffix: str) -> None: def run_robotoff_endpoint_batch_import(): - """Run Robotoff api endpoint to import batch data into tables. - """ + """Run Robotoff api endpoint to import batch data into tables.""" url = "https://robotoff.openfoodfacts.org/api/v1/batch/import" data = {"job_type": "ingredients_spellcheck"} headers = { "Authorization": f"Bearer {os.getenv('BATCH_JOB_KEY')}", - "Content-Type": "application/json" + "Content-Type": "application/json", } try: response = requests.post( @@ -170,10 +206,12 @@ def run_robotoff_endpoint_batch_import(): data=data, headers=headers, ) - logger.info(f"Import batch Robotoff API endpoint succesfully requested: {response.text}") + logger.info( + f"Import batch Robotoff API endpoint succesfully requested: {response.text}" + ) except requests.exceptions.RequestException as e: raise SystemExit(e) - + if __name__ == "__main__": main()