Skip to content

Commit

Permalink
Clip the range of histograms when there are outliers (skrub-data#1157)
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromedockes authored Nov 27, 2024
1 parent d929329 commit a7abe4b
Show file tree
Hide file tree
Showing 6 changed files with 131 additions and 12 deletions.
4 changes: 4 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ Minor changes
* The :class:`TableReport` now also reports the number of unique values for
numeric columns. :pr:`1154` by :user:`Jérôme Dockès <jeromedockes>`.

* The :class:`TableReport`, when plotting histograms, now detects outliers and
clips the range of data shown in the histogram. This allows seeing more detail
in the shown distribution. :pr:`1157` by :user:`Jérôme Dockès <jeromedockes>`.

Bug fixes
---------

Expand Down
12 changes: 12 additions & 0 deletions skrub/_reporting/_data/templates/column-summaries.css
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,18 @@
justify-self: left;
}

/* Highlight min and max values when outliers were truncated from the histogram*/

.column-summary[data-has-low-outliers] .min-value{
color: var(--red);
font-weight: bold;
}

.column-summary[data-has-high-outliers] .max-value{
color: var(--red);
font-weight: bold;
}


/* Aspects specific to the single card shown in the dataframe sample tab */
/* --------------------------------------------------------------------- */
Expand Down
14 changes: 10 additions & 4 deletions skrub/_reporting/_data/templates/column-summary.html
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@
data-name-repr="{{ column.name.__repr__() }}"
data-column-name="{{ column.name }}"
data-column-idx="{{ column.idx }}"
{% if column['n_low_outliers'] %}
data-has-low-outliers
{% endif %}
{% if column['n_high_outliers'] %}
data-has-high-outliers
{% endif %}
data-manager="FilterableColumn {% if in_sample_tab %}SampleColumnSummary{% endif %}"
{% if column.value_is_constant %} data-constant-column {% endif %}
{% if in_sample_tab %} data-role="sample-column" data-hidden {% else %}
Expand Down Expand Up @@ -54,15 +60,15 @@ <h3 class="margin-r-m">

<dt>Min | Max</dt>
<dd>
{{ column.quantiles[0.0] | format_number }} |
{{ column.quantiles[1.0] | format_number }}
<span class="min-value">{{ column.quantiles[0.0] | format_number }}</span> |
<span class="max-value">{{ column.quantiles[1.0] | format_number }}</span>
{{ unit }}
</dd>
{% elif "min" in column %}
<dt>Min | Max</dt>
<dd>
{{ column.min | format_number }} |
{{ column.max | format_number }}
<span class="min-value">{{ column.min | format_number }}</span> |
<span class="max-value">{{ column.max | format_number }}</span>
{{ unit }}
</dd>
{% endif %}
Expand Down
69 changes: 65 additions & 4 deletions skrub/_reporting/_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
COLORS = _SEABORN
COLOR_0 = COLORS[0]

_RED = "#dd0000"


