Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 586702633
  • Loading branch information
drewbryant authored and colaboratory-team committed Nov 30, 2023
1 parent bc057a3 commit 9bec4f8
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 153 deletions.
156 changes: 3 additions & 153 deletions google/colab/_quickchart.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,8 @@
"""Automated chart generation for data frames."""
import collections.abc
import itertools
import logging

import IPython
import numpy as np


_CATEGORICAL_DTYPES = (
np.dtype('object'),
np.dtype('bool'),
)
_DEFAULT_DATETIME_DTYPE = np.dtype('datetime64[ns]') # a.k.a. "<M8[ns]".
_DATETIME_DTYPES = (_DEFAULT_DATETIME_DTYPE,)
_DATETIME_DTYPE_KINDS = ('M',) # More general set of datetime dtypes.
_DATETIME_COLNAME_PATTERNS = (
'date',
'datetime',
'time',
'timestamp',
) # Prefix/suffix matches.
_DATETIME_COLNAMES = ('dt', 't', 'ts', 'year') # Exact matches.
_EXPECTED_DTYPES = _CATEGORICAL_DTYPES + _DATETIME_DTYPES
_CATEGORICAL_LARGE_SIZE_THRESHOLD = 8 # Facet-friendly size limit.


_DATAFRAME_REGISTRY = None

Expand Down Expand Up @@ -85,9 +65,10 @@ def get_or_register_varname(self, _) -> str:
def determine_charts(df, dataframe_registry, max_chart_instances=None):
"""Finds charts compatible with dtypes of the given data frame."""
# Lazy import to avoid loading matplotlib and transitive deps on kernel init.
from google.colab import _quickchart_dtypes # pylint: disable=g-import-not-at-top
from google.colab import _quickchart_helpers # pylint: disable=g-import-not-at-top

dtype_groups = _classify_dtypes(df)
dtype_groups = _quickchart_dtypes.classify_dtypes(df)
numeric_cols = dtype_groups['numeric']
categorical_cols = dtype_groups['categorical']
time_cols = dtype_groups['datetime'] + dtype_groups['timelike']
Expand Down Expand Up @@ -214,134 +195,3 @@ def _select_time_series_cols(time_cols, numeric_cols, categorical_cols, k=None):
),
k,
)


def _classify_dtypes(
df,
categorical_dtypes=_CATEGORICAL_DTYPES,
datetime_dtypes=_DATETIME_DTYPES,
datetime_dtype_kinds=_DATETIME_DTYPE_KINDS,
categorical_size_threshold=_CATEGORICAL_LARGE_SIZE_THRESHOLD,
):
"""Classifies each dataframe series into a datatype group.
Args:
df: (pd.DataFrame) A dataframe.
categorical_dtypes: (iterable<str>) Categorical data types.
datetime_dtypes: (iterable<str>) Datetime data types.
datetime_dtype_kinds: (iterable<str>) Datetime dtype.kind values.
categorical_size_threshold: (int) The max number of unique values for a
given categorical to be considered "small".
Returns:
({str: list<str>}) A dict mapping a dtype name to the corresponding
column names.
"""
# Lazy import to avoid loading pandas and transitive deps on kernel init.
import pandas as pd # pylint: disable=g-import-not-at-top
from pandas.api.types import is_numeric_dtype # pylint: disable=g-import-not-at-top

dtypes = (
pd.DataFrame(df.dtypes, columns=['colname_dtype'])
.reset_index()
.rename(columns={'index': 'colname'})
)

filtered_cols = []
numeric_cols = []
cat_cols = []
datetime_cols = []
timelike_cols = []
singleton_cols = []
for colname, colname_dtype in zip(dtypes.colname, dtypes.colname_dtype):
if not all(df[colname].apply(pd.api.types.is_hashable)):
filtered_cols.append(colname)
elif len(df[colname].unique()) <= 1:
singleton_cols.append(colname)
elif colname_dtype in categorical_dtypes:
cat_cols.append(colname)
elif (colname_dtype in datetime_dtypes) or (
colname_dtype.kind in datetime_dtype_kinds
):
datetime_cols.append(colname)
elif is_numeric_dtype(colname_dtype):
numeric_cols.append(colname)
else:
filtered_cols.append(colname)
if filtered_cols:
logging.warning(
'Quickchart encountered unexpected dtypes in columns: "%r"',
(filtered_cols,),
)

small_cat_cols, large_cat_cols = [], []
for colname in cat_cols:
if len(df[colname].unique()) <= categorical_size_threshold:
small_cat_cols.append(colname)
else:
large_cat_cols.append(colname)

def _matches_datetime_pattern(colname):
colname = str(colname).lower()
return any(
colname.startswith(p) or colname.endswith(p)
for p in _DATETIME_COLNAME_PATTERNS
) or any(colname == c for c in _DATETIME_COLNAMES)

for colname in df.columns:
if (
_matches_datetime_pattern(colname)
or _is_monotonically_increasing_numeric(df[colname])
) and _all_values_scalar(df[colname]):
timelike_cols.append(colname)

return {
'numeric': numeric_cols,
'categorical': small_cat_cols,
'large_categorical': large_cat_cols,
'datetime': datetime_cols,
'timelike': timelike_cols,
'singleton': singleton_cols,
'filtered': filtered_cols,
}


def _is_monotonically_increasing_numeric(series):
# Pandas extension dtypes do not extend numpy's dtype and will fail if passed
# into issubdtype.
if not isinstance(series.dtype, np.dtype):
return False
return np.issubdtype(series.dtype.base, np.number) and np.all(
np.array(series)[:-1] <= np.array(series)[1:]
)


def _all_values_scalar(series):
def _is_non_scalar(x):
return isinstance(x, collections.abc.Iterable) and not isinstance(
x, (bytes, str)
)

