From 8f26fd25b40107a1c073e812b6b61c6e391c5195 Mon Sep 17 00:00:00 2001 From: djm21 Date: Wed, 31 Jul 2024 17:46:52 -0500 Subject: [PATCH] updated write_json_files to fix model card issues --- src/sasctl/pzmm/write_json_files.py | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/src/sasctl/pzmm/write_json_files.py b/src/sasctl/pzmm/write_json_files.py index 10e87f1a..b9bb1d6e 100644 --- a/src/sasctl/pzmm/write_json_files.py +++ b/src/sasctl/pzmm/write_json_files.py @@ -965,6 +965,7 @@ def format_max_differences( maxdiff_df = maxdiff_df.rename( columns={"Value": "maxdiff", "Base": "BASE", "Compare": "COMPARE"} ) + maxdiff_df["maxdiff"] = maxdiff_df["maxdiff"].apply(str) maxdiff_df["VLABEL"] = "" maxdiff_df["_DATAROLE_"] = datarole @@ -2257,6 +2258,7 @@ def generate_model_card( interval_vars: Optional[list] = [], class_vars: Optional[list] = [], selection_statistic: str = None, + training_table_name: str = None, server: str = "cas-shared-default", caslib: str = "Public", ): @@ -2336,7 +2338,7 @@ def generate_model_card( # Upload training table to CAS. The location of the training table is returned. training_table = cls.upload_training_data( - conn, model_prefix, train_data, server, caslib + conn, model_prefix, train_data, training_table_name, server, caslib ) # Generates the event percentage for Classification targets, and the event average @@ -2378,6 +2380,7 @@ def upload_training_data( conn, model_prefix: str, train_data: pd.DataFrame, + train_data_name: str, server: str = "cas-shared-default", caslib: str = "Public", ): @@ -2404,15 +2407,18 @@ def upload_training_data( Returns a string that represents the location of the training table within CAS. """ # Upload raw training data to caslib so that data can be analyzed - train_data_name = model_prefix + "_train_data" + if not train_data_name: + train_data_name = model_prefix + "_train_data" upload_train_data = conn.upload( train_data, casout={"name": train_data_name, "caslib": caslib}, promote=True ) if upload_train_data.status is not None: - raise RuntimeError( - f"A table with the name {train_data_name} already exists in the specified caslib. Please " - "either delete/rename the old table or give a new name to the current table." + # raise RuntimeError( + warnings.warn( + f"A table with the name {train_data_name} already exists in the specified caslib. If this " + f"is not intentional, please either rename the training data file or remove the duplicate from " + f"the caslib." ) return server + "/" + caslib + "/" + train_data_name.upper() @@ -2762,6 +2768,9 @@ def generate_misc(cls, model_files: Union[str, Path, dict]): roc_table = json.load(roc_file) correct_text = ["CORRECT", "INCORRECT", "CORRECT", "INCORRECT"] outcome_values = ["1", "0", "0", "1"] + target_texts = ["Event", "Event", "NEvent", "NEvent"] + target_values = ["1", "1", "0", "0"] + misc_data = list() # Iterates through ROC table to get TRAIN, TEST, and VALIDATE data with a cutoff of .5 for i in range(50, 300, 100): @@ -2772,8 +2781,8 @@ def generate_misc(cls, model_files: Union[str, Path, dict]): roc_data["_TN_"], roc_data["_FN_"], ] - for c_text, c_val, o_val in zip( - correct_text, correctness_values, outcome_values + for c_text, c_val, o_val, t_txt, t_val in zip( + correct_text, correctness_values, outcome_values, target_texts, target_values ): misc_data.append( { @@ -2784,6 +2793,8 @@ def generate_misc(cls, model_files: Union[str, Path, dict]): "_DataRole_": roc_data["_DataRole_"], "_cutoffSource_": "Default", "_cutoff_": "0.5", + "TargetText": t_txt, + "Target": t_val }, "rowNumber": len(misc_data) + 1, }