-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
adding load and save posterior software to evaluate.py
- Loading branch information
1 parent
94c0f36
commit b4ac468
Showing
1 changed file
with
39 additions
and
27 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,46 +1,58 @@ | ||
""" | ||
Simple stub functions to use in inference | ||
Simple stub functions to use in evaluating inference from a previously trained inference model. | ||
""" | ||
|
||
import argparse | ||
import pickle | ||
|
||
|
||
def load_model(checkpoint_path): | ||
""" | ||
Load the entire model for prediction with an input | ||
class InferenceModel: | ||
def save_model_pkl(self, path, model_name, posterior): | ||
""" | ||
Save the pkl'ed saved posterior model | ||
:param checkpoint_path: location | ||
:return: loaded model object that can be used with the predict function | ||
""" | ||
pass | ||
:param path: Location to save the model | ||
:param model_name: Name of the model | ||
:param posterior: Model object to be saved | ||
""" | ||
file_name = path + model_name + ".pkl" | ||
with open(file_name, "wb") as file: | ||
pickle.dump(posterior, file) | ||
|
||
def load_model_pkl(self, path, model_name): | ||
""" | ||
Load the pkl'ed saved posterior model | ||
def predict(input, model): | ||
""" | ||
:param path: Location to load the model from | ||
:param model_name: Name of the model | ||
:return: Loaded model object that can be used with the predict function | ||
""" | ||
with open(path + model_name + ".pkl", 'rb') as file: | ||
posterior = pickle.load(file) | ||
return posterior | ||
|
||
:param input: loaded object used for inference | ||
:param model: loaded model | ||
:return: Prediction | ||
""" | ||
return 0 | ||
|
||
def load_inference_object(input_path): | ||
""" | ||
def predict(input, model): | ||
""" | ||
:param input_path: path to the object you want to predict | ||
:return: loaded object | ||
""" | ||
return 0 | ||
:param input: loaded object used for inference | ||
:param model: loaded model | ||
:return: Prediction | ||
""" | ||
return 0 | ||
|
||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--checkpoint", type=str, help="Checkpoint to unloaded model checkpoint, either weights or the compressed model object") | ||
parser.add_argument("--input", type=str, help="path to object to predict quality of") | ||
parser.add_argument("--path", type=str, help="path to saved posterior") | ||
parser.add_argument("--name", type=str, help="saved posterior name") | ||
args = parser.parse_args() | ||
|
||
model = load_model(args.checkpoint) | ||
pred_obj = load_inference_object(args.input) | ||
# Create an instance of InferenceModel | ||
inference_model = InferenceModel() | ||
|
||
# Load the model | ||
model = inference_model.load_model_pkl(args.path, args.name) | ||
|
||
prediction = predict(pred_obj, model) | ||
print(prediction) | ||
inference_obj = inference_model.predict(model) |