From 1e57f01caf9db4d6bf1341e44124f36abb0f8afe Mon Sep 17 00:00:00 2001 From: Santiago Somoza <45318759+santi1234567@users.noreply.github.com> Date: Thu, 1 Aug 2024 22:37:52 -0300 Subject: [PATCH] Support pickle model import on API server (#34) * rename knn_classifier to classifier * remove deprecated file * support changing classifier type * add classifier type flag * update readme * linting * Add method to import persisted model and allow user to set env variable * pickle is not needed anymore in the server file * add workaround for using pickle with gunicorn * linting --------- Co-authored-by: Tarun --- api_server.py | 29 +++++++++++++++++++++++++---- classifier.py | 12 ++++++++++++ 2 files changed, 37 insertions(+), 4 deletions(-) diff --git a/api_server.py b/api_server.py index 2d9ba21..1bf06d1 100644 --- a/api_server.py +++ b/api_server.py @@ -1,7 +1,6 @@ import os import json import falcon - from multi_classifier import MultiClassifier from build_db import ( open_block_db, @@ -17,12 +16,21 @@ count_false_positives, count_false_negatives, ) +import __main__ +from classifier import ( + Classifier, + import_classifier, +) + +__main__.Classifier = Classifier + DATA_DIR = "./data/mainnet/training" BLOCK_DB = os.environ.get("BLOCK_DB") or "./block_db.sqlite" BN_URL = "http://localhost:5052" SELF_URL = "http://localhost:8000" DISABLE_CLASSIFIER = "DISABLE_CLASSIFIER" in os.environ +MODEL_PATH = os.environ.get("MODEL_PATH") or "" class Classify: @@ -202,9 +210,22 @@ def on_get(self, req, resp, client, start_slot, end_slot=None): classifier = None if not DISABLE_CLASSIFIER: - print("Initialising classifier, this could take a moment...") - classifier = MultiClassifier(DATA_DIR) if not DISABLE_CLASSIFIER else None - print("Done") + if MODEL_PATH != "": + if MODEL_PATH.endswith(".pkl"): + classifier = import_classifier(MODEL_PATH) + + else: + print("model path must end with .pkl") + exit(0) + + else: + print("Initialising classifier, this could take a moment...") + classifier = MultiClassifier(DATA_DIR) if not DISABLE_CLASSIFIER else None + print("Done") + +if classifier is None: + print("The classifier was not loaded") + exit(0) block_db = open_block_db(BLOCK_DB) diff --git a/classifier.py b/classifier.py index 9463b0f..072d2a4 100755 --- a/classifier.py +++ b/classifier.py @@ -290,6 +290,18 @@ def persist_classifier(classifier: Classifier, name: str) -> None: print(f"Failed to persist classifier due to {e}") +def import_classifier(model_path: str) -> Classifier: + print(f"""Loading classifier from {model_path}""") + + try: + classifier = pickle.load(open(model_path, "rb")) + print("Loaded classifier into memory") + return classifier + + except Exception as e: + print(f"Failed to import classifier due to {e}") + + def main(): args = parse_args() data_dir = args.data_dir