Skip to content

Commit

Permalink
added params for pytorch model in viya4model load, import model,
Browse files Browse the repository at this point in the history
  • Loading branch information
namera9 committed Jan 29, 2024
1 parent ae049df commit 830140b
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions src/sasctl/pzmm/write_score_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,7 @@ def _write_imports(
mojo_model: Optional[bool] = False,
binary_h2o_model: Optional[bool] = False,
tf_model: Optional[bool] = False,
pytorch_model: Optional[bool] = False,
binary_string: Optional[str] = None,
) -> None:
"""
Expand All @@ -427,6 +428,9 @@ def _write_imports(
tf_model : bool, optional
Flag to indicate that the model is a tensorflow model. The default value
is None.
pytorch_model : bool, optional
Flag to indicate that the model is a pytorch model. The default value
is None.
binary_string : str, optional
A binary representation of the Python model object. The default value is
None.
Expand Down Expand Up @@ -475,6 +479,8 @@ def _write_imports(
import tensorflow as tf
"""
elif pytorch_model:
cls.score_code += "import math\nimport torch\nimport pandas as pd\nimport numpy as np\nfrom pathlib import Path\n\n"
elif binary_string:
cls.score_code += (
f'import codecs\n\nbinary_string = "{binary_string}"'
Expand Down Expand Up @@ -578,6 +584,7 @@ def _viya4_model_load(
pickle_type: Optional[str] = None,
mojo_model: Optional[bool] = False,
binary_h2o_model: Optional[bool] = False,
pytorch_model: Optional[bool] = False,
tf_keras_model: Optional[bool] = False,
tf_core_model: Optional[bool] = False,
) -> str:
Expand All @@ -598,6 +605,9 @@ def _viya4_model_load(
binary_h2o_model : boolean, optional
Flag to indicate that the model is a H2O.ai binary model. The default value
is None.
pytorch_model : boolean, optional
Flag to indicate that the model is a pytorch model. The default value
is None.
tf_keras_model : boolean, optional
Flag to indicate that the model is a tensorflow keras model. The default
value is False.
Expand Down Expand Up @@ -633,6 +643,8 @@ def _viya4_model_load(
f"{'':8}model = h2o.load(str(Path(settings.pickle_path) / "
f"{model_file_name}))\n\n"
)
elif pytorch_model:
cls.score_code += ( "model = torch.load(path) ")
elif tf_keras_model:
cls.score_code += (
f"model = tf.keras.models.load_model(Path(settings.pickle_path) / "
Expand Down

0 comments on commit 830140b

Please sign in to comment.