From 1b1f776ca21df1a209f1b6fd357ad81ef9359b91 Mon Sep 17 00:00:00 2001 From: Brian Wylie Date: Sun, 7 Jul 2024 15:06:04 -0600 Subject: [PATCH] adding quantile regressor models for the residuals --- .../quant_regression.template | 28 +++++++++++++++++-- 1 file changed, 25 insertions(+), 3 deletions(-) diff --git a/src/sageworks/core/transforms/features_to_model/light_quant_regression/quant_regression.template b/src/sageworks/core/transforms/features_to_model/light_quant_regression/quant_regression.template index fcedf940e..624f73e19 100644 --- a/src/sageworks/core/transforms/features_to_model/light_quant_regression/quant_regression.template +++ b/src/sageworks/core/transforms/features_to_model/light_quant_regression/quant_regression.template @@ -105,7 +105,29 @@ if __name__ == "__main__": result_df["mean"] = rmse_model.predict(X) result_df["prediction"] = result_df["mean"] - # Save the quantile and RMSE predictions to S3 + # Now compute residuals on the prediction and compute quantiles on the residuals + result_df["residual"] = result_df[target] - result_df["prediction"] + result_df["residual_abs"] = result_df["residual"].abs() + + # Change the target to the residual + y = result_df["residual"] + + # Train models for each of the residual quantiles + for q in quantiles: + params = { + "objective": "reg:quantileerror", + "quantile_alpha": q, + } + model = xgb.XGBRegressor(**params) + model.fit(X, y) + + # Convert quantile to string + q_str = f"qr_{int(q * 100):02}" + + # Store the model + q_models[q_str] = model + + # Save the target quantiles and residual quantiles to S3 wr.s3.to_csv( result_df, path=f"{model_metrics_s3_path}/validation_predictions.csv", @@ -121,7 +143,7 @@ if __name__ == "__main__": print(f"R2: {r2:.3f}") print(f"NumRows: {len(result_df)}") - # Now save the Quantile models to the standard place + # Now save the both the target quantile and residual quantiles models for name, model in q_models.items(): model_path = os.path.join(args.model_dir, f"{name}.json") print(f"Saving model: {model_path}") @@ -150,7 +172,7 @@ def model_fn(model_dir) -> dict: # Load ALL the Quantile models from the model directory models = {} for file in os.listdir(model_dir): - if file.startswith("q_") and file.endswith(".json"): # The Quantile models + if file.startswith("q") and file.endswith(".json"): # The Quantile models # Load the model model_path = os.path.join(model_dir, file) print(f"Loading model: {model_path}")