diff --git a/CHANGES.rst b/CHANGES.rst index 56a7835e3..51ab2bca4 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -96,7 +96,7 @@ Bug fixes * The :class:`TableReport` would raise an exception when a column contained infinite values. This has been fixed in :pr:`1150` by :user:`Jérôme Dockès - `. + ` and :pr:`1151` by Jérôme Dockès. Release 0.3.1 ============= diff --git a/skrub/_reporting/_plotting.py b/skrub/_reporting/_plotting.py index 2d7e6214d..1fcbea7af 100644 --- a/skrub/_reporting/_plotting.py +++ b/skrub/_reporting/_plotting.py @@ -9,6 +9,7 @@ import warnings import matplotlib +import numpy as np from matplotlib import pyplot as plt from skrub import _dataframe as sbd @@ -118,7 +119,14 @@ def _adjust_fig_size(fig, ax, target_w, target_h): def histogram(col, color=COLOR_0): """Histogram for a numeric column.""" col = sbd.drop_nulls(col) + if sbd.is_float(col): + # avoid any issues with pandas nullable dtypes + # (to_numpy can yield a numpy array with object dtype in old pandas + # version if there are inf or nan) + col = sbd.to_float32(col) values = sbd.to_numpy(col) + if np.issubdtype(values.dtype, np.floating): + values = values[np.isfinite(values)] fig, ax = plt.subplots() _despine(ax) ax.hist(values, color=color) diff --git a/skrub/_reporting/tests/test_table_report.py b/skrub/_reporting/tests/test_table_report.py index 2fae33397..859457e02 100644 --- a/skrub/_reporting/tests/test_table_report.py +++ b/skrub/_reporting/tests/test_table_report.py @@ -1,5 +1,6 @@ import json import re +import warnings from skrub import TableReport, ToDatetime from skrub import _dataframe as sbd @@ -99,3 +100,16 @@ def test_duplicate_columns(pd_module): df = pd_module.make_dataframe({"a": [1, 2], "b": [3, 4]}) df.columns = ["a", "a"] TableReport(df).html() + + +def test_infinite_values(df_module): + # Non-regression for https://github.com/skrub-data/skrub/issues/1134 + # (histogram plot failing with infinite values) + with warnings.catch_warnings(): + # convert_dtypes() emits spurious warning while deciding whether to cast to int + warnings.filterwarnings("ignore", message="invalid value encountered in cast") + df = df_module.make_dataframe( + dict(a=[float("inf"), 1.0, 2.0], b=[0.0, 1.0, 2.0]) + ) + + TableReport(df).html()