Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support pickle model import on API server #34

Merged
merged 11 commits into from
Aug 2, 2024
29 changes: 25 additions & 4 deletions api_server.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os
import json
import falcon

from multi_classifier import MultiClassifier
from build_db import (
open_block_db,
Expand All @@ -17,12 +16,21 @@
count_false_positives,
count_false_negatives,
)
import __main__
from classifier import (
Classifier,
import_classifier,
)

__main__.Classifier = Classifier


DATA_DIR = "./data/mainnet/training"
BLOCK_DB = os.environ.get("BLOCK_DB") or "./block_db.sqlite"
BN_URL = "http://localhost:5052"
SELF_URL = "http://localhost:8000"
DISABLE_CLASSIFIER = "DISABLE_CLASSIFIER" in os.environ
MODEL_PATH = os.environ.get("MODEL_PATH") or ""


class Classify:
Expand Down Expand Up @@ -202,9 +210,22 @@ def on_get(self, req, resp, client, start_slot, end_slot=None):

classifier = None
if not DISABLE_CLASSIFIER:
print("Initialising classifier, this could take a moment...")
classifier = MultiClassifier(DATA_DIR) if not DISABLE_CLASSIFIER else None
print("Done")
if MODEL_PATH != "":
if MODEL_PATH.endswith(".pkl"):
classifier = import_classifier(MODEL_PATH)

else:
print("model path must end with .pkl")
exit(0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should use a non-zero exit code for errors

Suggested change
exit(0)
exit(1)


else:
print("Initialising classifier, this could take a moment...")
classifier = MultiClassifier(DATA_DIR) if not DISABLE_CLASSIFIER else None
print("Done")

if classifier is None:
print("The classifier was not loaded")
exit(0)
Comment on lines +226 to +228
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm going to delete this because it conflicts with the intention of DISABLE_CLASSIFIER. Sometimes it is useful to start the API server without a classifier, e.g. if you just want to serve info from the sqlite DB.

If we want to simplify this in future and always require a classifier, we should remove the DISABLE_CLASSIFIER option

Suggested change
if classifier is None:
print("The classifier was not loaded")
exit(0)


block_db = open_block_db(BLOCK_DB)

Expand Down
12 changes: 12 additions & 0 deletions classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,18 @@ def persist_classifier(classifier: Classifier, name: str) -> None:
print(f"Failed to persist classifier due to {e}")


def import_classifier(model_path: str) -> Classifier:
print(f"""Loading classifier from {model_path}""")

try:
classifier = pickle.load(open(model_path, "rb"))
print("Loaded classifier into memory")
return classifier

except Exception as e:
print(f"Failed to import classifier due to {e}")


def main():
args = parse_args()
data_dir = args.data_dir
Expand Down
Loading