Skip to content

Commit

Permalink
black reformatting
Browse files Browse the repository at this point in the history
  • Loading branch information
djm21 committed Jul 31, 2024
1 parent 8f26fd2 commit ee2cb5f
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions src/sasctl/pzmm/write_json_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -1393,7 +1393,7 @@ def check_for_data(
def stat_dataset_to_dataframe(
data: Union[DataFrame, List[list], Type["numpy.array"]],
target_value: Union[str, int, float] = None,
target_type: str = 'classification'
target_type: str = "classification",
) -> DataFrame:
"""
Convert the user supplied statistical dataset from either a pandas DataFrame,
Expand Down Expand Up @@ -1441,14 +1441,14 @@ def stat_dataset_to_dataframe(
if isinstance(data, pd.DataFrame):
if len(data.columns) == 2:
data.columns = ["actual", "predict"]
if target_type == 'classification':
if target_type == "classification":
data["predict_proba"] = data["predict"].gt(target_value).astype(int)
elif len(data.columns) == 3:
data.columns = ["actual", "predict", "predict_proba"]
elif isinstance(data, list):
if len(data) == 2:
data = pd.DataFrame({"actual": data[0], "predict": data[1]})
if target_type == 'classification':
if target_type == "classification":
data["predict_proba"] = data["predict"].gt(target_value).astype(int)
elif len(data) == 3:
data = pd.DataFrame(
Expand All @@ -1461,7 +1461,7 @@ def stat_dataset_to_dataframe(
elif isinstance(data, np.ndarray):
if len(data) == 2:
data = pd.DataFrame({"actual": data[0, :], "predict": data[1, :]})
if target_type == 'classification':
if target_type == "classification":
data["predict_proba"] = data["predict"].gt(target_value).astype(int)
elif len(data) == 3:
data = pd.DataFrame(
Expand Down Expand Up @@ -2372,7 +2372,7 @@ def generate_model_card(
)

# Generates dmcas_misc.json file
if target_type == 'classification':
if target_type == "classification":
cls.generate_misc(model_files)

@staticmethod
Expand Down Expand Up @@ -2782,7 +2782,11 @@ def generate_misc(cls, model_files: Union[str, Path, dict]):
roc_data["_FN_"],
]
for c_text, c_val, o_val, t_txt, t_val in zip(
correct_text, correctness_values, outcome_values, target_texts, target_values
correct_text,
correctness_values,
outcome_values,
target_texts,
target_values,
):
misc_data.append(
{
Expand All @@ -2794,7 +2798,7 @@ def generate_misc(cls, model_files: Union[str, Path, dict]):
"_cutoffSource_": "Default",
"_cutoff_": "0.5",
"TargetText": t_txt,
"Target": t_val
"Target": t_val,
},
"rowNumber": len(misc_data) + 1,
}
Expand Down

0 comments on commit ee2cb5f

Please sign in to comment.