Skip to content

Commit

Permalink
updated write_json_files to fix model card issues
Browse files Browse the repository at this point in the history
  • Loading branch information
djm21 committed Jul 31, 2024
1 parent a7b49ac commit 8f26fd2
Showing 1 changed file with 18 additions and 7 deletions.
25 changes: 18 additions & 7 deletions src/sasctl/pzmm/write_json_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,6 +965,7 @@ def format_max_differences(
maxdiff_df = maxdiff_df.rename(
columns={"Value": "maxdiff", "Base": "BASE", "Compare": "COMPARE"}
)
maxdiff_df["maxdiff"] = maxdiff_df["maxdiff"].apply(str)

maxdiff_df["VLABEL"] = ""
maxdiff_df["_DATAROLE_"] = datarole
Expand Down Expand Up @@ -2257,6 +2258,7 @@ def generate_model_card(
interval_vars: Optional[list] = [],
class_vars: Optional[list] = [],
selection_statistic: str = None,
training_table_name: str = None,
server: str = "cas-shared-default",
caslib: str = "Public",
):
Expand Down Expand Up @@ -2336,7 +2338,7 @@ def generate_model_card(

# Upload training table to CAS. The location of the training table is returned.
training_table = cls.upload_training_data(
conn, model_prefix, train_data, server, caslib
conn, model_prefix, train_data, training_table_name, server, caslib
)

# Generates the event percentage for Classification targets, and the event average
Expand Down Expand Up @@ -2378,6 +2380,7 @@ def upload_training_data(
conn,
model_prefix: str,
train_data: pd.DataFrame,
train_data_name: str,
server: str = "cas-shared-default",
caslib: str = "Public",
):
Expand All @@ -2404,15 +2407,18 @@ def upload_training_data(
Returns a string that represents the location of the training table within CAS.
"""
# Upload raw training data to caslib so that data can be analyzed
train_data_name = model_prefix + "_train_data"
if not train_data_name:
train_data_name = model_prefix + "_train_data"
upload_train_data = conn.upload(
train_data, casout={"name": train_data_name, "caslib": caslib}, promote=True
)

if upload_train_data.status is not None:
raise RuntimeError(
f"A table with the name {train_data_name} already exists in the specified caslib. Please "
"either delete/rename the old table or give a new name to the current table."
# raise RuntimeError(
warnings.warn(
f"A table with the name {train_data_name} already exists in the specified caslib. If this "
f"is not intentional, please either rename the training data file or remove the duplicate from "
f"the caslib."
)

return server + "/" + caslib + "/" + train_data_name.upper()
Expand Down Expand Up @@ -2762,6 +2768,9 @@ def generate_misc(cls, model_files: Union[str, Path, dict]):
roc_table = json.load(roc_file)
correct_text = ["CORRECT", "INCORRECT", "CORRECT", "INCORRECT"]
outcome_values = ["1", "0", "0", "1"]
target_texts = ["Event", "Event", "NEvent", "NEvent"]
target_values = ["1", "1", "0", "0"]

misc_data = list()
# Iterates through ROC table to get TRAIN, TEST, and VALIDATE data with a cutoff of .5
for i in range(50, 300, 100):
Expand All @@ -2772,8 +2781,8 @@ def generate_misc(cls, model_files: Union[str, Path, dict]):
roc_data["_TN_"],
roc_data["_FN_"],
]
for c_text, c_val, o_val in zip(
correct_text, correctness_values, outcome_values
for c_text, c_val, o_val, t_txt, t_val in zip(
correct_text, correctness_values, outcome_values, target_texts, target_values
):
misc_data.append(
{
Expand All @@ -2784,6 +2793,8 @@ def generate_misc(cls, model_files: Union[str, Path, dict]):
"_DataRole_": roc_data["_DataRole_"],
"_cutoffSource_": "Default",
"_cutoff_": "0.5",
"TargetText": t_txt,
"Target": t_val
},
"rowNumber": len(misc_data) + 1,
}
Expand Down

0 comments on commit 8f26fd2

Please sign in to comment.