Skip to content

Commit

Permalink
Merge pull request #188 from sassoftware/model_cards
Browse files Browse the repository at this point in the history
Model cards
  • Loading branch information
djm21 authored Apr 9, 2024
2 parents 28a3e22 + fa499fc commit cdb3a52
Show file tree
Hide file tree
Showing 4 changed files with 854 additions and 40 deletions.
20 changes: 17 additions & 3 deletions examples/pzmm_binary_classification_model_import.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -740,7 +740,7 @@
],
"source": [
"import getpass\n",
"def write_model_stats(x_train, y_train, test_predict, test_proba, y_test, model, path):\n",
"def write_model_stats(x_train, y_train, test_predict, test_proba, y_test, model, path, prefix):\n",
" # Calculate train predictions\n",
" train_predict = model.predict(x_train)\n",
" train_proba = model.predict_proba(x_train)\n",
Expand All @@ -757,6 +757,20 @@
" test_data=test_data, \n",
" json_path=path\n",
" )\n",
"\n",
" full_training_data = pd.concat([y_train.reset_index(drop=True), x_train.reset_index(drop=True)], axis=1)\n",
"\n",
" pzmm.JSONFiles.generate_model_card(\n",
" model_prefix=prefix,\n",
" model_files = path,\n",
" algorithm = str(type(model).__name__),\n",
" train_data = full_training_data,\n",
" train_predictions=train_predict,\n",
" target_type='classification',\n",
" target_value=1,\n",
" interval_vars=predictor_columns,\n",
" selection_statistic='_RASE_',\n",
" )\n",
" \n",
"username = getpass.getpass()\n",
"password = getpass.getpass()\n",
Expand All @@ -766,8 +780,8 @@
"\n",
"test_predict = [y_dtc_predict, y_rfc_predict, y_gbc_predict]\n",
"test_proba = [y_dtc_proba, y_rfc_proba, y_gbc_proba]\n",
"for (mod, pred, proba, path) in zip(model, test_predict, test_proba, zip_folder):\n",
" write_model_stats(x_train, y_train, pred, proba, y_test, mod, path)"
"for (mod, pred, proba, path, prefix) in zip(model, test_predict, test_proba, zip_folder, model_prefix):\n",
" write_model_stats(x_train, y_train, pred, proba, y_test, mod, path, prefix)"
]
},
{
Expand Down
58 changes: 58 additions & 0 deletions src/sasctl/pzmm/template_files/dmcas_relativeimportance.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
{
"creationTimeStamp" : "0001-01-01T00:00:00Z",
"modifiedTimeStamp" : "0001-01-01T00:00:00Z",
"revision" : 0,
"name" : "dmcas_relativeimportance",
"version" : 0,
"order" : 0,
"parameterMap" : {
"LABEL" : {
"label" : "Variable Label",
"length" : 256,
"order" : 1,
"parameter" : "LABEL",
"preformatted" : false,
"type" : "char",
"values" : [ "LABEL" ]
},
"LEVEL" : {
"label" : "Variable Level",
"length" : 10,
"order" : 5,
"parameter" : "LEVEL",
"preformatted" : false,
"type" : "char",
"values" : [ "LEVEL" ]
},
"ROLE" : {
"label" : "Role",
"length" : 32,
"order" : 4,
"parameter" : "ROLE",
"preformatted" : false,
"type" : "char",
"values" : [ "ROLE" ]
},
"RelativeImportance" : {
"label" : "Relative Importance",
"length" : 8,
"order" : 3,
"parameter" : "RelativeImportance",
"preformatted" : false,
"type" : "num",
"values" : [ "RelativeImportance" ]
},
"Variable" : {
"label" : "Variable Name",
"length" : 255,
"order" : 2,
"parameter" : "Variable",
"preformatted" : false,
"type" : "char",
"values" : [ "Variable" ]
}
},
"data" : [],
"xInteger" : false,
"yInteger" : false
}
Loading

0 comments on commit cdb3a52

Please sign in to comment.