Skip to content

Commit

Permalink
experimenting with quartile regressor functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
brifordwylie committed Jun 26, 2024
1 parent 7e72a7f commit 7a86b3e
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 11 deletions.
13 changes: 9 additions & 4 deletions src/sageworks/algorithms/dataframe/quantile_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@ class QuantileRegressor(BaseEstimator, TransformerMixin):
A class for training regression models over a set of quantiles. Useful for calculating confidence intervals.
"""

def __init__(self, model: Union[RegressorMixin, XGBRegressor] = XGBRegressor, quantiles: list = [0.05, 0.25, 0.50, 0.75, 0.95]):
def __init__(
self,
model: Union[RegressorMixin, XGBRegressor] = XGBRegressor,
quantiles: list = [0.05, 0.25, 0.50, 0.75, 0.95],
):
"""
Initializes the QuantileRegressor with the specified parameters.
Expand Down Expand Up @@ -114,7 +118,7 @@ def example_confidence(q_dataframe, target="target", target_sensitivity=0.25):
# If the interval with is greater than target_sensitivity with have 0 confidence
# anything below that is a linear scale from 0 to 1
confidence_interval = upper_95 - lower_05
q_conf = np.clip(1 - confidence_interval/(target_sensitivity * 4.0), 0, 1)
q_conf = np.clip(1 - confidence_interval / (target_sensitivity * 4.0), 0, 1)

# Now lets look at the IQR distance for each observation
epsilon_iqr = target_sensitivity * 0.5
Expand Down Expand Up @@ -206,8 +210,9 @@ def integration_test():
confidence_df["interval"] = confidence_df["q_95"] - confidence_df["q_05"]

# Compute the confidence
confidence_df["conf"], confidence_df["q_conf"], confidence_df["iqr_conf"] = example_confidence(confidence_df, target_column,
target_sensitivity=1.5)
confidence_df["conf"], confidence_df["q_conf"], confidence_df["iqr_conf"] = example_confidence(
confidence_df, target_column, target_sensitivity=1.5
)

# Columns of Interest
q_columns = [c for c in confidence_df.columns if c.startswith("q_")]
Expand Down
1 change: 0 additions & 1 deletion src/sageworks/core/artifacts/endpoint_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -914,4 +914,3 @@ def delete_endpoint_models(self):
# Run Inference and metrics for a Classification Endpoint
class_endpoint = EndpointCore("wine-classification-end")
class_endpoint.auto_inference()

29 changes: 23 additions & 6 deletions src/sageworks/utils/test_data_generator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""A Test Data Generator Class"""

import os
import logging
import pandas as pd
Expand Down Expand Up @@ -56,8 +57,8 @@ def confidence_data(n_samples=2000, n_features=1) -> pd.DataFrame:
"""

# Generate synthetic data with even spacing from -10 to 5 and sparse spacing from 5 to 10
x_even = np.linspace(-10, 5, int(n_samples*7/8)) # Evenly spaced from -10 to 5
x_sparse = 5 + (np.linspace(0, 1, int(n_samples*1/8)) ** 2) * 5 # Increasingly sparse from 5 to 10
x_even = np.linspace(-10, 5, int(n_samples * 7 / 8)) # Evenly spaced from -10 to 5
x_sparse = 5 + (np.linspace(0, 1, int(n_samples * 1 / 8)) ** 2) * 5 # Increasingly sparse from 5 to 10
x = np.concatenate([x_even, x_sparse])

# Ensure no values are exactly zero or negative in the input to the log function
Expand Down Expand Up @@ -124,7 +125,7 @@ def aqsol_data(self) -> pd.DataFrame:
"""Generate a Pandas DataFrame of AQSol Data"""

# Define a temporary file path
temp_file_path = os.path.join(tempfile.gettempdir(), 'aqsol_data.csv')
temp_file_path = os.path.join(tempfile.gettempdir(), "aqsol_data.csv")

# First check if the data is already stored in a local temporary file
if os.path.exists(temp_file_path):
Expand All @@ -145,9 +146,25 @@ def aqsol_data(self) -> pd.DataFrame:

def aqsol_features(self) -> list:
"""Get the AQSol Feature List"""
return ['molwt', 'mollogp', 'molmr', 'heavyatomcount', 'numhacceptors', 'numhdonors', 'numheteroatoms',
'numrotatablebonds', 'numvalenceelectrons', 'numaromaticrings', 'numsaturatedrings', 'numaliphaticrings',
'ringcount', 'tpsa', 'labuteasa', 'balabanj', 'bertzct']
return [
"molwt",
"mollogp",
"molmr",
"heavyatomcount",
"numhacceptors",
"numhdonors",
"numheteroatoms",
"numrotatablebonds",
"numvalenceelectrons",
"numaromaticrings",
"numsaturatedrings",
"numaliphaticrings",
"ringcount",
"tpsa",
"labuteasa",
"balabanj",
"bertzct",
]

def aqsol_target(self) -> str:
"""Get the AQSol Target"""
Expand Down

0 comments on commit 7a86b3e

Please sign in to comment.