Skip to content

Commit

Permalink
Support pickle model import on API server (#34)
Browse files Browse the repository at this point in the history
* rename knn_classifier to classifier

* remove deprecated file

* support changing classifier type

* add classifier type flag

* update readme

* linting

* Add method to import persisted model and allow user to set env variable

* pickle is not needed anymore in the server file

* add workaround for using pickle with gunicorn

* linting

---------

Co-authored-by: Tarun <[email protected]>
  • Loading branch information
santi1234567 and tdahar authored Aug 2, 2024
1 parent 34b8cab commit 1e57f01
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 4 deletions.
29 changes: 25 additions & 4 deletions api_server.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os
import json
import falcon

from multi_classifier import MultiClassifier
from build_db import (
open_block_db,
Expand All @@ -17,12 +16,21 @@
count_false_positives,
count_false_negatives,
)
import __main__
from classifier import (
Classifier,
import_classifier,
)

__main__.Classifier = Classifier


DATA_DIR = "./data/mainnet/training"
BLOCK_DB = os.environ.get("BLOCK_DB") or "./block_db.sqlite"
BN_URL = "http://localhost:5052"
SELF_URL = "http://localhost:8000"
DISABLE_CLASSIFIER = "DISABLE_CLASSIFIER" in os.environ
MODEL_PATH = os.environ.get("MODEL_PATH") or ""


class Classify:
Expand Down Expand Up @@ -202,9 +210,22 @@ def on_get(self, req, resp, client, start_slot, end_slot=None):

classifier = None
if not DISABLE_CLASSIFIER:
print("Initialising classifier, this could take a moment...")
classifier = MultiClassifier(DATA_DIR) if not DISABLE_CLASSIFIER else None
print("Done")
if MODEL_PATH != "":
if MODEL_PATH.endswith(".pkl"):
classifier = import_classifier(MODEL_PATH)

else:
print("model path must end with .pkl")
exit(0)

else:
print("Initialising classifier, this could take a moment...")
classifier = MultiClassifier(DATA_DIR) if not DISABLE_CLASSIFIER else None
print("Done")

if classifier is None:
print("The classifier was not loaded")
exit(0)

block_db = open_block_db(BLOCK_DB)

Expand Down
12 changes: 12 additions & 0 deletions classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,18 @@ def persist_classifier(classifier: Classifier, name: str) -> None:
print(f"Failed to persist classifier due to {e}")


def import_classifier(model_path: str) -> Classifier:
print(f"""Loading classifier from {model_path}""")

try:
classifier = pickle.load(open(model_path, "rb"))
print("Loaded classifier into memory")
return classifier

except Exception as e:
print(f"Failed to import classifier due to {e}")


def main():
args = parse_args()
data_dir = args.data_dir
Expand Down

0 comments on commit 1e57f01

Please sign in to comment.