Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Grouped barchart for <= 12 rows #13

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,8 @@ venv.bak/

# mypy
.mypy_cache/


# specific
examples/_matplotlib.ipynb
.vscode/*
24 changes: 23 additions & 1 deletion altair_pandas/_core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import altair as alt
import pandas as pd

CHART_GROUPED_MAX = 12 # TODO: arbitrary - for discussion


def _valid_column(column_name):
return str(column_name)
Expand Down Expand Up @@ -97,6 +99,9 @@ def _preprocess_data(self, with_index=True, usecols=None):
if isinstance(data.index, pd.MultiIndex):
data.index = pd.Index(
[str(i) for i in data.index], name=data.index.name)

if data.index.name is None:
data.index.name = 'index'
return data.reset_index()
return data

Expand Down Expand Up @@ -136,13 +141,30 @@ def area(self, x=None, y=None, **kwargs):

# TODO: bars should be grouped, not stacked.
def bar(self, x=None, y=None, **kwargs):
return self._xy(

chart = self._xy(
{'type': 'bar', 'orient': 'vertical'}, x, y, **kwargs)

if len(self._data) <= CHART_GROUPED_MAX:

return chart.encode(
x=alt.X('column:N', title=None),
column=x or self._data.index.name or 'index'
)

return chart

def barh(self, x=None, y=None, **kwargs):
chart = self._xy(
{'type': 'bar', 'orient': 'horizontal'}, x, y, **kwargs)
chart.encoding.x, chart.encoding.y = chart.encoding.y, chart.encoding.x

if len(self._data) <= CHART_GROUPED_MAX:
return chart.encode(
y=alt.Y('column:N', title=None),
row=x or self._data.index.name or 'index'
)

return chart

def scatter(self, x, y, c=None, s=None, **kwargs):
Expand Down
26 changes: 24 additions & 2 deletions altair_pandas/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ def dataframe():
return pd.DataFrame({'x': range(5), 'y': range(5)})


@pytest.fixture
def large_dataframe():
return pd.DataFrame({'x': range(20), 'y': range(20)})


@pytest.mark.parametrize('data', [
pd.Series(
range(6),
Expand All @@ -27,9 +32,13 @@ def test_multiindex(data, with_plotting_backend):
chart = data.plot.bar()
spec = chart.to_dict()
assert list(chart.data.iloc[:, 0]) == [str(i) for i in data.index]
assert spec['encoding']['x']['field'] == 'index'
assert spec['encoding']['x']['type'] == 'nominal'

if isinstance(data, pd.Series):
assert spec['encoding']['x']['field'] == 'index'
else:
assert spec['encoding']['x']['field'] == 'column'


def test_nonstring_column_names(with_plotting_backend):
data = pd.DataFrame(np.ones((3, 4)))
Expand Down Expand Up @@ -65,9 +74,11 @@ def test_dataframe_basic_plot(dataframe, kind, with_plotting_backend):
spec = chart.to_dict()
if kind == 'bar':
assert spec['mark'] == {'type': 'bar', 'orient': 'vertical'}
assert spec['encoding']['x']['field'] == 'column'
else:
assert spec['mark'] == kind
assert spec['encoding']['x']['field'] == 'index'
assert spec['encoding']['x']['field'] == 'index'

assert spec['encoding']['y']['field'] == 'value'
assert spec['encoding']['color']['field'] == 'column'
assert spec['transform'][0]['fold'] == ['x', 'y']
Expand All @@ -85,6 +96,17 @@ def test_dataframe_barh(dataframe, with_plotting_backend):
chart = dataframe.plot.barh()
spec = chart.to_dict()
assert spec['mark'] == {'type': 'bar', 'orient': 'horizontal'}
assert spec['encoding']['y']['field'] == 'column'
assert spec['encoding']['x']['field'] == 'value'
assert spec['encoding']['color']['field'] == 'column'
assert spec['transform'][0]['fold'] == ['x', 'y']
assert spec['encoding']['row']['field'] == 'index'


def test_dataframe_large_barh(large_dataframe, with_plotting_backend):
chart = large_dataframe.plot.barh()
spec = chart.to_dict()
assert spec['mark'] == {'type': 'bar', 'orient': 'horizontal'}
assert spec['encoding']['y']['field'] == 'index'
assert spec['encoding']['x']['field'] == 'value'
assert spec['encoding']['color']['field'] == 'column'
Expand Down
Loading