diff --git a/google/colab/_quickchart.py b/google/colab/_quickchart.py index 58633631..bde9bc43 100644 --- a/google/colab/_quickchart.py +++ b/google/colab/_quickchart.py @@ -1,28 +1,8 @@ """Automated chart generation for data frames.""" -import collections.abc import itertools -import logging import IPython -import numpy as np - - -_CATEGORICAL_DTYPES = ( - np.dtype('object'), - np.dtype('bool'), -) -_DEFAULT_DATETIME_DTYPE = np.dtype('datetime64[ns]') # a.k.a. " str: def determine_charts(df, dataframe_registry, max_chart_instances=None): """Finds charts compatible with dtypes of the given data frame.""" # Lazy import to avoid loading matplotlib and transitive deps on kernel init. + from google.colab import _quickchart_dtypes # pylint: disable=g-import-not-at-top from google.colab import _quickchart_helpers # pylint: disable=g-import-not-at-top - dtype_groups = _classify_dtypes(df) + dtype_groups = _quickchart_dtypes.classify_dtypes(df) numeric_cols = dtype_groups['numeric'] categorical_cols = dtype_groups['categorical'] time_cols = dtype_groups['datetime'] + dtype_groups['timelike'] @@ -214,134 +195,3 @@ def _select_time_series_cols(time_cols, numeric_cols, categorical_cols, k=None): ), k, ) - - -def _classify_dtypes( - df, - categorical_dtypes=_CATEGORICAL_DTYPES, - datetime_dtypes=_DATETIME_DTYPES, - datetime_dtype_kinds=_DATETIME_DTYPE_KINDS, - categorical_size_threshold=_CATEGORICAL_LARGE_SIZE_THRESHOLD, -): - """Classifies each dataframe series into a datatype group. - - Args: - df: (pd.DataFrame) A dataframe. - categorical_dtypes: (iterable) Categorical data types. - datetime_dtypes: (iterable) Datetime data types. - datetime_dtype_kinds: (iterable) Datetime dtype.kind values. - categorical_size_threshold: (int) The max number of unique values for a - given categorical to be considered "small". - - Returns: - ({str: list}) A dict mapping a dtype name to the corresponding - column names. - """ - # Lazy import to avoid loading pandas and transitive deps on kernel init. - import pandas as pd # pylint: disable=g-import-not-at-top - from pandas.api.types import is_numeric_dtype # pylint: disable=g-import-not-at-top - - dtypes = ( - pd.DataFrame(df.dtypes, columns=['colname_dtype']) - .reset_index() - .rename(columns={'index': 'colname'}) - ) - - filtered_cols = [] - numeric_cols = [] - cat_cols = [] - datetime_cols = [] - timelike_cols = [] - singleton_cols = [] - for colname, colname_dtype in zip(dtypes.colname, dtypes.colname_dtype): - if not all(df[colname].apply(pd.api.types.is_hashable)): - filtered_cols.append(colname) - elif len(df[colname].unique()) <= 1: - singleton_cols.append(colname) - elif colname_dtype in categorical_dtypes: - cat_cols.append(colname) - elif (colname_dtype in datetime_dtypes) or ( - colname_dtype.kind in datetime_dtype_kinds - ): - datetime_cols.append(colname) - elif is_numeric_dtype(colname_dtype): - numeric_cols.append(colname) - else: - filtered_cols.append(colname) - if filtered_cols: - logging.warning( - 'Quickchart encountered unexpected dtypes in columns: "%r"', - (filtered_cols,), - ) - - small_cat_cols, large_cat_cols = [], [] - for colname in cat_cols: - if len(df[colname].unique()) <= categorical_size_threshold: - small_cat_cols.append(colname) - else: - large_cat_cols.append(colname) - - def _matches_datetime_pattern(colname): - colname = str(colname).lower() - return any( - colname.startswith(p) or colname.endswith(p) - for p in _DATETIME_COLNAME_PATTERNS - ) or any(colname == c for c in _DATETIME_COLNAMES) - - for colname in df.columns: - if ( - _matches_datetime_pattern(colname) - or _is_monotonically_increasing_numeric(df[colname]) - ) and _all_values_scalar(df[colname]): - timelike_cols.append(colname) - - return { - 'numeric': numeric_cols, - 'categorical': small_cat_cols, - 'large_categorical': large_cat_cols, - 'datetime': datetime_cols, - 'timelike': timelike_cols, - 'singleton': singleton_cols, - 'filtered': filtered_cols, - } - - -def _is_monotonically_increasing_numeric(series): - # Pandas extension dtypes do not extend numpy's dtype and will fail if passed - # into issubdtype. - if not isinstance(series.dtype, np.dtype): - return False - return np.issubdtype(series.dtype.base, np.number) and np.all( - np.array(series)[:-1] <= np.array(series)[1:] - ) - - -def _all_values_scalar(series): - def _is_non_scalar(x): - return isinstance(x, collections.abc.Iterable) and not isinstance( - x, (bytes, str) - ) - - return not any(_is_non_scalar(x) for x in series) - - -def _get_axis_bounds(series, padding_percent=0.05, zero_rtol=1e-3): - """Gets the min/max axis bounds for a given data series. - - Args: - series: (pd.Series) A data series. - padding_percent: (float) The amount of padding to add to the minimal domain - extent as a percentage of the domain size. - zero_rtol: (float) If either min or max bound is within this relative - tolerance to zero, don't add padding for aesthetics. - - Returns: - ( min_bound, max_bound) - """ - min_bound, max_bound = series.min(), series.max() - padding = (max_bound - min_bound) * padding_percent - if not np.allclose(0, min_bound, rtol=zero_rtol): - min_bound -= padding - if not np.allclose(0, max_bound, rtol=zero_rtol): - max_bound += padding - return min_bound, max_bound diff --git a/google/colab/_quickchart_dtypes.py b/google/colab/_quickchart_dtypes.py new file mode 100644 index 00000000..0f352678 --- /dev/null +++ b/google/colab/_quickchart_dtypes.py @@ -0,0 +1,140 @@ +"""Chart datatype inference utilities.""" + +import collections.abc +import logging + +import numpy as np + + +_CATEGORICAL_DTYPES = ( + np.dtype('object'), + np.dtype('bool'), +) +_DEFAULT_DATETIME_DTYPE = np.dtype('datetime64[ns]') # a.k.a. ") Categorical data types. + datetime_dtypes: (iterable) Datetime data types. + datetime_dtype_kinds: (iterable) Datetime dtype.kind values. + categorical_size_threshold: (int) The max number of unique values for a + given categorical to be considered "small". + + Returns: + ({str: list}) A dict mapping a dtype name to the corresponding + column names. + """ + # Lazy import to avoid loading pandas and transitive deps on kernel init. + import pandas as pd # pylint: disable=g-import-not-at-top + from pandas.api.types import is_numeric_dtype # pylint: disable=g-import-not-at-top + + dtypes = ( + pd.DataFrame(df.dtypes, columns=['colname_dtype']) + .reset_index() + .rename(columns={'index': 'colname'}) + ) + + filtered_cols = [] + numeric_cols = [] + cat_cols = [] + datetime_cols = [] + timelike_cols = [] + singleton_cols = [] + for colname, colname_dtype in zip(dtypes.colname, dtypes.colname_dtype): + if not all(df[colname].apply(pd.api.types.is_hashable)): + filtered_cols.append(colname) + elif len(df[colname].unique()) <= 1: + singleton_cols.append(colname) + elif colname_dtype in categorical_dtypes: + cat_cols.append(colname) + elif (colname_dtype in datetime_dtypes) or ( + colname_dtype.kind in datetime_dtype_kinds + ): + datetime_cols.append(colname) + elif is_numeric_dtype(colname_dtype): + numeric_cols.append(colname) + else: + filtered_cols.append(colname) + if filtered_cols: + logging.warning( + 'Quickchart encountered unexpected dtypes in columns: "%r"', + (filtered_cols,), + ) + + small_cat_cols, large_cat_cols = [], [] + for colname in cat_cols: + if len(df[colname].unique()) <= categorical_size_threshold: + small_cat_cols.append(colname) + else: + large_cat_cols.append(colname) + + def _matches_datetime_pattern(colname): + colname = str(colname).lower() + return any( + colname.startswith(p) or colname.endswith(p) + for p in _DATETIME_COLNAME_PATTERNS + ) or any(colname == c for c in _DATETIME_COLNAMES) + + for colname in df.columns: + if ( + _matches_datetime_pattern(colname) + or _is_monotonically_increasing_numeric(df[colname]) + ) and _all_values_scalar(df[colname]): + timelike_cols.append(colname) + + return { + 'numeric': numeric_cols, + 'categorical': small_cat_cols, + 'large_categorical': large_cat_cols, + 'datetime': datetime_cols, + 'timelike': timelike_cols, + 'singleton': singleton_cols, + 'filtered': filtered_cols, + } + + +def _is_monotonically_increasing_numeric(series): + # Pandas extension dtypes do not extend numpy's dtype and will fail if passed + # into issubdtype. + if not isinstance(series.dtype, np.dtype): + return False + return np.issubdtype(series.dtype.base, np.number) and np.all( + np.array(series)[:-1] <= np.array(series)[1:] + ) + + +def _all_values_scalar(series): + def _is_non_scalar(x): + return isinstance(x, collections.abc.Iterable) and not isinstance( + x, (bytes, str) + ) + + return not any(_is_non_scalar(x) for x in series)