Skip to content

Commit

Permalink
fixed some bugs to allow for model card files to generate correctly.
Browse files Browse the repository at this point in the history
  • Loading branch information
djm21 committed Mar 7, 2024
1 parent 53cb8bb commit 850f054
Showing 1 changed file with 18 additions and 16 deletions.
34 changes: 18 additions & 16 deletions src/sasctl/pzmm/write_json_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -2208,7 +2208,7 @@ def generate_model_card(
algorithm: str,
train_data: pd.DataFrame,
train_predictions: Union[pd.Series, list],
target_type: str = "Interval",
target_type: str = "interval",
target_value: Union[str, int, float, None] = None,
interval_vars: Optional[list] = [],
class_vars: Optional[list] = [],
Expand Down Expand Up @@ -2237,10 +2237,10 @@ def generate_model_card(
train_predictions : pandas.Series, list
List of predictions made by the model on the training data.
target_type : string
Type the model is targeting. Currently supports "Classification" and "Interval" types.
Type the model is targeting. Currently supports "classification" and "interval" types.
The default value is "Interval".
target_value : string, int, float, optional
Value the model is targeting for Classification models. This argument is not needed for
Value the model is targeting for classification models. This argument is not needed for
Interval models. The default value is None.
interval_vars : list, optional
A list of interval variables. The default value is an empty list.
Expand All @@ -2255,14 +2255,14 @@ def generate_model_card(
caslib: str, optional
The caslib the training data will be stored on. The default value is "Public"
"""
if not target_value and target_type == "Classification":
if not target_value and target_type == "classification":
raise RuntimeError(
"For the model card data to be properly generated on a Classification "
"For the model card data to be properly generated on a classification "
"model, a target value is required."
)
if target_type not in ["Classification", "Interval"]:
if target_type not in ["classification", "interval"]:
raise RuntimeError(
"Only Classification and Interval target types are currently accepted."
"Only classification and interval target types are currently accepted."
)
if selection_statistic not in cls.valid_params:
raise RuntimeError(
Expand Down Expand Up @@ -2396,10 +2396,10 @@ def generate_outcome_average(
Returns a dictionary with a key value pair that represents the outcome average.
"""
output_var = train_data.drop(input_variables, axis=1)
if target_type == "Classification":
if target_type == "classification":
value_counts = output_var[output_var.columns[0]].value_counts()
return {'eventPercentage': value_counts[target_value]/sum(value_counts)}
elif target_type == "Interval":
elif target_type == "interval":
return {'eventAverage': sum(value_counts[value_counts.columns[0]]) / len(value_counts)}

@staticmethod
Expand Down Expand Up @@ -2480,8 +2480,8 @@ def update_model_properties(
"The ModelProperties.json file must be generated before the model card data "
"can be generated."
)
for key, value in update_dict:
model_files[PROP][key] = value
for key in update_dict:
model_files[PROP][key] = update_dict[key]
else:
if not Path.exists(Path(model_files) / PROP):
raise RuntimeError(
Expand All @@ -2490,8 +2490,8 @@ def update_model_properties(
)
with open(Path(model_files) / PROP, 'r+') as properties_json:
model_properties = json.load(properties_json)
for key, value in update_dict:
model_properties[key] = value
for key in update_dict:
model_properties[key] = update_dict[key]
properties_json.seek(0)
properties_json.write(json.dumps(model_properties, indent=4, cls=NpEncoder))
properties_json.truncate()
Expand Down Expand Up @@ -2595,7 +2595,7 @@ def generate_variable_importance(
}
})
var_data = conn.dataPreprocess.transform(
table={"name": "test_data", "caslib": caslib},
table={"name": "train_data", "caslib": caslib},
requestPackages=request_packages,
evaluationStats=True,
percentileMaxIterations=10,
Expand Down Expand Up @@ -2623,7 +2623,10 @@ def generate_variable_importance(
},
"rowNumber": index+1
})
with open('./dmcas_relativeimportance.json', 'r') as f:
json_template_path = (
Path(__file__).resolve().parent / f"template_files/{VARIMPORTANCES}"
)
with open(json_template_path, 'r') as f:
relative_importance_json = json.load(f)
relative_importance_json['data'] = relative_importances

Expand All @@ -2641,5 +2644,4 @@ def generate_variable_importance(
print(
f"{VARIMPORTANCES} was successfully written and saved to "
f"{Path(model_files) / VARIMPORTANCES}"

)

0 comments on commit 850f054

Please sign in to comment.