-
Notifications
You must be signed in to change notification settings - Fork 16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support pickle model import on API server #34
Changes from all commits
1dff16c
58ee232
268a3d0
52a3545
9c2c8a0
3eac126
b5f3e2d
e494bc8
632cf87
4f991bb
bd4b890
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
@@ -1,7 +1,6 @@ | ||||||||
import os | ||||||||
import json | ||||||||
import falcon | ||||||||
|
||||||||
from multi_classifier import MultiClassifier | ||||||||
from build_db import ( | ||||||||
open_block_db, | ||||||||
|
@@ -17,12 +16,21 @@ | |||||||
count_false_positives, | ||||||||
count_false_negatives, | ||||||||
) | ||||||||
import __main__ | ||||||||
from classifier import ( | ||||||||
Classifier, | ||||||||
import_classifier, | ||||||||
) | ||||||||
|
||||||||
__main__.Classifier = Classifier | ||||||||
|
||||||||
|
||||||||
DATA_DIR = "./data/mainnet/training" | ||||||||
BLOCK_DB = os.environ.get("BLOCK_DB") or "./block_db.sqlite" | ||||||||
BN_URL = "http://localhost:5052" | ||||||||
SELF_URL = "http://localhost:8000" | ||||||||
DISABLE_CLASSIFIER = "DISABLE_CLASSIFIER" in os.environ | ||||||||
MODEL_PATH = os.environ.get("MODEL_PATH") or "" | ||||||||
|
||||||||
|
||||||||
class Classify: | ||||||||
|
@@ -202,9 +210,22 @@ def on_get(self, req, resp, client, start_slot, end_slot=None): | |||||||
|
||||||||
classifier = None | ||||||||
if not DISABLE_CLASSIFIER: | ||||||||
print("Initialising classifier, this could take a moment...") | ||||||||
classifier = MultiClassifier(DATA_DIR) if not DISABLE_CLASSIFIER else None | ||||||||
print("Done") | ||||||||
if MODEL_PATH != "": | ||||||||
if MODEL_PATH.endswith(".pkl"): | ||||||||
classifier = import_classifier(MODEL_PATH) | ||||||||
|
||||||||
else: | ||||||||
print("model path must end with .pkl") | ||||||||
exit(0) | ||||||||
|
||||||||
else: | ||||||||
print("Initialising classifier, this could take a moment...") | ||||||||
classifier = MultiClassifier(DATA_DIR) if not DISABLE_CLASSIFIER else None | ||||||||
print("Done") | ||||||||
|
||||||||
if classifier is None: | ||||||||
print("The classifier was not loaded") | ||||||||
exit(0) | ||||||||
Comment on lines
+226
to
+228
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm going to delete this because it conflicts with the intention of If we want to simplify this in future and always require a classifier, we should remove the
Suggested change
|
||||||||
|
||||||||
block_db = open_block_db(BLOCK_DB) | ||||||||
|
||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should use a non-zero exit code for errors