From a7abe4b6038648583d44f2dca8d84efc1e76a9ca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Dock=C3=A8s?= Date: Wed, 27 Nov 2024 15:08:12 +0100 Subject: [PATCH] Clip the range of histograms when there are outliers (#1157) --- CHANGES.rst | 4 ++ .../_data/templates/column-summaries.css | 12 ++++ .../_data/templates/column-summary.html | 14 ++-- skrub/_reporting/_plotting.py | 69 +++++++++++++++++-- skrub/_reporting/_summarize.py | 14 ++-- skrub/_reporting/tests/test_plotting.py | 30 ++++++++ 6 files changed, 131 insertions(+), 12 deletions(-) create mode 100644 skrub/_reporting/tests/test_plotting.py diff --git a/CHANGES.rst b/CHANGES.rst index 4f58d2fd3..1cf702c64 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -86,6 +86,10 @@ Minor changes * The :class:`TableReport` now also reports the number of unique values for numeric columns. :pr:`1154` by :user:`Jérôme Dockès `. +* The :class:`TableReport`, when plotting histograms, now detects outliers and + clips the range of data shown in the histogram. This allows seeing more detail + in the shown distribution. :pr:`1157` by :user:`Jérôme Dockès `. + Bug fixes --------- diff --git a/skrub/_reporting/_data/templates/column-summaries.css b/skrub/_reporting/_data/templates/column-summaries.css index 2f04ab776..6314734fe 100644 --- a/skrub/_reporting/_data/templates/column-summaries.css +++ b/skrub/_reporting/_data/templates/column-summaries.css @@ -94,6 +94,18 @@ justify-self: left; } +/* Highlight min and max values when outliers were truncated from the histogram*/ + +.column-summary[data-has-low-outliers] .min-value{ + color: var(--red); + font-weight: bold; +} + +.column-summary[data-has-high-outliers] .max-value{ + color: var(--red); + font-weight: bold; +} + /* Aspects specific to the single card shown in the dataframe sample tab */ /* --------------------------------------------------------------------- */ diff --git a/skrub/_reporting/_data/templates/column-summary.html b/skrub/_reporting/_data/templates/column-summary.html index c01c3444a..06fb10f38 100644 --- a/skrub/_reporting/_data/templates/column-summary.html +++ b/skrub/_reporting/_data/templates/column-summary.html @@ -4,6 +4,12 @@ data-name-repr="{{ column.name.__repr__() }}" data-column-name="{{ column.name }}" data-column-idx="{{ column.idx }}" + {% if column['n_low_outliers'] %} + data-has-low-outliers + {% endif %} + {% if column['n_high_outliers'] %} + data-has-high-outliers + {% endif %} data-manager="FilterableColumn {% if in_sample_tab %}SampleColumnSummary{% endif %}" {% if column.value_is_constant %} data-constant-column {% endif %} {% if in_sample_tab %} data-role="sample-column" data-hidden {% else %} @@ -54,15 +60,15 @@

Min | Max
- {{ column.quantiles[0.0] | format_number }} | - {{ column.quantiles[1.0] | format_number }} + {{ column.quantiles[0.0] | format_number }} | + {{ column.quantiles[1.0] | format_number }} {{ unit }}
{% elif "min" in column %}
Min | Max
- {{ column.min | format_number }} | - {{ column.max | format_number }} + {{ column.min | format_number }} | + {{ column.max | format_number }} {{ unit }}
{% endif %} diff --git a/skrub/_reporting/_plotting.py b/skrub/_reporting/_plotting.py index d7476b376..f8f36143d 100644 --- a/skrub/_reporting/_plotting.py +++ b/skrub/_reporting/_plotting.py @@ -39,6 +39,8 @@ COLORS = _SEABORN COLOR_0 = COLORS[0] +_RED = "#dd0000" + def _plot(plotting_fun): """Set the maptlotib config & silence some warnings for all report plots. @@ -115,6 +117,67 @@ def _adjust_fig_size(fig, ax, target_w, target_h): fig.set_size_inches((w, h)) +def _get_range(values, frac=0.2, factor=3.0): + min_value, low_p, high_p, max_value = np.quantile( + values, [0.0, frac, 1.0 - frac, 1.0] + ) + delta = high_p - low_p + if not delta: + return min_value, max_value + margin = factor * delta + low = low_p - margin + high = high_p + margin + + # Chosen low bound should be max(low, min_value). Moreover, we add a small + # tolerance: if the clipping value is close to the actual minimum, extend + # it (so we don't clip right above the minimum which looks a bit silly). + if low - margin * 0.15 < min_value: + low = min_value + if max_value < high + margin * 0.15: + high = max_value + return low, high + + +def _robust_hist(values, ax, color): + low, high = _get_range(values) + inliers = values[(low <= values) & (values <= high)] + n_low_outliers = (values < low).sum() + n_high_outliers = (high < values).sum() + n, bins, patches = ax.hist(inliers) + n_out = n_low_outliers + n_high_outliers + if not n_out: + return 0, 0 + width = bins[1] - bins[0] + start, stop = bins[0], bins[-1] + line_params = dict(color=_RED, linestyle="--", ymax=0.95) + if n_low_outliers: + start = bins[0] - width + ax.stairs([n_low_outliers], [start, bins[0]], color=_RED, fill=True) + ax.axvline(bins[0], **line_params) + if n_high_outliers: + stop = bins[-1] + width + ax.stairs([n_high_outliers], [bins[-1], stop], color=_RED, fill=True) + ax.axvline(bins[-1], **line_params) + ax.text( + # we place the text offset from the left rather than centering it to + # make room for the factor matplotlib sometimes places on the right of + # the axis eg "1e6" when the ticks are labelled in millions. + 0.15, + 1.0, + ( + f"{_utils.format_number(n_out)} outliers " + f"({_utils.format_percent(n_out / len(values))})" + ), + transform=ax.transAxes, + ha="left", + va="baseline", + fontweight="bold", + color=_RED, + ) + ax.set_xlim(start, stop) + return n_low_outliers, n_high_outliers + + @_plot def histogram(col, duration_unit=None, color=COLOR_0): """Histogram for a numeric column.""" @@ -125,17 +188,15 @@ def histogram(col, duration_unit=None, color=COLOR_0): # version if there are inf or nan) col = sbd.to_float32(col) values = sbd.to_numpy(col) - if np.issubdtype(values.dtype, np.floating): - values = values[np.isfinite(values)] fig, ax = plt.subplots() _despine(ax) - ax.hist(values, color=color) + n_low_outliers, n_high_outliers = _robust_hist(values, ax, color=color) if duration_unit is not None: ax.set_xlabel(f"{duration_unit.capitalize()}s") if sbd.is_any_date(col): _rotate_ticklabels(ax) _adjust_fig_size(fig, ax, 2.0, 1.0) - return _serialize(fig) + return _serialize(fig), n_low_outliers, n_high_outliers @_plot diff --git a/skrub/_reporting/_summarize.py b/skrub/_reporting/_summarize.py index ebecddbb0..06237b2e7 100644 --- a/skrub/_reporting/_summarize.py +++ b/skrub/_reporting/_summarize.py @@ -209,9 +209,11 @@ def _add_datetime_summary(summary, column, with_plots): summary["min"] = min_date.isoformat() summary["max"] = max_date.isoformat() if with_plots: - summary["histogram_plot"] = _plotting.histogram( - column, color=_plotting.COLORS[0] - ) + ( + summary["histogram_plot"], + summary["n_low_outliers"], + summary["n_high_outliers"], + ) = _plotting.histogram(column, color=_plotting.COLORS[0]) def _add_numeric_summary( @@ -243,7 +245,11 @@ def _add_numeric_summary( if not with_plots: return if order_by_column is None: - summary["histogram_plot"] = _plotting.histogram( + ( + summary["histogram_plot"], + summary["n_low_outliers"], + summary["n_high_outliers"], + ) = _plotting.histogram( column, duration_unit=duration_unit, color=_plotting.COLORS[0] ) else: diff --git a/skrub/_reporting/tests/test_plotting.py b/skrub/_reporting/tests/test_plotting.py new file mode 100644 index 000000000..c49eb98aa --- /dev/null +++ b/skrub/_reporting/tests/test_plotting.py @@ -0,0 +1,30 @@ +import numpy as np +import pandas as pd + +from skrub._reporting import _plotting + + +def test_histogram(): + rng = np.random.default_rng(0) + x = rng.normal(size=200) + o = rng.uniform(-100, 100, size=10) + + data = pd.Series(np.concatenate([x, o])) + _, n_low, n_high = _plotting.histogram(data) + assert (n_low, n_high) == (5, 4) + + data = pd.Series(np.concatenate([x, o - 1000])) + _, n_low, n_high = _plotting.histogram(data) + assert (n_low, n_high) == (10, 0) + + data = pd.Series(np.concatenate([x, o + 1000])) + _, n_low, n_high = _plotting.histogram(data) + assert (n_low, n_high) == (0, 10) + + data = pd.Series(x) + _, n_low, n_high = _plotting.histogram(data) + assert (n_low, n_high) == (0, 0) + + data = pd.Series([0.0]) + _, n_low, n_high = _plotting.histogram(data) + assert (n_low, n_high) == (0, 0)