From 830140b2c23bbe04e3c2c8ae9f3152424fdca410 Mon Sep 17 00:00:00 2001 From: Nabeel Merali Date: Mon, 29 Jan 2024 08:59:04 -0500 Subject: [PATCH] added params for pytorch model in viya4model load, import model, --- src/sasctl/pzmm/write_score_code.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/sasctl/pzmm/write_score_code.py b/src/sasctl/pzmm/write_score_code.py index 970a81ac..c1dbd5c9 100644 --- a/src/sasctl/pzmm/write_score_code.py +++ b/src/sasctl/pzmm/write_score_code.py @@ -405,6 +405,7 @@ def _write_imports( mojo_model: Optional[bool] = False, binary_h2o_model: Optional[bool] = False, tf_model: Optional[bool] = False, + pytorch_model: Optional[bool] = False, binary_string: Optional[str] = None, ) -> None: """ @@ -427,6 +428,9 @@ def _write_imports( tf_model : bool, optional Flag to indicate that the model is a tensorflow model. The default value is None. + pytorch_model : bool, optional + Flag to indicate that the model is a pytorch model. The default value + is None. binary_string : str, optional A binary representation of the Python model object. The default value is None. @@ -475,6 +479,8 @@ def _write_imports( import tensorflow as tf """ + elif pytorch_model: + cls.score_code += "import math\nimport torch\nimport pandas as pd\nimport numpy as np\nfrom pathlib import Path\n\n" elif binary_string: cls.score_code += ( f'import codecs\n\nbinary_string = "{binary_string}"' @@ -578,6 +584,7 @@ def _viya4_model_load( pickle_type: Optional[str] = None, mojo_model: Optional[bool] = False, binary_h2o_model: Optional[bool] = False, + pytorch_model: Optional[bool] = False, tf_keras_model: Optional[bool] = False, tf_core_model: Optional[bool] = False, ) -> str: @@ -598,6 +605,9 @@ def _viya4_model_load( binary_h2o_model : boolean, optional Flag to indicate that the model is a H2O.ai binary model. The default value is None. + pytorch_model : boolean, optional + Flag to indicate that the model is a pytorch model. The default value + is None. tf_keras_model : boolean, optional Flag to indicate that the model is a tensorflow keras model. The default value is False. @@ -633,6 +643,8 @@ def _viya4_model_load( f"{'':8}model = h2o.load(str(Path(settings.pickle_path) / " f"{model_file_name}))\n\n" ) + elif pytorch_model: + cls.score_code += ( "model = torch.load(path) ") elif tf_keras_model: cls.score_code += ( f"model = tf.keras.models.load_model(Path(settings.pickle_path) / "