Skip to content

Commit

Permalink
fix: Plotting was not interacting well with Altair schema wrappers (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli authored Oct 13, 2024
1 parent 0628e0e commit 207ddb0
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 53 deletions.
2 changes: 1 addition & 1 deletion docs/source/user-guide/misc/visualization.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ This is shorthand for:
import altair as alt

(
alt.Chart(df).mark_point().encode(
alt.Chart(df).mark_point(tooltip=True).encode(
x="sepal_length",
y="sepal_width",
color="species",
Expand Down
8 changes: 4 additions & 4 deletions py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,17 +627,17 @@ def plot(self) -> DataFramePlot:
- `df.plot.line(**kwargs)`
is shorthand for
`alt.Chart(df).mark_line().encode(**kwargs).interactive()`
`alt.Chart(df).mark_line(tooltip=True).encode(**kwargs).interactive()`
- `df.plot.point(**kwargs)`
is shorthand for
`alt.Chart(df).mark_point().encode(**kwargs).interactive()` (and
`alt.Chart(df).mark_point(tooltip=True).encode(**kwargs).interactive()` (and
`plot.scatter` is provided as an alias)
- `df.plot.bar(**kwargs)`
is shorthand for
`alt.Chart(df).mark_bar().encode(**kwargs).interactive()`
`alt.Chart(df).mark_bar(tooltip=True).encode(**kwargs).interactive()`
- for any other attribute `attr`, `df.plot.attr(**kwargs)`
is shorthand for
`alt.Chart(df).mark_attr().encode(**kwargs).interactive()`
`alt.Chart(df).mark_attr(tooltip=True).encode(**kwargs).interactive()`
Examples
--------
Expand Down
35 changes: 12 additions & 23 deletions py-polars/polars/dataframe/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,6 @@
Encodings: TypeAlias = dict[str, Encoding]


def _maybe_extract_shorthand(encoding: Encoding) -> Encoding:
if isinstance(encoding, alt.SchemaBase):
# e.g. for `alt.X('x:Q', axis=alt.Axis(labelAngle=30))`, return `'x:Q'`
return getattr(encoding, "shorthand", encoding)
return encoding


def _add_tooltip(encodings: Encodings, /, **kwargs: Unpack[EncodeKwds]) -> None:
if "tooltip" not in kwargs:
encodings["tooltip"] = [
*[_maybe_extract_shorthand(x) for x in encodings.values()],
*[_maybe_extract_shorthand(x) for x in kwargs.values()], # type: ignore[arg-type]
] # type: ignore[assignment]


class DataFramePlot:
"""DataFrame.plot namespace."""

Expand Down Expand Up @@ -107,8 +92,11 @@ def bar(
encodings["y"] = y
if color is not None:
encodings["color"] = color
_add_tooltip(encodings, **kwargs)
return self._chart.mark_bar().encode(**encodings, **kwargs).interactive()
return (
self._chart.mark_bar(tooltip=True)
.encode(**encodings, **kwargs)
.interactive()
)

def line(
self,
Expand Down Expand Up @@ -169,8 +157,11 @@ def line(
encodings["color"] = color
if order is not None:
encodings["order"] = order
_add_tooltip(encodings, **kwargs)
return self._chart.mark_line().encode(**encodings, **kwargs).interactive()
return (
self._chart.mark_line(tooltip=True)
.encode(**encodings, **kwargs)
.interactive()
)

def point(
self,
Expand Down Expand Up @@ -231,9 +222,8 @@ def point(
encodings["color"] = color
if size is not None:
encodings["size"] = size
_add_tooltip(encodings, **kwargs)
return (
self._chart.mark_point()
self._chart.mark_point(tooltip=True)
.encode(
**encodings,
**kwargs,
Expand All @@ -252,7 +242,6 @@ def __getattr__(self, attr: str) -> Callable[..., alt.Chart]:
encodings: Encodings = {}

def func(**kwargs: EncodeKwds) -> alt.Chart:
_add_tooltip(encodings, **kwargs)
return method().encode(**encodings, **kwargs).interactive()
return method(tooltip=True).encode(**encodings, **kwargs).interactive()

return func
22 changes: 10 additions & 12 deletions py-polars/polars/series/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from typing import TYPE_CHECKING, Callable

from polars.dataframe.plotting import _add_tooltip
from polars.dependencies import altair as alt

if TYPE_CHECKING:
Expand Down Expand Up @@ -42,7 +41,7 @@ def hist(
`Altair <https://altair-viz.github.io/>`_.
`s.plot.hist(**kwargs)` is shorthand for
`alt.Chart(s.to_frame()).mark_bar().encode(x=alt.X(f'{s.name}:Q', bin=True), y='count()', **kwargs).interactive()`,
`alt.Chart(s.to_frame()).mark_bar(tooltip=True).encode(x=alt.X(f'{s.name}:Q', bin=True), y='count()', **kwargs).interactive()`,
and is provided for convenience - for full customisatibility, use a plotting
library directly.
Expand All @@ -69,9 +68,11 @@ def hist(
"x": alt.X(f"{self._series_name}:Q", bin=True),
"y": "count()",
}
_add_tooltip(encodings, **kwargs)
return (
alt.Chart(self._df).mark_bar().encode(**encodings, **kwargs).interactive()
alt.Chart(self._df)
.mark_bar(tooltip=True)
.encode(**encodings, **kwargs)
.interactive()
)

def kde(
Expand All @@ -86,7 +87,7 @@ def kde(
`Altair <https://altair-viz.github.io/>`_.
`s.plot.kde(**kwargs)` is shorthand for
`alt.Chart(s.to_frame()).transform_density(s.name, as_=[s.name, 'density']).mark_area().encode(x=s.name, y='density:Q', **kwargs).interactive()`,
`alt.Chart(s.to_frame()).transform_density(s.name, as_=[s.name, 'density']).mark_area(tooltip=True).encode(x=s.name, y='density:Q', **kwargs).interactive()`,
and is provided for convenience - for full customisatibility, use a plotting
library directly.
Expand All @@ -110,11 +111,10 @@ def kde(
msg = "Cannot use `plot.kde` when Series name is `'density'`"
raise ValueError(msg)
encodings: Encodings = {"x": self._series_name, "y": "density:Q"}
_add_tooltip(encodings, **kwargs)
return (
alt.Chart(self._df)
.transform_density(self._series_name, as_=[self._series_name, "density"])
.mark_area()
.mark_area(tooltip=True)
.encode(**encodings, **kwargs)
.interactive()
)
Expand All @@ -131,7 +131,7 @@ def line(
`Altair <https://altair-viz.github.io/>`_.
`s.plot.line(**kwargs)` is shorthand for
`alt.Chart(s.to_frame().with_row_index()).mark_line().encode(x='index', y=s.name, **kwargs).interactive()`,
`alt.Chart(s.to_frame().with_row_index()).mark_line(tooltip=True).encode(x='index', y=s.name, **kwargs).interactive()`,
and is provided for convenience - for full customisatibility, use a plotting
library directly.
Expand All @@ -155,10 +155,9 @@ def line(
msg = "Cannot call `plot.line` when Series name is 'index'"
raise ValueError(msg)
encodings: Encodings = {"x": "index", "y": self._series_name}
_add_tooltip(encodings, **kwargs)
return (
alt.Chart(self._df.with_row_index())
.mark_line()
.mark_line(tooltip=True)
.encode(**encodings, **kwargs)
.interactive()
)
Expand All @@ -177,7 +176,6 @@ def __getattr__(self, attr: str) -> Callable[..., alt.Chart]:
encodings: Encodings = {"x": "index", "y": self._series_name}

def func(**kwargs: EncodeKwds) -> alt.Chart:
_add_tooltip(encodings, **kwargs)
return method().encode(**encodings, **kwargs).interactive()
return method(tooltip=True).encode(**encodings, **kwargs).interactive()

return func
6 changes: 3 additions & 3 deletions py-polars/polars/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -7477,13 +7477,13 @@ def plot(self) -> SeriesPlot:
- `s.plot.hist(**kwargs)`
is shorthand for
`alt.Chart(s.to_frame()).mark_bar().encode(x=alt.X(f'{s.name}:Q', bin=True), y='count()', **kwargs).interactive()`
`alt.Chart(s.to_frame()).mark_bar(tooltip=True).encode(x=alt.X(f'{s.name}:Q', bin=True), y='count()', **kwargs).interactive()`
- `s.plot.kde(**kwargs)`
is shorthand for
`alt.Chart(s.to_frame()).transform_density(s.name, as_=[s.name, 'density']).mark_area().encode(x=s.name, y='density:Q', **kwargs).interactive()`
`alt.Chart(s.to_frame()).transform_density(s.name, as_=[s.name, 'density']).mark_area(tooltip=True).encode(x=s.name, y='density:Q', **kwargs).interactive()`
- for any other attribute `attr`, `s.plot.attr(**kwargs)`
is shorthand for
`alt.Chart(s.to_frame().with_row_index()).mark_attr().encode(x='index', y=s.name, **kwargs).interactive()`
`alt.Chart(s.to_frame().with_row_index()).mark_attr(tooltip=True).encode(x='index', y=s.name, **kwargs).interactive()`
Examples
--------
Expand Down
13 changes: 3 additions & 10 deletions py-polars/tests/unit/operations/namespaces/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,7 @@ def test_dataframe_plot_tooltip() -> None:
}
)
result = df.plot.line(x="length", y="width", color="species").to_dict()
assert result["encoding"]["tooltip"] == [
{"field": "length", "type": "quantitative"},
{"field": "width", "type": "quantitative"},
{"field": "species", "type": "nominal"},
]
assert result["mark"]["tooltip"] is True
result = df.plot.line(
x="length", y="width", color="species", tooltip=["length", "width"]
).to_dict()
Expand All @@ -54,10 +50,7 @@ def test_series_plot() -> None:
def test_series_plot_tooltip() -> None:
s = pl.Series("a", [1, 4, 4, 4, 7, 2, 5, 3, 6])
result = s.plot.line().to_dict()
assert result["encoding"]["tooltip"] == [
{"field": "index", "type": "quantitative"},
{"field": "a", "type": "quantitative"},
]
assert result["mark"]["tooltip"] is True
result = s.plot.line(tooltip=["a"]).to_dict()
assert result["encoding"]["tooltip"] == [{"field": "a", "type": "quantitative"}]

Expand All @@ -73,4 +66,4 @@ def test_nameless_series() -> None:
def test_x_with_axis_18830() -> None:
df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]})
result = df.plot.line(x=alt.X("a", axis=alt.Axis(labelAngle=-90))).to_dict()
assert result["encoding"]["tooltip"] == [{"field": "a", "type": "quantitative"}]
assert result["mark"]["tooltip"] is True

0 comments on commit 207ddb0

Please sign in to comment.