diff --git a/src/sasctl/pzmm/write_score_code.py b/src/sasctl/pzmm/write_score_code.py index b29d1a1f..efd47220 100644 --- a/src/sasctl/pzmm/write_score_code.py +++ b/src/sasctl/pzmm/write_score_code.py @@ -497,7 +497,7 @@ def _write_imports( import codecs binary_string = "" -model = pickle.load(codecs.decode(binary_string.encode(), "base64")) +model = pickle.loads(codecs.decode(binary_string.encode(), "base64")) """ def _viya35_model_load( @@ -562,6 +562,26 @@ def _viya35_model_load( f'{model_id}/{model_file_name}")))' ) else: + if pickle_type.lower() == 'pickle': + self.score_code += ( + f'model_path = Path("/models/resources/viya/{model_id}' + f'")\nwith open(model_path / "{model_file_name}", ' + f"\"rb\") as pickle_model:\n{'':4}model = pd.read_pickle" + "(pickle_model)\n\n" + ) + """ +model_path = Path("/models/resources/viya/") +with open(model_path / "model.pickle", "rb") as pickle_model: + model = pd.read_pickle(pickle_model) + + """ + return ( + f"{'':8}model_path = Path(\"/models/resources/viya/{model_id}" + f"\")\n{'':8}with open(model_path / \"{model_file_name}\", " + f"\"rb\") as pickle_model:\n{'':12}model = pd.read_pickle" + "(pickle_model)" + ) + self.score_code += ( f'model_path = Path("/models/resources/viya/{model_id}' f'")\nwith open(model_path / "{model_file_name}", ' @@ -658,6 +678,23 @@ def _viya4_model_load( f"safe_mode=True)\n" ) else: + if pickle_type.lower() == "pickle": + self.score_code += ( + f"with open(Path(settings.pickle_path) / " + f'"{model_file_name}", "rb") as pickle_model:\n' + f"{'':4}model = pd.read_pickle(pickle_model)\n\n" + ) + """ + with open(Path(settings.pickle_path) / "model.pickle", "rb") as pickle_model: + model = pd.read_pickle(pickle_model) + + """ + return ( + f"{'':8}with open(Path(settings.pickle_path) / " + f'"{model_file_name}", "rb") as pickle_model:\n' + f"{'':12}model = pd.read_pickle(pickle_model)\n\n" + ) + self.score_code += ( f"with open(Path(settings.pickle_path) / " f'"{model_file_name}", "rb") as pickle_model:\n' diff --git a/tests/unit/test_write_score_code.py b/tests/unit/test_write_score_code.py index 25eb1aaf..f953b5cf 100644 --- a/tests/unit/test_write_score_code.py +++ b/tests/unit/test_write_score_code.py @@ -118,8 +118,13 @@ def test_viya35_model_load(): """ sc = ScoreCode() load_text = sc._viya35_model_load("1234", "normal") - assert "pickle.load(pickle_model)" in sc.score_code - assert "pickle.load(pickle_model)" in load_text + assert "pd.read_pickle(pickle_model)" in sc.score_code + assert "pd.read_pickle(pickle_model)" in load_text + + sc = ScoreCode() + load_text = sc._viya35_model_load("1234", "normal", pickle_type="dill") + assert "dill.load(pickle_model)" in sc.score_code + assert "dill.load(pickle_model)" in load_text sc = ScoreCode() mojo_text = sc._viya35_model_load("2345", "mojo", mojo_model=True) @@ -142,8 +147,13 @@ def test_viya4_model_load(): """ sc = ScoreCode() load_text = sc._viya4_model_load("normal") - assert "pickle.load(pickle_model)" in sc.score_code - assert "pickle.load(pickle_model)" in load_text + assert "pd.read_pickle(pickle_model)" in sc.score_code + assert "pd.read_pickle(pickle_model)" in load_text + + sc = ScoreCode() + load_text = sc._viya35_model_load("1234", "normal", pickle_type="dill") + assert "dill.load(pickle_model)" in sc.score_code + assert "dill.load(pickle_model)" in load_text sc = ScoreCode() mojo_text = sc._viya4_model_load("mojo", mojo_model=True)