Skip to content

Commit

Permalink
adding load and save posterior software to evaluate.py
Browse files Browse the repository at this point in the history
  • Loading branch information
beckynevin committed Jan 26, 2024
1 parent 94c0f36 commit b4ac468
Showing 1 changed file with 39 additions and 27 deletions.
66 changes: 39 additions & 27 deletions src/scripts/evaluate.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,58 @@
"""
Simple stub functions to use in inference
Simple stub functions to use in evaluating inference from a previously trained inference model.
"""

import argparse
import pickle


def load_model(checkpoint_path):
"""
Load the entire model for prediction with an input
class InferenceModel:
def save_model_pkl(self, path, model_name, posterior):
"""
Save the pkl'ed saved posterior model
:param checkpoint_path: location
:return: loaded model object that can be used with the predict function
"""
pass
:param path: Location to save the model
:param model_name: Name of the model
:param posterior: Model object to be saved
"""
file_name = path + model_name + ".pkl"
with open(file_name, "wb") as file:
pickle.dump(posterior, file)

def load_model_pkl(self, path, model_name):
"""
Load the pkl'ed saved posterior model
def predict(input, model):
"""
:param path: Location to load the model from
:param model_name: Name of the model
:return: Loaded model object that can be used with the predict function
"""
with open(path + model_name + ".pkl", 'rb') as file:
posterior = pickle.load(file)
return posterior

:param input: loaded object used for inference
:param model: loaded model
:return: Prediction
"""
return 0

def load_inference_object(input_path):
"""
def predict(input, model):
"""
:param input_path: path to the object you want to predict
:return: loaded object
"""
return 0
:param input: loaded object used for inference
:param model: loaded model
:return: Prediction
"""
return 0


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint", type=str, help="Checkpoint to unloaded model checkpoint, either weights or the compressed model object")
parser.add_argument("--input", type=str, help="path to object to predict quality of")
parser.add_argument("--path", type=str, help="path to saved posterior")
parser.add_argument("--name", type=str, help="saved posterior name")
args = parser.parse_args()

model = load_model(args.checkpoint)
pred_obj = load_inference_object(args.input)
# Create an instance of InferenceModel
inference_model = InferenceModel()

# Load the model
model = inference_model.load_model_pkl(args.path, args.name)

prediction = predict(pred_obj, model)
print(prediction)
inference_obj = inference_model.predict(model)

0 comments on commit b4ac468

Please sign in to comment.