Skip to content

Commit

Permalink
Add new plot plot_against_target_for_regression
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasMeissnerDS committed Aug 17, 2024
1 parent 047d231 commit babfa35
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 3 deletions.
84 changes: 81 additions & 3 deletions bluecast/eda/analyse.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pandas as pd
import scipy.stats as ss
import seaborn as sns
import statsmodels.api as sm
from sklearn.decomposition import PCA
from sklearn.feature_selection import mutual_info_classif, mutual_info_regression
from sklearn.manifold import TSNE
Expand All @@ -18,7 +19,7 @@
plt.set_loglevel("WARNING")


def find_nbins_with_freedman_diaconis(data: np.ndarray):
def find_bind_with_with_freedman_diaconis(data: np.ndarray):
# Calculate the IQR
iqr = np.percentile(data, 75) - np.percentile(data, 25)

Expand Down Expand Up @@ -228,7 +229,7 @@ def univariate_plots(df: pd.DataFrame, col_requires_at_least_n_values: int = 5)
np.arange(
min(df[col]),
max(df[col]),
max(find_nbins_with_freedman_diaconis(df[col].values), 0.1),
max(find_bind_with_with_freedman_diaconis(df[col].values), 0.1),
)
)

Expand Down Expand Up @@ -377,6 +378,83 @@ def correlation_to_target(df: pd.DataFrame, target: str) -> None:
plt.show()


def plot_against_target_for_regression(
df: pd.DataFrame, num_columns: List[Union[int, float, str]], target_col: str
) -> None:
"""
Creates scatter plots for each column in num_columns against the target_col.
Draws a regression line and shows the p-value for the regression line.
Parameters:
- df: pd.DataFrame -> The input dataframe containing the data.
- num_columns: List[Union[int, float, str]] -> List of column names to plot against the target column.
- target_col: str -> The target column name for regression.
Returns:
- None -> The function displays plots.
"""

if target_col not in df.columns:
raise ValueError(
f"Target column '{target_col}' must be part of the provided DataFrame"
)

num_cols_grid = 2 # Set the number of columns for the grid layout
num_variables = len(num_columns)
num_rows = (num_variables + num_cols_grid - 1) // num_cols_grid

# Set the size of the figure
fig, axes = plt.subplots(
num_rows, num_cols_grid, figsize=(14, 5 * num_rows), squeeze=False
)

for i, column in enumerate(num_columns):
if column not in df.columns:
raise ValueError(f"Column '{column}' not found in DataFrame")

row = i // num_cols_grid
col = i % num_cols_grid
ax = axes[row, col]

x = df[column]
y = df[target_col]

# Scatter plot
sns.scatterplot(x=x, y=y, ax=ax)

# Fit a regression line
X = sm.add_constant(x) # Adds a constant term to the predictor
model = sm.OLS(y, X).fit()
prediction = model.predict(X)

# Plot the regression line
ax.plot(x, prediction, color="red", label="Regression Line")

# Calculate and show the p-value
p_value = model.pvalues[1]
ax.annotate(
f"p-value: {p_value:.4f}",
xy=(0.05, 0.95),
xycoords="axes fraction",
fontsize=12,
verticalalignment="top",
bbox=dict(boxstyle="round,pad=0.3", edgecolor="black", facecolor="white"),
)

ax.set_title(f"Scatter Plot: {column} vs {target_col}")
ax.set_xlabel(column)
ax.set_ylabel(target_col)

# Remove any empty subplots
if num_variables < num_rows * num_cols_grid:
for i in range(num_variables, num_rows * num_cols_grid):
fig.delaxes(axes.flatten()[i])

# Adjust the spacing between subplots
plt.tight_layout()
plt.show()


def plot_pca(df: pd.DataFrame, target: str, scale_data: bool = True) -> None:
"""
Plots PCA for the dataframe. The target column must be part of the provided DataFrame.
Expand Down Expand Up @@ -814,7 +892,7 @@ def plot_ecdf(
np.arange(
min(df[col]),
max(df[col]),
max(find_nbins_with_freedman_diaconis(df[col].values), 0.1),
max(find_bind_with_with_freedman_diaconis(df[col].values), 0.1),
)
)
fig, ax1 = plt.subplots()
Expand Down
9 changes: 9 additions & 0 deletions bluecast/tests/test_analyse.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
correlation_heatmap,
correlation_to_target,
mutual_info_to_target,
plot_against_target_for_regression,
plot_andrews_curve,
plot_classification_target_distribution_within_categories,
plot_count_pairs,
Expand Down Expand Up @@ -394,3 +395,11 @@ def test_plot_andrews_curve_missing_target(sample_dataframe):
target_col = "NonExistentTarget"
with pytest.raises(KeyError):
plot_andrews_curve(sample_dataframe, target_col)


def test_plot_against_target_for_regression(synthetic_train_test_data_regression):
num_columns = ["numerical_feature_1", "numerical_feature_2", "numerical_feature_3"]
plot_against_target_for_regression(
synthetic_train_test_data_regression[0], num_columns, "target"
)
assert True
Binary file modified dist/bluecast-1.6.0-py3-none-any.whl
Binary file not shown.
Binary file modified dist/bluecast-1.6.0.tar.gz
Binary file not shown.

0 comments on commit babfa35

Please sign in to comment.