return not any(_is_non_scalar(x) for x in series)


def _get_axis_bounds(series, padding_percent=0.05, zero_rtol=1e-3):
"""Gets the min/max axis bounds for a given data series.
Args:
series: (pd.Series) A data series.
padding_percent: (float) The amount of padding to add to the minimal domain
extent as a percentage of the domain size.
zero_rtol: (float) If either min or max bound is within this relative
tolerance to zero, don't add padding for aesthetics.
Returns:
(<float> min_bound, <float> max_bound)
"""
min_bound, max_bound = series.min(), series.max()
padding = (max_bound - min_bound) * padding_percent
if not np.allclose(0, min_bound, rtol=zero_rtol):
min_bound -= padding
if not np.allclose(0, max_bound, rtol=zero_rtol):
max_bound += padding
return min_bound, max_bound
140 changes: 140 additions & 0 deletions google/colab/_quickchart_dtypes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
"""Chart datatype inference utilities."""

import collections.abc
import logging

import numpy as np


_CATEGORICAL_DTYPES = (
np.dtype('object'),
np.dtype('bool'),
)
_DEFAULT_DATETIME_DTYPE = np.dtype('datetime64[ns]') # a.k.a. "<M8[ns]".
_DATETIME_DTYPES = (_DEFAULT_DATETIME_DTYPE,)
_DATETIME_DTYPE_KINDS = ('M',) # More general set of datetime dtypes.
_DATETIME_COLNAME_PATTERNS = (
'date',
'datetime',
'time',
'timestamp',
) # Prefix/suffix matches.
_DATETIME_COLNAMES = ('dt', 't', 'ts', 'year') # Exact matches.
_EXPECTED_DTYPES = _CATEGORICAL_DTYPES + _DATETIME_DTYPES
_CATEGORICAL_LARGE_SIZE_THRESHOLD = 8 # Facet-friendly size limit.


def is_categorical(series):
return (
series.dtype in _CATEGORICAL_DTYPES
and len(series.unique()) <= _CATEGORICAL_LARGE_SIZE_THRESHOLD
)


def classify_dtypes(
df,
categorical_dtypes=_CATEGORICAL_DTYPES,
datetime_dtypes=_DATETIME_DTYPES,
datetime_dtype_kinds=_DATETIME_DTYPE_KINDS,
categorical_size_threshold=_CATEGORICAL_LARGE_SIZE_THRESHOLD,
):
"""Classifies each dataframe series into a datatype group.
Args:
df: (pd.DataFrame) A dataframe.
categorical_dtypes: (iterable<str>) Categorical data types.
datetime_dtypes: (iterable<str>) Datetime data types.
datetime_dtype_kinds: (iterable<str>) Datetime dtype.kind values.
categorical_size_threshold: (int) The max number of unique values for a
given categorical to be considered "small".
Returns:
({str: list<str>}) A dict mapping a dtype name to the corresponding
column names.
"""
# Lazy import to avoid loading pandas and transitive deps on kernel init.
import pandas as pd # pylint: disable=g-import-not-at-top
from pandas.api.types import is_numeric_dtype # pylint: disable=g-import-not-at-top

dtypes = (
pd.DataFrame(df.dtypes, columns=['colname_dtype'])
.reset_index()
.rename(columns={'index': 'colname'})
)

filtered_cols = []
numeric_cols = []
cat_cols = []
datetime_cols = []
timelike_cols = []
singleton_cols = []
for colname, colname_dtype in zip(dtypes.colname, dtypes.colname_dtype):
if not all(df[colname].apply(pd.api.types.is_hashable)):
filtered_cols.append(colname)
elif len(df[colname].unique()) <= 1:
singleton_cols.append(colname)
elif colname_dtype in categorical_dtypes:
cat_cols.append(colname)
elif (colname_dtype in datetime_dtypes) or (
colname_dtype.kind in datetime_dtype_kinds
):
datetime_cols.append(colname)
elif is_numeric_dtype(colname_dtype):
numeric_cols.append(colname)
else:
filtered_cols.append(colname)
if filtered_cols:
logging.warning(
'Quickchart encountered unexpected dtypes in columns: "%r"',
(filtered_cols,),
)

small_cat_cols, large_cat_cols = [], []
for colname in cat_cols:
if len(df[colname].unique()) <= categorical_size_threshold:
small_cat_cols.append(colname)
else:
large_cat_cols.append(colname)

def _matches_datetime_pattern(colname):
colname = str(colname).lower()
return any(
colname.startswith(p) or colname.endswith(p)
for p in _DATETIME_COLNAME_PATTERNS
) or any(colname == c for c in _DATETIME_COLNAMES)

for colname in df.columns:
if (
_matches_datetime_pattern(colname)
or _is_monotonically_increasing_numeric(df[colname])
) and _all_values_scalar(df[colname]):
timelike_cols.append(colname)

return {
'numeric': numeric_cols,
'categorical': small_cat_cols,
'large_categorical': large_cat_cols,
'datetime': datetime_cols,
'timelike': timelike_cols,
'singleton': singleton_cols,
'filtered': filtered_cols,
}


def _is_monotonically_increasing_numeric(series):
# Pandas extension dtypes do not extend numpy's dtype and will fail if passed
# into issubdtype.
if not isinstance(series.dtype, np.dtype):
return False
return np.issubdtype(series.dtype.base, np.number) and np.all(
np.array(series)[:-1] <= np.array(series)[1:]
)


def _all_values_scalar(series):
def _is_non_scalar(x):
return isinstance(x, collections.abc.Iterable) and not isinstance(
x, (bytes, str)
)

return not any(_is_non_scalar(x) for x in series)

0 comments on commit 9bec4f8

Please sign in to comment.