def _plot(plotting_fun):
"""Set the maptlotib config & silence some warnings for all report plots.
Expand Down Expand Up @@ -115,6 +117,67 @@ def _adjust_fig_size(fig, ax, target_w, target_h):
fig.set_size_inches((w, h))


def _get_range(values, frac=0.2, factor=3.0):
min_value, low_p, high_p, max_value = np.quantile(
values, [0.0, frac, 1.0 - frac, 1.0]
)
delta = high_p - low_p
if not delta:
return min_value, max_value
margin = factor * delta
low = low_p - margin
high = high_p + margin

# Chosen low bound should be max(low, min_value). Moreover, we add a small
# tolerance: if the clipping value is close to the actual minimum, extend
# it (so we don't clip right above the minimum which looks a bit silly).
if low - margin * 0.15 < min_value:
low = min_value
if max_value < high + margin * 0.15:
high = max_value
return low, high


def _robust_hist(values, ax, color):
low, high = _get_range(values)
inliers = values[(low <= values) & (values <= high)]
n_low_outliers = (values < low).sum()
n_high_outliers = (high < values).sum()
n, bins, patches = ax.hist(inliers)
n_out = n_low_outliers + n_high_outliers
if not n_out:
return 0, 0
width = bins[1] - bins[0]
start, stop = bins[0], bins[-1]
line_params = dict(color=_RED, linestyle="--", ymax=0.95)
if n_low_outliers:
start = bins[0] - width
ax.stairs([n_low_outliers], [start, bins[0]], color=_RED, fill=True)
ax.axvline(bins[0], **line_params)
if n_high_outliers:
stop = bins[-1] + width
ax.stairs([n_high_outliers], [bins[-1], stop], color=_RED, fill=True)
ax.axvline(bins[-1], **line_params)
ax.text(
# we place the text offset from the left rather than centering it to
# make room for the factor matplotlib sometimes places on the right of
# the axis eg "1e6" when the ticks are labelled in millions.
0.15,
1.0,
(
f"{_utils.format_number(n_out)} outliers "
f"({_utils.format_percent(n_out / len(values))})"
),
transform=ax.transAxes,
ha="left",
va="baseline",
fontweight="bold",
color=_RED,
)
ax.set_xlim(start, stop)
return n_low_outliers, n_high_outliers


@_plot
def histogram(col, duration_unit=None, color=COLOR_0):
"""Histogram for a numeric column."""
Expand All @@ -125,17 +188,15 @@ def histogram(col, duration_unit=None, color=COLOR_0):
# version if there are inf or nan)
col = sbd.to_float32(col)
values = sbd.to_numpy(col)
if np.issubdtype(values.dtype, np.floating):
values = values[np.isfinite(values)]
fig, ax = plt.subplots()
_despine(ax)
ax.hist(values, color=color)
n_low_outliers, n_high_outliers = _robust_hist(values, ax, color=color)
if duration_unit is not None:
ax.set_xlabel(f"{duration_unit.capitalize()}s")
if sbd.is_any_date(col):
_rotate_ticklabels(ax)
_adjust_fig_size(fig, ax, 2.0, 1.0)
return _serialize(fig)
return _serialize(fig), n_low_outliers, n_high_outliers


@_plot
Expand Down
14 changes: 10 additions & 4 deletions skrub/_reporting/_summarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,9 +209,11 @@ def _add_datetime_summary(summary, column, with_plots):
summary["min"] = min_date.isoformat()
summary["max"] = max_date.isoformat()
if with_plots:
summary["histogram_plot"] = _plotting.histogram(
column, color=_plotting.COLORS[0]
)
(
summary["histogram_plot"],
summary["n_low_outliers"],
summary["n_high_outliers"],
) = _plotting.histogram(column, color=_plotting.COLORS[0])


def _add_numeric_summary(
Expand Down Expand Up @@ -243,7 +245,11 @@ def _add_numeric_summary(
if not with_plots:
return
if order_by_column is None:
summary["histogram_plot"] = _plotting.histogram(
(
summary["histogram_plot"],
summary["n_low_outliers"],
summary["n_high_outliers"],
) = _plotting.histogram(
column, duration_unit=duration_unit, color=_plotting.COLORS[0]
)
else:
Expand Down
30 changes: 30 additions & 0 deletions skrub/_reporting/tests/test_plotting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import numpy as np
import pandas as pd

from skrub._reporting import _plotting


def test_histogram():
rng = np.random.default_rng(0)
x = rng.normal(size=200)
o = rng.uniform(-100, 100, size=10)

data = pd.Series(np.concatenate([x, o]))
_, n_low, n_high = _plotting.histogram(data)
assert (n_low, n_high) == (5, 4)

data = pd.Series(np.concatenate([x, o - 1000]))
_, n_low, n_high = _plotting.histogram(data)
assert (n_low, n_high) == (10, 0)

data = pd.Series(np.concatenate([x, o + 1000]))
_, n_low, n_high = _plotting.histogram(data)
assert (n_low, n_high) == (0, 10)

data = pd.Series(x)
_, n_low, n_high = _plotting.histogram(data)
assert (n_low, n_high) == (0, 0)

data = pd.Series([0.0])
_, n_low, n_high = _plotting.histogram(data)
assert (n_low, n_high) == (0, 0)

0 comments on commit a7abe4b

Please sign in to comment.