Skip to content

Commit

Permalink
adding docstring for output vars
Browse files Browse the repository at this point in the history
  • Loading branch information
namera9 committed Apr 4, 2024
1 parent da8d538 commit d896f8f
Showing 1 changed file with 17 additions and 2 deletions.
19 changes: 17 additions & 2 deletions src/sasctl/pzmm/write_score_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ def write_score_wrapper_function_input(cls,
function_definition: str,
function_body: str,
model_load: str,
model_name_with_file_extension: str):
model_name_with_file_extension: str,
output_variables: List[str]):
"""
Method to generate scoring code from a function and add it to cls.score_wrapper.
Expand All @@ -32,6 +33,7 @@ def write_score_wrapper_function_input(cls,
function_body (str): Function body.
model_load (str): Name of the model to load.
model_name_with_file_extension (str): Name of the model file with extension.
output_variables (List[str]): List of output variables to define in score function
Returns:
cls.score_wrapper (str): The scoring code.
Expand All @@ -49,11 +51,16 @@ def write_score_wrapper_function_input(cls,

# Define the score function and add the function body specified
cls.score_wrapper += f"{function_definition}:\n"
cls.score_wrapper += '\t"'
cls.score_wrapper += "Output Variables: " + ", ".join(output_variables) # Join output variables with comma
cls.score_wrapper += '"\n'
cls.score_wrapper += "\tglobal model\n"
cls.score_wrapper += "\ttry:\n"
cls.score_wrapper += f"\t\t{function_body}\n"
cls.score_wrapper += "\texcept Exception as e:\n"
cls.score_wrapper += "\t\tprint(f'Error: {e}')\n"
cls.score_wrapper += "\t\treturn None\n"

# Validate syntax before returning
if not cls.validate_score_wrapper_syntax(cls.score_wrapper):
raise SyntaxError("Syntax error in generated code.")
Expand All @@ -67,6 +74,7 @@ def write_score_wrapper_file_input(cls,
model_load: str,
model_name_with_file_extension: str,
score_function_body: str,
output_variables: List[str],
):
"""
Method to generate scoring code from a file and add it to cls.score_wrapper.
Expand All @@ -77,6 +85,7 @@ def write_score_wrapper_file_input(cls,
model_load (str): Name of the model to load.
model_name_with_file_extension (str): Name of the model file with extension.
score_function_body (str): The code needed to evaluate the model.
output_variables (List[str]): List of output variables to define in score function
Returns:
cls.score_wrapper (str): The scoring code.
Expand All @@ -98,11 +107,18 @@ def write_score_wrapper_file_input(cls,

# define the generic score function, and append the score_function_body to evaluate the model.
cls.score_wrapper += f"def score(input_data):\n"
cls.score_wrapper += '\t"'
cls.score_wrapper += "Output Variables: " + ", ".join(output_variables) # Join output variables with comma
cls.score_wrapper += '"\n'

cls.score_wrapper += "\tglobal model\n"
cls.score_wrapper += "\ttry:"
cls.score_wrapper += f"\n{score_function_body}\n"
cls.score_wrapper += "\n\texcept Exception as e:\n"
cls.score_wrapper += "\t\tprint(f'Error: {e}')\n"
cls.score_wrapper += "\t\treturn None\n"
# Need some kind of return value here
return cls.score_wrapper

# Validate Syntax before returning
if not cls.validate_score_wrapper_syntax(cls.score_wrapper):
Expand Down Expand Up @@ -147,4 +163,3 @@ def validate_score_wrapper_syntax(cls, code: str) -> bool:




0 comments on commit d896f8f

Please sign in to comment.