From d55f57dbf1eab1d9202b5826c2d86f5f6deb57ac Mon Sep 17 00:00:00 2001 From: djm21 Date: Mon, 28 Oct 2024 14:36:23 -0700 Subject: [PATCH] updates to write_score_code to correct formatting errors --- src/sasctl/pzmm/write_score_code.py | 35 ++++++++++++++++------------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/src/sasctl/pzmm/write_score_code.py b/src/sasctl/pzmm/write_score_code.py index e990d461..b44df206 100644 --- a/src/sasctl/pzmm/write_score_code.py +++ b/src/sasctl/pzmm/write_score_code.py @@ -250,6 +250,7 @@ def score(var1, var2, var3, var4): input_var_list, missing_values=missing_values, dtype_list=input_dtypes_list, + preprocess_function=preprocess_function ) self._predictions_to_metrics( score_metrics, @@ -266,6 +267,7 @@ def score(var1, var2, var3, var4): missing_values=missing_values, statsmodels_model="statsmodels_model" in kwargs, tf_model="tf_keras_model" in kwargs or "tf_core_model" in kwargs, + preprocess_function=preprocess_function ) # Include check for numpy values and a conversion operation as needed self.score_code += ( @@ -814,14 +816,15 @@ def _predict_method( input_frame = f'{{{", ".join(input_dict)}}}, index=index' self.score_code += self._wrap_indent_string(input_frame, 8) self.score_code += f"\n{'':4})\n" - if preprocess_function: - self.score_code += ( - f"{'':4}input_array = {preprocess_function.__name__}(input_array)" - ) + if missing_values: self.score_code += ( f"{'':4}input_array = impute_missing_values(input_array)\n" ) + if preprocess_function: + self.score_code += ( + f"{'':4}input_array = {preprocess_function.__name__}(input_array)\n" + ) self.score_code += ( f"{'':4}column_types = {column_types}\n" f"{'':4}h2o_array = h2o.H2OFrame(input_array, " @@ -860,14 +863,14 @@ def _predict_method( input_frame = f'{{{", ".join(input_dict)}}}, index=index' self.score_code += self._wrap_indent_string(input_frame, 8) self.score_code += f"\n{'':4})\n" - if preprocess_function: - self.score_code += ( - f"{'':4}input_array = {preprocess_function.__name__}(input_array)" - ) if missing_values: self.score_code += ( f"{'':4}input_array = impute_missing_values(input_array)\n" ) + if preprocess_function: + self.score_code += ( + f"{'':4}input_array = {preprocess_function.__name__}(input_array)\n" + ) self.score_code += ( f"{'':4}prediction = model.{method.__name__}(input_array)\n" ) @@ -885,14 +888,14 @@ def _predict_method( input_frame = f'{{{", ".join(input_dict)}}}, index=index' self.score_code += self._wrap_indent_string(input_frame, 8) self.score_code += f"\n{'':4})\n" - if preprocess_function: - self.score_code += ( - f"{'':4}input_array = {preprocess_function.__name__}(input_array)" - ) if missing_values: self.score_code += ( f"{'':4}input_array = impute_missing_values(input_array)\n" ) + if preprocess_function: + self.score_code += ( + f"{'':4}input_array = {preprocess_function.__name__}(input_array)\n" + ) self.score_code += ( f"{'':4}prediction = model.{method.__name__}(input_array)\n\n" f"{'':4} # Check if model returns logits or probabilities\n" @@ -921,14 +924,14 @@ def _predict_method( input_frame = f'{{{", ".join(input_dict)}}}, index=index' self.score_code += self._wrap_indent_string(input_frame, 8) self.score_code += f"\n{'':4})\n" - if preprocess_function: - self.score_code += ( - f"{'':4}input_array = {preprocess_function.__name__}(input_array)" - ) if missing_values: self.score_code += ( f"{'':4}input_array = impute_missing_values(input_array)\n" ) + if preprocess_function: + self.score_code += ( + f"{'':4}input_array = {preprocess_function.__name__}(input_array)\n" + ) self.score_code += ( f"{'':4}prediction = model.{method.__name__}(input_array).tolist()\n" )