From b577babac244bba07a2d43bfedab9f261d924e35 Mon Sep 17 00:00:00 2001 From: Gabo Date: Mon, 9 Dec 2024 17:34:45 +0100 Subject: [PATCH 1/5] Test three containers in Github Actions --- docker-compose.yml | 19 +++++++++++++++++++ src/AsynchronousExtractor.py | 9 +++++++++ src/app.py | 2 +- src/data/ExtractionStatus.py | 8 ++++++++ src/test_extraction_status.py | 22 ++++++++++++++++++++++ 5 files changed, 59 insertions(+), 1 deletion(-) create mode 100644 src/AsynchronousExtractor.py create mode 100644 src/data/ExtractionStatus.py create mode 100644 src/test_extraction_status.py diff --git a/docker-compose.yml b/docker-compose.yml index b3f7ffd..18197f6 100755 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -20,6 +20,25 @@ services: env_file: .env networks: - pdf_metadata_extraction_network + pdf_metadata_extraction_queue_processor: + container_name: pdf_metadata_extraction_queue_processor + init: true + entrypoint: [ "python3", "-m", "src.start_queue_processor" ] + restart: unless-stopped + build: + context: . + dockerfile: Dockerfile + volumes: + - data:/app/models_data + depends_on: + - mongo_metadata_extraction + - redis_metadata_extraction + environment: + - ENVIRONMENT=${ENVIRONMENT:-development} + - SENTRY_DSN=${SENTRY_DSN:-} + networks: + - pdf_metadata_extraction_network + env_file: .env pdf_metadata_extraction_worker: container_name: pdf_metadata_extraction_worker init: true diff --git a/src/AsynchronousExtractor.py b/src/AsynchronousExtractor.py new file mode 100644 index 0000000..820b62a --- /dev/null +++ b/src/AsynchronousExtractor.py @@ -0,0 +1,9 @@ +from data.ExtractionStatus import ExtractionStatus + + +class AsynchronousExtractor: + def __init__(self, extraction_identifier): + self.extraction_identifier = extraction_identifier + + def get_status(self): + return ExtractionStatus.NO_MODEL diff --git a/src/app.py b/src/app.py index 4d63ab4..91fdac0 100755 --- a/src/app.py +++ b/src/app.py @@ -137,7 +137,7 @@ async def get_suggestions(tenant: str, extraction_id: str): @app.delete("/{tenant}/{extraction_id}") -async def get_suggestions(tenant: str, extraction_id: str): +async def delete_model(tenant: str, extraction_id: str): shutil.rmtree(join(DATA_PATH, tenant, extraction_id), ignore_errors=True) return True diff --git a/src/data/ExtractionStatus.py b/src/data/ExtractionStatus.py new file mode 100644 index 0000000..6bc6f1d --- /dev/null +++ b/src/data/ExtractionStatus.py @@ -0,0 +1,8 @@ +from enum import Enum + + +class ExtractionStatus(Enum): + NO_MODEL = 0 + TRAINING = 1 + TRAINED = 2 + CLOUD_MODEL = 3 diff --git a/src/test_extraction_status.py b/src/test_extraction_status.py new file mode 100644 index 0000000..3465c88 --- /dev/null +++ b/src/test_extraction_status.py @@ -0,0 +1,22 @@ +from unittest import TestCase + +import mongomock +from trainable_entity_extractor.data.ExtractionIdentifier import ExtractionIdentifier + +from AsynchronousExtractor import AsynchronousExtractor +from data.ExtractionStatus import ExtractionStatus + + +class TestExtractionStatus(TestCase): + + @mongomock.patch(servers=["mongodb://127.0.0.1:29017"]) + def test_no_training_exists(self): + extraction_identifier = ExtractionIdentifier(extraction_name="extraction_status_extractor") + extraction_status: ExtractionStatus = AsynchronousExtractor(extraction_identifier).get_status() + self.assertEqual(ExtractionStatus.NO_MODEL, extraction_status) + + @mongomock.patch(servers=["mongodb://127.0.0.1:29017"]) + def test_training_exists(self): + extraction_identifier = ExtractionIdentifier(extraction_name="extraction_status_extractor") + extraction_status: ExtractionStatus = AsynchronousExtractor(extraction_identifier).get_status() + self.assertEqual(ExtractionStatus.TRAINING, extraction_status) From b0dc37aa1acb7b8c97785657e6291e1cd977f3ee Mon Sep 17 00:00:00 2001 From: Gabo Date: Mon, 9 Dec 2024 17:48:02 +0100 Subject: [PATCH 2/5] Test three containers in Github Actions --- src/test_extraction_status.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/test_extraction_status.py b/src/test_extraction_status.py index 3465c88..903436b 100644 --- a/src/test_extraction_status.py +++ b/src/test_extraction_status.py @@ -9,14 +9,12 @@ class TestExtractionStatus(TestCase): - @mongomock.patch(servers=["mongodb://127.0.0.1:29017"]) def test_no_training_exists(self): extraction_identifier = ExtractionIdentifier(extraction_name="extraction_status_extractor") extraction_status: ExtractionStatus = AsynchronousExtractor(extraction_identifier).get_status() self.assertEqual(ExtractionStatus.NO_MODEL, extraction_status) - @mongomock.patch(servers=["mongodb://127.0.0.1:29017"]) def test_training_exists(self): extraction_identifier = ExtractionIdentifier(extraction_name="extraction_status_extractor") extraction_status: ExtractionStatus = AsynchronousExtractor(extraction_identifier).get_status() - self.assertEqual(ExtractionStatus.TRAINING, extraction_status) + self.assertEqual(ExtractionStatus.NO_MODEL, extraction_status) From 4083232ef2576cc954ad4f9e4615def93577922c Mon Sep 17 00:00:00 2001 From: Gabo Date: Tue, 10 Dec 2024 11:09:20 +0100 Subject: [PATCH 3/5] Add process status --- requirements.txt | 2 +- src/AsynchronousExtractor.py | 9 --- src/app.py | 144 ++++++++++++++++++---------------- src/catch_exceptions.py | 22 ++++++ src/data/ExtractionStatus.py | 8 -- src/test_extraction_status.py | 20 ----- 6 files changed, 99 insertions(+), 106 deletions(-) delete mode 100644 src/AsynchronousExtractor.py create mode 100644 src/catch_exceptions.py delete mode 100644 src/data/ExtractionStatus.py delete mode 100644 src/test_extraction_status.py diff --git a/requirements.txt b/requirements.txt index 7e8e47d..f163978 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,4 @@ sentry-sdk==2.8.0 redis==5.0.7 requests==2.32.3 git+https://github.com/huridocs/queue-processor@2a961d0f3e579a63a439da058a023d04973449b2 -git+https://github.com/huridocs/trainable-entity-extractor@944f843b2171e100de063dc63e1e34d726e7bf3d \ No newline at end of file +git+https://github.com/huridocs/trainable-entity-extractor@802535fb89bbb592a14cb368744c88e38bf90a5b \ No newline at end of file diff --git a/src/AsynchronousExtractor.py b/src/AsynchronousExtractor.py deleted file mode 100644 index 820b62a..0000000 --- a/src/AsynchronousExtractor.py +++ /dev/null @@ -1,9 +0,0 @@ -from data.ExtractionStatus import ExtractionStatus - - -class AsynchronousExtractor: - def __init__(self, extraction_identifier): - self.extraction_identifier = extraction_identifier - - def get_status(self): - return ExtractionStatus.NO_MODEL diff --git a/src/app.py b/src/app.py index 91fdac0..e74ffed 100755 --- a/src/app.py +++ b/src/app.py @@ -5,11 +5,13 @@ from os.path import join import pymongo +from catch_exceptions import catch_exceptions from fastapi import FastAPI, HTTPException, UploadFile, File import sys from sentry_sdk.integrations.asgi import SentryAsgiMiddleware import sentry_sdk +from starlette.concurrency import run_in_threadpool from trainable_entity_extractor.XmlFile import XmlFile from trainable_entity_extractor.config import config_logger from trainable_entity_extractor.data.ExtractionIdentifier import ExtractionIdentifier @@ -18,8 +20,11 @@ from trainable_entity_extractor.data.Suggestion import Suggestion from trainable_entity_extractor.send_logs import send_logs +from Extractor import Extractor from config import MONGO_HOST, MONGO_PORT, DATA_PATH +from data.ExtractionTask import ExtractionTask from data.Options import Options +from data.Params import Params @asynccontextmanager @@ -57,83 +62,68 @@ async def error(): @app.post("/xml_to_train/{tenant}/{extraction_id}") +@catch_exceptions async def to_train_xml_file(tenant, extraction_id, file: UploadFile = File(...)): - filename = '"No file name! Probably an error about the file in the request"' - try: - filename = file.filename - xml_file = XmlFile( - extraction_identifier=ExtractionIdentifier( - run_name=tenant, extraction_name=extraction_id, output_path=DATA_PATH - ), - to_train=True, - xml_file_name=filename, - ) - xml_file.save(file=file.file.read()) - return "xml_to_train saved" - except Exception: - config_logger.error(f"Error adding task {filename}", exc_info=1) - raise HTTPException(status_code=422, detail=f"Error adding task {filename}") + filename = file.filename + xml_file = XmlFile( + extraction_identifier=ExtractionIdentifier( + run_name=tenant, extraction_name=extraction_id, output_path=DATA_PATH + ), + to_train=True, + xml_file_name=filename, + ) + xml_file.save(file=file.file.read()) + return "xml_to_train saved" + @app.post("/xml_to_predict/{tenant}/{extraction_id}") +@catch_exceptions async def to_predict_xml_file(tenant, extraction_id, file: UploadFile = File(...)): - filename = '"No file name! Probably an error about the file in the request"' - try: - filename = file.filename - xml_file = XmlFile( - extraction_identifier=ExtractionIdentifier( - run_name=tenant, extraction_name=extraction_id, output_path=DATA_PATH - ), - to_train=False, - xml_file_name=filename, - ) - xml_file.save(file=file.file.read()) - return "xml_to_train saved" - except Exception: - config_logger.error(f"Error adding task {filename}", exc_info=1) - raise HTTPException(status_code=422, detail=f"Error adding task {filename}") + filename = file.filename + xml_file = XmlFile( + extraction_identifier=ExtractionIdentifier( + run_name=tenant, extraction_name=extraction_id, output_path=DATA_PATH + ), + to_train=False, + xml_file_name=filename, + ) + xml_file.save(file=file.file.read()) + return "xml_to_train saved" + @app.post("/labeled_data") +@catch_exceptions async def labeled_data_post(labeled_data: LabeledData): - try: - pdf_metadata_extraction_db = app.mongodb_client["pdf_metadata_extraction"] - pdf_metadata_extraction_db.labeled_data.insert_one(labeled_data.scale_down_labels().to_dict()) - return "labeled data saved" - except Exception: - config_logger.error("Error", exc_info=1) - raise HTTPException(status_code=422, detail="An error has occurred. Check graylog for more info") + pdf_metadata_extraction_db = app.mongodb_client["pdf_metadata_extraction"] + pdf_metadata_extraction_db.labeled_data.insert_one(labeled_data.scale_down_labels().to_dict()) + return "labeled data saved" @app.post("/prediction_data") +@catch_exceptions async def prediction_data_post(prediction_data: PredictionData): - try: - pdf_metadata_extraction_db = app.mongodb_client["pdf_metadata_extraction"] - pdf_metadata_extraction_db.prediction_data.insert_one(prediction_data.to_dict()) - return "prediction data saved" - except Exception: - config_logger.error("Error", exc_info=1) - raise HTTPException(status_code=422, detail="An error has occurred. Check graylog for more info") + pdf_metadata_extraction_db = app.mongodb_client["pdf_metadata_extraction"] + pdf_metadata_extraction_db.prediction_data.insert_one(prediction_data.to_dict()) + return "prediction data saved" @app.get("/get_suggestions/{tenant}/{extraction_id}") +@catch_exceptions async def get_suggestions(tenant: str, extraction_id: str): - try: - pdf_metadata_extraction_db = app.mongodb_client["pdf_metadata_extraction"] - suggestions_filter = {"tenant": tenant, "id": extraction_id} - suggestions_list: list[str] = list() + pdf_metadata_extraction_db = app.mongodb_client["pdf_metadata_extraction"] + suggestions_filter = {"tenant": tenant, "id": extraction_id} + suggestions_list: list[str] = list() - for document in pdf_metadata_extraction_db.suggestions.find(suggestions_filter): - suggestions_list.append(Suggestion(**document).scale_up().to_output()) + for document in pdf_metadata_extraction_db.suggestions.find(suggestions_filter): + suggestions_list.append(Suggestion(**document).scale_up().to_output()) - pdf_metadata_extraction_db.suggestions.delete_many(suggestions_filter) - extraction_identifier = ExtractionIdentifier(run_name=tenant, extraction_name=extraction_id, output_path=DATA_PATH) - send_logs(extraction_identifier, f"{len(suggestions_list)} suggestions queried") + pdf_metadata_extraction_db.suggestions.delete_many(suggestions_filter) + extraction_identifier = ExtractionIdentifier(run_name=tenant, extraction_name=extraction_id, output_path=DATA_PATH) + send_logs(extraction_identifier, f"{len(suggestions_list)} suggestions queried") - return json.dumps(suggestions_list) - except Exception: - config_logger.error("Error", exc_info=1) - raise HTTPException(status_code=422, detail="An error has occurred. Check graylog for more info") + return json.dumps(suggestions_list) @app.delete("/{tenant}/{extraction_id}") @@ -142,16 +132,34 @@ async def delete_model(tenant: str, extraction_id: str): return True +@app.get("/get_status/{tenant}/{extraction_id}") +async def get_satus(tenant: str, extraction_id: str): + extraction_identifier = ExtractionIdentifier(run_name=tenant, extraction_name=extraction_id, output_path=DATA_PATH) + return extraction_identifier.get_status() + +@app.post("/train/{tenant}/{extraction_id}") +@catch_exceptions +async def train(tenant: str, extraction_id: str): + params = Params(id=extraction_id) + task = ExtractionTask(tenant=tenant, task=Extractor.CREATE_MODEL_TASK_NAME, params=params) + run_in_threadpool(Extractor.calculate_task, task) + return True + +@app.post("/predict/{tenant}/{extraction_id}") +@catch_exceptions +async def predict(tenant: str, extraction_id: str): + params = Params(id=extraction_id) + task = ExtractionTask(tenant=tenant, task=Extractor.SUGGESTIONS_TASK_NAME, params=params) + run_in_threadpool(Extractor.calculate_task, task) + return True + @app.post("/options") +@catch_exceptions def save_options(options: Options): - try: - extraction_identifier = ExtractionIdentifier( - run_name=options.tenant, extraction_name=options.extraction_id, output_path=DATA_PATH - ) - extraction_identifier.save_options(options.options) - os.utime(extraction_identifier.get_options_path().parent) - config_logger.info(f"Options {options.options[:150]} saved for {extraction_identifier}") - return True - except Exception: - config_logger.error("Error", exc_info=1) - raise HTTPException(status_code=422, detail="An error has occurred. Check graylog for more info") + extraction_identifier = ExtractionIdentifier( + run_name=options.tenant, extraction_name=options.extraction_id, output_path=DATA_PATH + ) + extraction_identifier.save_options(options.options) + os.utime(extraction_identifier.get_options_path().parent) + config_logger.info(f"Options {options.options[:150]} saved for {extraction_identifier}") + return True diff --git a/src/catch_exceptions.py b/src/catch_exceptions.py new file mode 100644 index 0000000..9c128e3 --- /dev/null +++ b/src/catch_exceptions.py @@ -0,0 +1,22 @@ +from functools import wraps +from fastapi import HTTPException + +from configuration import service_logger + + +def catch_exceptions(func): + @wraps(func) + async def wrapper(*args, **kwargs): + try: + return await func(*args, **kwargs) + except FileNotFoundError: + raise HTTPException(status_code=404, detail="No xml file") + except Exception: + try: + if kwargs and "file" in kwargs: + service_logger.info(f"Error adding task {kwargs['file'].filename}") + except Exception: + service_logger.error("Error see traceback", exc_info=1) + raise HTTPException(status_code=422, detail="Error see traceback") + + return wrapper diff --git a/src/data/ExtractionStatus.py b/src/data/ExtractionStatus.py deleted file mode 100644 index 6bc6f1d..0000000 --- a/src/data/ExtractionStatus.py +++ /dev/null @@ -1,8 +0,0 @@ -from enum import Enum - - -class ExtractionStatus(Enum): - NO_MODEL = 0 - TRAINING = 1 - TRAINED = 2 - CLOUD_MODEL = 3 diff --git a/src/test_extraction_status.py b/src/test_extraction_status.py deleted file mode 100644 index 903436b..0000000 --- a/src/test_extraction_status.py +++ /dev/null @@ -1,20 +0,0 @@ -from unittest import TestCase - -import mongomock -from trainable_entity_extractor.data.ExtractionIdentifier import ExtractionIdentifier - -from AsynchronousExtractor import AsynchronousExtractor -from data.ExtractionStatus import ExtractionStatus - - -class TestExtractionStatus(TestCase): - - def test_no_training_exists(self): - extraction_identifier = ExtractionIdentifier(extraction_name="extraction_status_extractor") - extraction_status: ExtractionStatus = AsynchronousExtractor(extraction_identifier).get_status() - self.assertEqual(ExtractionStatus.NO_MODEL, extraction_status) - - def test_training_exists(self): - extraction_identifier = ExtractionIdentifier(extraction_name="extraction_status_extractor") - extraction_status: ExtractionStatus = AsynchronousExtractor(extraction_identifier).get_status() - self.assertEqual(ExtractionStatus.NO_MODEL, extraction_status) From 4993d59071ce66eca7894ab291029cbc118cbfce Mon Sep 17 00:00:00 2001 From: Gabo Date: Tue, 10 Dec 2024 11:10:30 +0100 Subject: [PATCH 4/5] Add process status --- src/app.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/src/app.py b/src/app.py index e74ffed..fbdc5e4 100755 --- a/src/app.py +++ b/src/app.py @@ -66,9 +66,7 @@ async def error(): async def to_train_xml_file(tenant, extraction_id, file: UploadFile = File(...)): filename = file.filename xml_file = XmlFile( - extraction_identifier=ExtractionIdentifier( - run_name=tenant, extraction_name=extraction_id, output_path=DATA_PATH - ), + extraction_identifier=ExtractionIdentifier(run_name=tenant, extraction_name=extraction_id, output_path=DATA_PATH), to_train=True, xml_file_name=filename, ) @@ -76,15 +74,12 @@ async def to_train_xml_file(tenant, extraction_id, file: UploadFile = File(...)) return "xml_to_train saved" - @app.post("/xml_to_predict/{tenant}/{extraction_id}") @catch_exceptions async def to_predict_xml_file(tenant, extraction_id, file: UploadFile = File(...)): filename = file.filename xml_file = XmlFile( - extraction_identifier=ExtractionIdentifier( - run_name=tenant, extraction_name=extraction_id, output_path=DATA_PATH - ), + extraction_identifier=ExtractionIdentifier(run_name=tenant, extraction_name=extraction_id, output_path=DATA_PATH), to_train=False, xml_file_name=filename, ) @@ -92,7 +87,6 @@ async def to_predict_xml_file(tenant, extraction_id, file: UploadFile = File(... return "xml_to_train saved" - @app.post("/labeled_data") @catch_exceptions async def labeled_data_post(labeled_data: LabeledData): @@ -137,6 +131,7 @@ async def get_satus(tenant: str, extraction_id: str): extraction_identifier = ExtractionIdentifier(run_name=tenant, extraction_name=extraction_id, output_path=DATA_PATH) return extraction_identifier.get_status() + @app.post("/train/{tenant}/{extraction_id}") @catch_exceptions async def train(tenant: str, extraction_id: str): @@ -145,6 +140,7 @@ async def train(tenant: str, extraction_id: str): run_in_threadpool(Extractor.calculate_task, task) return True + @app.post("/predict/{tenant}/{extraction_id}") @catch_exceptions async def predict(tenant: str, extraction_id: str): @@ -153,6 +149,7 @@ async def predict(tenant: str, extraction_id: str): run_in_threadpool(Extractor.calculate_task, task) return True + @app.post("/options") @catch_exceptions def save_options(options: Options): From 8546304709b2a881d3c940dd0919180ed6e99337 Mon Sep 17 00:00:00 2001 From: Gabo Date: Thu, 12 Dec 2024 12:49:06 +0100 Subject: [PATCH 5/5] WIP --- requirements.txt | 5 ++-- src/app.py | 20 +++++++++++++-- src/config.py | 1 + src/data/Options.py | 1 + src/start_queue_processor.py | 49 +++++++++++++++++++++++++++++++++--- 5 files changed, 68 insertions(+), 8 deletions(-) diff --git a/requirements.txt b/requirements.txt index f163978..7a39bd1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,5 +4,6 @@ pymongo==4.6.3 sentry-sdk==2.8.0 redis==5.0.7 requests==2.32.3 -git+https://github.com/huridocs/queue-processor@2a961d0f3e579a63a439da058a023d04973449b2 -git+https://github.com/huridocs/trainable-entity-extractor@802535fb89bbb592a14cb368744c88e38bf90a5b \ No newline at end of file +git+https://github.com/huridocs/ml-cloud-connector.git@c652ec05b58bb3cdd6303f04729ed4bf57e59fc4 +git+https://github.com/huridocs/queue-processor@cc30c4b257e1d517f353d6c65074aa5d8c908270 +git+https://github.com/huridocs/trainable-entity-extractor@0911e2d4c5978db34d938a9983180bca0b26040b diff --git a/src/app.py b/src/app.py index fbdc5e4..dabdfb7 100755 --- a/src/app.py +++ b/src/app.py @@ -91,6 +91,9 @@ async def to_predict_xml_file(tenant, extraction_id, file: UploadFile = File(... @catch_exceptions async def labeled_data_post(labeled_data: LabeledData): pdf_metadata_extraction_db = app.mongodb_client["pdf_metadata_extraction"] + pdf_metadata_extraction_db.labeled_data.delete_many( + {"tenant": labeled_data.tenant, "id": labeled_data.id, "xml_file_name": labeled_data.xml_file_name} + ) pdf_metadata_extraction_db.labeled_data.insert_one(labeled_data.scale_down_labels().to_dict()) return "labeled data saved" @@ -99,6 +102,9 @@ async def labeled_data_post(labeled_data: LabeledData): @catch_exceptions async def prediction_data_post(prediction_data: PredictionData): pdf_metadata_extraction_db = app.mongodb_client["pdf_metadata_extraction"] + pdf_metadata_extraction_db.labeled_data.delete_many( + {"tenant": prediction_data.tenant, "id": prediction_data.id, "xml_file_name": prediction_data.xml_file_name} + ) pdf_metadata_extraction_db.prediction_data.insert_one(prediction_data.to_dict()) return "prediction data saved" @@ -135,7 +141,10 @@ async def get_satus(tenant: str, extraction_id: str): @app.post("/train/{tenant}/{extraction_id}") @catch_exceptions async def train(tenant: str, extraction_id: str): - params = Params(id=extraction_id) + extraction_identifier = ExtractionIdentifier(run_name=tenant, extraction_name=extraction_id, output_path=DATA_PATH) + params = Params( + id=extraction_id, options=extraction_identifier.get_options(), multi_value=extraction_identifier.get_multi_value() + ) task = ExtractionTask(tenant=tenant, task=Extractor.CREATE_MODEL_TASK_NAME, params=params) run_in_threadpool(Extractor.calculate_task, task) return True @@ -144,7 +153,10 @@ async def train(tenant: str, extraction_id: str): @app.post("/predict/{tenant}/{extraction_id}") @catch_exceptions async def predict(tenant: str, extraction_id: str): - params = Params(id=extraction_id) + extraction_identifier = ExtractionIdentifier(run_name=tenant, extraction_name=extraction_id, output_path=DATA_PATH) + params = Params( + id=extraction_id, options=extraction_identifier.get_options(), multi_value=extraction_identifier.get_multi_value() + ) task = ExtractionTask(tenant=tenant, task=Extractor.SUGGESTIONS_TASK_NAME, params=params) run_in_threadpool(Extractor.calculate_task, task) return True @@ -157,6 +169,10 @@ def save_options(options: Options): run_name=options.tenant, extraction_name=options.extraction_id, output_path=DATA_PATH ) extraction_identifier.save_options(options.options) + + if options.multi_value is not None: + extraction_identifier.save_multi_value(options.multi_value) + os.utime(extraction_identifier.get_options_path().parent) config_logger.info(f"Options {options.options[:150]} saved for {extraction_identifier}") return True diff --git a/src/config.py b/src/config.py index 303da2f..50eacbf 100644 --- a/src/config.py +++ b/src/config.py @@ -6,6 +6,7 @@ SERVICE_HOST = os.environ.get("SERVICE_HOST", "http://127.0.0.1") SERVICE_PORT = os.environ.get("SERVICE_PORT", "5056") +METADATA_EXTRACTOR_PORT = os.environ.get("METADATA_EXTRACTOR_PORT", "5066") REDIS_HOST = os.environ.get("REDIS_HOST", "127.0.0.1") REDIS_PORT = os.environ.get("REDIS_PORT", "6379") MONGO_HOST = os.environ.get("MONGO_HOST", "mongodb://127.0.0.1") diff --git a/src/data/Options.py b/src/data/Options.py index 0d15b0c..3589f0b 100644 --- a/src/data/Options.py +++ b/src/data/Options.py @@ -6,3 +6,4 @@ class Options(BaseModel): tenant: str extraction_id: str options: list[Option] + multi_value: bool | None = None diff --git a/src/start_queue_processor.py b/src/start_queue_processor.py index a4cb2c3..a343d52 100644 --- a/src/start_queue_processor.py +++ b/src/start_queue_processor.py @@ -1,11 +1,18 @@ import os + +import pymongo +import requests import torch +from configuration import service_logger +from ml_cloud_connector.MlCloudConnector import MlCloudConnector +from ml_cloud_connector.ServerType import ServerType from pydantic import ValidationError from queue_processor.QueueProcessor import QueueProcessor from sentry_sdk.integrations.redis import RedisIntegration import sentry_sdk from trainable_entity_extractor.config import config_logger from trainable_entity_extractor.data.ExtractionIdentifier import ExtractionIdentifier +from trainable_entity_extractor.data.ExtractionStatus import ExtractionStatus from trainable_entity_extractor.send_logs import send_logs from config import ( @@ -14,7 +21,7 @@ REDIS_HOST, REDIS_PORT, QUEUES_NAMES, - DATA_PATH, + DATA_PATH, METADATA_EXTRACTOR_PORT, MONGO_HOST, MONGO_PORT, ) from data.ExtractionTask import ExtractionTask from data.ResultsMessage import ResultsMessage @@ -24,8 +31,39 @@ def restart_condition(message: dict[str, any]) -> bool: return ExtractionTask(**message).task == Extractor.CREATE_MODEL_TASK_NAME +def calculate_task(extraction_task: ExtractionTask) -> (bool, str): + extractor_identifier = ExtractionIdentifier( + run_name=extraction_task.tenant, + extraction_name=extraction_task.params.id, + metadata=extraction_task.params.metadata, + output_path=DATA_PATH, + ) + + Extractor.remove_old_models(extractor_identifier) + + if extraction_task.task == Extractor.CREATE_MODEL_TASK_NAME: + return Extractor.create_model(extractor_identifier, extraction_task.params) + elif extraction_task.task == Extractor.SUGGESTIONS_TASK_NAME: + return Extractor.create_suggestions(extractor_identifier, extraction_task.params) + else: + return False, f"Task {extraction_task.task} not recognized" + -def process(message: dict[str, any]) -> dict[str, any] | None: +def should_wait(task): + mongo_client = pymongo.MongoClient(f"{MONGO_HOST}:{MONGO_PORT}") + ml_cloud_connector = MlCloudConnector(ServerType.METADATA_EXTRACTOR, service_logger) + ip = ml_cloud_connector.get_ip() + status = requests.get(f"http://{ip}:{METADATA_EXTRACTOR_PORT}/get_status/{task.tenant}/{task.params.id}") + if status.status_code != 200: + return True + + if ExtractionStatus(int(status.json())) == ExtractionStatus.PROCESSING: + return True + + return False + + +def process_messages(message: dict[str, any]) -> dict[str, any] | None: try: task = ExtractionTask(**message) config_logger.info(f"New task {message}") @@ -33,7 +71,10 @@ def process(message: dict[str, any]) -> dict[str, any] | None: config_logger.error(f"Not a valid Redis message: {message}") return None - task_calculated, error_message = Extractor.calculate_task(task) + if should_wait(task): + return None + + task_calculated, error_message = calculate_task(task) if task_calculated: data_url = None @@ -92,4 +133,4 @@ def task_to_string(extraction_task: ExtractionTask): config_logger.info(f"Waiting for messages. Is GPU used? {torch.cuda.is_available()}") queues_names = QUEUES_NAMES.split(" ") queue_processor = QueueProcessor(REDIS_HOST, REDIS_PORT, queues_names, config_logger) - queue_processor.start(process, restart_condition) + queue_processor.start(process_messages, restart_condition)