diff --git a/src/sasctl/pzmm/write_score_wrapper.py b/src/sasctl/pzmm/write_score_wrapper.py index 573a8e1e..9aeab51e 100644 --- a/src/sasctl/pzmm/write_score_wrapper.py +++ b/src/sasctl/pzmm/write_score_wrapper.py @@ -22,7 +22,8 @@ def write_score_wrapper_function_input(cls, function_definition: str, function_body: str, model_load: str, - model_name_with_file_extension: str): + model_name_with_file_extension: str, + output_variables: List[str]): """ Method to generate scoring code from a function and add it to cls.score_wrapper. @@ -32,6 +33,7 @@ def write_score_wrapper_function_input(cls, function_body (str): Function body. model_load (str): Name of the model to load. model_name_with_file_extension (str): Name of the model file with extension. + output_variables (List[str]): List of output variables to define in score function Returns: cls.score_wrapper (str): The scoring code. @@ -49,11 +51,16 @@ def write_score_wrapper_function_input(cls, # Define the score function and add the function body specified cls.score_wrapper += f"{function_definition}:\n" + cls.score_wrapper += '\t"' + cls.score_wrapper += "Output Variables: " + ", ".join(output_variables) # Join output variables with comma + cls.score_wrapper += '"\n' cls.score_wrapper += "\tglobal model\n" cls.score_wrapper += "\ttry:\n" cls.score_wrapper += f"\t\t{function_body}\n" cls.score_wrapper += "\texcept Exception as e:\n" cls.score_wrapper += "\t\tprint(f'Error: {e}')\n" + cls.score_wrapper += "\t\treturn None\n" + # Validate syntax before returning if not cls.validate_score_wrapper_syntax(cls.score_wrapper): raise SyntaxError("Syntax error in generated code.") @@ -67,6 +74,7 @@ def write_score_wrapper_file_input(cls, model_load: str, model_name_with_file_extension: str, score_function_body: str, + output_variables: List[str], ): """ Method to generate scoring code from a file and add it to cls.score_wrapper. @@ -77,6 +85,7 @@ def write_score_wrapper_file_input(cls, model_load (str): Name of the model to load. model_name_with_file_extension (str): Name of the model file with extension. score_function_body (str): The code needed to evaluate the model. + output_variables (List[str]): List of output variables to define in score function Returns: cls.score_wrapper (str): The scoring code. @@ -98,11 +107,18 @@ def write_score_wrapper_file_input(cls, # define the generic score function, and append the score_function_body to evaluate the model. cls.score_wrapper += f"def score(input_data):\n" + cls.score_wrapper += '\t"' + cls.score_wrapper += "Output Variables: " + ", ".join(output_variables) # Join output variables with comma + cls.score_wrapper += '"\n' + cls.score_wrapper += "\tglobal model\n" cls.score_wrapper += "\ttry:" cls.score_wrapper += f"\n{score_function_body}\n" cls.score_wrapper += "\n\texcept Exception as e:\n" cls.score_wrapper += "\t\tprint(f'Error: {e}')\n" + cls.score_wrapper += "\t\treturn None\n" + # Need some kind of return value here + return cls.score_wrapper # Validate Syntax before returning if not cls.validate_score_wrapper_syntax(cls.score_wrapper): @@ -147,4 +163,3 @@ def validate_score_wrapper_syntax(cls, code: str) -> bool: -