From b4ac46835ad4d845324061943958680b66cda5e5 Mon Sep 17 00:00:00 2001 From: beckynevin Date: Thu, 25 Jan 2024 20:33:55 -0600 Subject: [PATCH] adding load and save posterior software to evaluate.py --- src/scripts/evaluate.py | 66 ++++++++++++++++++++++++----------------- 1 file changed, 39 insertions(+), 27 deletions(-) diff --git a/src/scripts/evaluate.py b/src/scripts/evaluate.py index 664d20e..16e54ea 100644 --- a/src/scripts/evaluate.py +++ b/src/scripts/evaluate.py @@ -1,46 +1,58 @@ """ -Simple stub functions to use in inference +Simple stub functions to use in evaluating inference from a previously trained inference model. + """ import argparse +import pickle -def load_model(checkpoint_path): - """ - Load the entire model for prediction with an input +class InferenceModel: + def save_model_pkl(self, path, model_name, posterior): + """ + Save the pkl'ed saved posterior model - :param checkpoint_path: location - :return: loaded model object that can be used with the predict function - """ - pass + :param path: Location to save the model + :param model_name: Name of the model + :param posterior: Model object to be saved + """ + file_name = path + model_name + ".pkl" + with open(file_name, "wb") as file: + pickle.dump(posterior, file) + def load_model_pkl(self, path, model_name): + """ + Load the pkl'ed saved posterior model -def predict(input, model): - """ + :param path: Location to load the model from + :param model_name: Name of the model + :return: Loaded model object that can be used with the predict function + """ + with open(path + model_name + ".pkl", 'rb') as file: + posterior = pickle.load(file) + return posterior - :param input: loaded object used for inference - :param model: loaded model - :return: Prediction - """ - return 0 -def load_inference_object(input_path): - """ + def predict(input, model): + """ - :param input_path: path to the object you want to predict - :return: loaded object - """ - return 0 + :param input: loaded object used for inference + :param model: loaded model + :return: Prediction + """ + return 0 if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument("--checkpoint", type=str, help="Checkpoint to unloaded model checkpoint, either weights or the compressed model object") - parser.add_argument("--input", type=str, help="path to object to predict quality of") + parser.add_argument("--path", type=str, help="path to saved posterior") + parser.add_argument("--name", type=str, help="saved posterior name") args = parser.parse_args() - model = load_model(args.checkpoint) - pred_obj = load_inference_object(args.input) + # Create an instance of InferenceModel + inference_model = InferenceModel() + + # Load the model + model = inference_model.load_model_pkl(args.path, args.name) - prediction = predict(pred_obj, model) - print(prediction) + inference_obj = inference_model.predict(model) \ No newline at end of file