From 0f1eddad241de2848f3e883a91025dfc790832df Mon Sep 17 00:00:00 2001 From: Marco Edward Gorelli Date: Tue, 27 Aug 2024 10:10:57 +0200 Subject: [PATCH] feat(python!): Use Altair in DataFrame.plot (#17995) --- .github/workflows/benchmark.yml | 7 +- .github/workflows/test-coverage.yml | 7 +- .github/workflows/test-python.yml | 4 + docs/requirements.txt | 3 +- .../python/user-guide/misc/visualization.py | 112 +++++--- docs/user-guide/misc/visualization.md | 58 +++- py-polars/polars/dataframe/frame.py | 64 +++-- py-polars/polars/dataframe/plotting.py | 256 ++++++++++++++++++ py-polars/polars/dependencies.py | 10 +- py-polars/polars/meta/versions.py | 4 +- py-polars/polars/series/plotting.py | 172 ++++++++++++ py-polars/polars/series/series.py | 47 ++-- py-polars/pyproject.toml | 4 +- py-polars/requirements-dev.txt | 2 +- .../unit/operations/namespaces/test_plot.py | 42 ++- 15 files changed, 659 insertions(+), 133 deletions(-) create mode 100644 py-polars/polars/dataframe/plotting.py create mode 100644 py-polars/polars/series/plotting.py diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 335d0a32b754..ce06b799f909 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -48,7 +48,12 @@ jobs: - name: Install Python dependencies working-directory: py-polars - run: uv pip install --compile-bytecode -r requirements-dev.txt -r requirements-ci.txt --verbose + run: | + # Install typing-extensions separately whilst the `--extra-index-url` in `requirements-ci.txt` + # doesn't have an up-to-date typing-extensions, see + # https://github.com/astral-sh/uv/issues/6028#issuecomment-2287232150 + uv pip install -U typing-extensions + uv pip install --compile-bytecode -r requirements-dev.txt -r requirements-ci.txt --verbose - name: Set up Rust run: rustup show diff --git a/.github/workflows/test-coverage.yml b/.github/workflows/test-coverage.yml index ce642bbec306..c774bdf864c9 100644 --- a/.github/workflows/test-coverage.yml +++ b/.github/workflows/test-coverage.yml @@ -103,7 +103,12 @@ jobs: - name: Install Python dependencies working-directory: py-polars - run: uv pip install --compile-bytecode -r requirements-dev.txt -r requirements-ci.txt --verbose + run: | + # Install typing-extensions separately whilst the `--extra-index-url` in `requirements-ci.txt` + # doesn't have an up-to-date typing-extensions, see + # https://github.com/astral-sh/uv/issues/6028#issuecomment-2287232150 + uv pip install -U typing-extensions + uv pip install --compile-bytecode -r requirements-dev.txt -r requirements-ci.txt --verbose - name: Set up Rust run: rustup component add llvm-tools-preview diff --git a/.github/workflows/test-python.yml b/.github/workflows/test-python.yml index ca717ef191f2..089ccb9f553a 100644 --- a/.github/workflows/test-python.yml +++ b/.github/workflows/test-python.yml @@ -66,6 +66,10 @@ jobs: - name: Install Python dependencies run: | pip install uv + # Install typing-extensions separately whilst the `--extra-index-url` in `requirements-ci.txt` + # doesn't have an up-to-date typing-extensions, see + # https://github.com/astral-sh/uv/issues/6028#issuecomment-2287232150 + uv pip install -U typing-extensions uv pip install --compile-bytecode -r requirements-dev.txt -r requirements-ci.txt --verbose - name: Set up Rust diff --git a/docs/requirements.txt b/docs/requirements.txt index db32e0f1cd39..258288d07e16 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,10 +1,11 @@ +altair pandas pyarrow graphviz +hvplot matplotlib seaborn plotly -altair numba numpy diff --git a/docs/src/python/user-guide/misc/visualization.py b/docs/src/python/user-guide/misc/visualization.py index f04288cb7812..cd256127f1dd 100644 --- a/docs/src/python/user-guide/misc/visualization.py +++ b/docs/src/python/user-guide/misc/visualization.py @@ -3,30 +3,33 @@ path = "docs/data/iris.csv" -df = pl.scan_csv(path).group_by("species").agg(pl.col("petal_length").mean()).collect() +df = pl.read_csv(path) print(df) # --8<-- [end:dataframe] """ # --8<-- [start:hvplot_show_plot] -df.plot.bar( - x="species", - y="petal_length", +import hvplot.polars +df.hvplot.scatter( + x="sepal_width", + y="sepal_length", + by="species", width=650, ) # --8<-- [end:hvplot_show_plot] """ # --8<-- [start:hvplot_make_plot] -import hvplot +import hvplot.polars -plot = df.plot.bar( - x="species", - y="petal_length", +plot = df.hvplot.scatter( + x="sepal_width", + y="sepal_length", + by="species", width=650, ) -hvplot.save(plot, "docs/images/hvplot_bar.html") -with open("docs/images/hvplot_bar.html", "r") as f: +hvplot.save(plot, "docs/images/hvplot_scatter.html") +with open("docs/images/hvplot_scatter.html", "r") as f: chart_html = f.read() print(f"{chart_html}") # --8<-- [end:hvplot_make_plot] @@ -35,7 +38,12 @@ # --8<-- [start:matplotlib_show_plot] import matplotlib.pyplot as plt -plt.bar(x=df["species"], height=df["petal_length"]) +fig, ax = plt.subplots() +ax.scatter( + x=df["sepal_width"], + y=df["sepal_length"], + c=df["species"].cast(pl.Categorical).to_physical(), +) # --8<-- [end:matplotlib_show_plot] """ @@ -44,9 +52,14 @@ import matplotlib.pyplot as plt -plt.bar(x=df["species"], height=df["petal_length"]) -plt.savefig("docs/images/matplotlib_bar.png") -with open("docs/images/matplotlib_bar.png", "rb") as f: +fig, ax = plt.subplots() +ax.scatter( + x=df["sepal_width"], + y=df["sepal_length"], + c=df["species"].cast(pl.Categorical).to_physical(), +) +fig.savefig("docs/images/matplotlib_scatter.png") +with open("docs/images/matplotlib_scatter.png", "rb") as f: png = base64.b64encode(f.read()).decode() print(f'') # --8<-- [end:matplotlib_make_plot] @@ -54,24 +67,28 @@ """ # --8<-- [start:seaborn_show_plot] import seaborn as sns -sns.barplot( +sns.scatterplot( df, - x="species", - y="petal_length", + x="sepal_width", + y="sepal_length", + hue="species", ) # --8<-- [end:seaborn_show_plot] """ # --8<-- [start:seaborn_make_plot] import seaborn as sns +import matplotlib.pyplot as plt -sns.barplot( +fig, ax = plt.subplots() +ax = sns.scatterplot( df, - x="species", - y="petal_length", + x="sepal_width", + y="sepal_length", + hue="species", ) -plt.savefig("docs/images/seaborn_bar.png") -with open("docs/images/seaborn_bar.png", "rb") as f: +fig.savefig("docs/images/seaborn_scatter.png") +with open("docs/images/seaborn_scatter.png", "rb") as f: png = base64.b64encode(f.read()).decode() print(f'') # --8<-- [end:seaborn_make_plot] @@ -80,11 +97,12 @@ # --8<-- [start:plotly_show_plot] import plotly.express as px -px.bar( +px.scatter( df, - x="species", - y="petal_length", - width=400, + x="sepal_width", + y="sepal_length", + color="species", + width=650, ) # --8<-- [end:plotly_show_plot] """ @@ -92,39 +110,47 @@ # --8<-- [start:plotly_make_plot] import plotly.express as px -fig = px.bar( +fig = px.scatter( df, - x="species", - y="petal_length", + x="sepal_width", + y="sepal_length", + color="species", width=650, ) -fig.write_html("docs/images/plotly_bar.html", full_html=False, include_plotlyjs="cdn") -with open("docs/images/plotly_bar.html", "r") as f: +fig.write_html( + "docs/images/plotly_scatter.html", full_html=False, include_plotlyjs="cdn" +) +with open("docs/images/plotly_scatter.html", "r") as f: chart_html = f.read() print(f"{chart_html}") # --8<-- [end:plotly_make_plot] """ # --8<-- [start:altair_show_plot] -import altair as alt - -alt.Chart(df, width=700).mark_bar().encode(x="species:N", y="petal_length:Q") +( + df.plot.point( + x="sepal_length", + y="sepal_width", + color="species", + ) + .properties(width=500) + .configure_scale(zero=False) +) # --8<-- [end:altair_show_plot] """ # --8<-- [start:altair_make_plot] -import altair as alt - chart = ( - alt.Chart(df, width=600) - .mark_bar() - .encode( - x="species:N", - y="petal_length:Q", + df.plot.point( + x="sepal_length", + y="sepal_width", + color="species", ) + .properties(width=500) + .configure_scale(zero=False) ) -chart.save("docs/images/altair_bar.html") -with open("docs/images/altair_bar.html", "r") as f: +chart.save("docs/images/altair_scatter.html") +with open("docs/images/altair_scatter.html", "r") as f: chart_html = f.read() print(f"{chart_html}") # --8<-- [end:altair_make_plot] diff --git a/docs/user-guide/misc/visualization.md b/docs/user-guide/misc/visualization.md index 88dcd83a18a6..3f7574c07a2e 100644 --- a/docs/user-guide/misc/visualization.md +++ b/docs/user-guide/misc/visualization.md @@ -2,7 +2,8 @@ Data in a Polars `DataFrame` can be visualized using common visualization libraries. -We illustrate plotting capabilities using the Iris dataset. We scan a CSV and then do a group-by on the `species` column and get the mean of the `petal_length`. +We illustrate plotting capabilities using the Iris dataset. We read a CSV and then +plot one column against another, colored by a yet another column. {{code_block('user-guide/misc/visualization','dataframe',[])}} @@ -10,9 +11,39 @@ We illustrate plotting capabilities using the Iris dataset. We scan a CSV and th --8<-- "python/user-guide/misc/visualization.py:dataframe" ``` -## Built-in plotting with hvPlot +## Built-in plotting with Altair -Polars has a `plot` method to create interactive plots using [hvPlot](https://hvplot.holoviz.org/). +Polars has a `plot` method to create plots using [Altair](https://altair-viz.github.io/): + +{{code_block('user-guide/misc/visualization','altair_show_plot',[])}} + +```python exec="on" session="user-guide/misc/visualization" +--8<-- "python/user-guide/misc/visualization.py:altair_make_plot" +``` + +This is shorthand for: + +```python +import altair as alt + +( + alt.Chart(df).mark_point().encode( + x="sepal_length", + y="sepal_width", + color="species", + ) + .properties(width=500) + .configure_scale(zero=False) +) +``` + +and is only provided for convenience, and to signal that Altair is known to work well with +Polars. + +## hvPlot + +If you import `hvplot.polars`, then it registers a `hvplot` +method which you can use to create interactive plots using [hvPlot](https://hvplot.holoviz.org/). {{code_block('user-guide/misc/visualization','hvplot_show_plot',[])}} @@ -22,8 +53,12 @@ Polars has a `plot` method to create interactive plots using [hvPlot](https://hv ## Matplotlib -To create a bar chart we can pass columns of a `DataFrame` directly to Matplotlib as a `Series` for each column. Matplotlib does not have explicit support for Polars objects but Matplotlib can accept a Polars `Series` because it can convert each Series to a numpy array, which is zero-copy for numeric -data without null values. +To create a scatter plot we can pass columns of a `DataFrame` directly to Matplotlib as a `Series` for each column. +Matplotlib does not have explicit support for Polars objects but can accept a Polars `Series` by +converting it to a NumPy array (which is zero-copy for numeric data without null values). + +Note that because the column `'species'` isn't numeric, we need to first convert it to numeric values so that +it can be passed as an argument to `c`. {{code_block('user-guide/misc/visualization','matplotlib_show_plot',[])}} @@ -31,9 +66,10 @@ data without null values. --8<-- "python/user-guide/misc/visualization.py:matplotlib_make_plot" ``` -## Seaborn, Plotly & Altair +## Seaborn and Plotly -[Seaborn](https://seaborn.pydata.org/), [Plotly](https://plotly.com/) & [Altair](https://altair-viz.github.io/) can accept a Polars `DataFrame` by leveraging the [dataframe interchange protocol](https://data-apis.org/dataframe-api/), which offers zero-copy conversion where possible. +[Seaborn](https://seaborn.pydata.org/) and [Plotly](https://plotly.com/) can accept a Polars `DataFrame` by leveraging the [dataframe interchange protocol](https://data-apis.org/dataframe-api/), which offers zero-copy conversion where possible. Note +that the protocol does not support all Polars data types (e.g. `List`) so your mileage may vary here. ### Seaborn @@ -50,11 +86,3 @@ data without null values. ```python exec="on" session="user-guide/misc/visualization" --8<-- "python/user-guide/misc/visualization.py:plotly_make_plot" ``` - -### Altair - -{{code_block('user-guide/misc/visualization','altair_show_plot',[])}} - -```python exec="on" session="user-guide/misc/visualization" ---8<-- "python/user-guide/misc/visualization.py:altair_make_plot" -``` diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index d065bc24ce90..8a550f9e5904 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -66,6 +66,7 @@ from polars._utils.wrap import wrap_expr, wrap_ldf, wrap_s from polars.dataframe._html import NotebookFormatter from polars.dataframe.group_by import DynamicGroupBy, GroupBy, RollingGroupBy +from polars.dataframe.plotting import DataFramePlot from polars.datatypes import ( N_INFER_DEFAULT, Boolean, @@ -82,15 +83,15 @@ ) from polars.datatypes.group import INTEGER_DTYPES from polars.dependencies import ( + _ALTAIR_AVAILABLE, _GREAT_TABLES_AVAILABLE, - _HVPLOT_AVAILABLE, _PANDAS_AVAILABLE, _PYARROW_AVAILABLE, _check_for_numpy, _check_for_pandas, _check_for_pyarrow, + altair, great_tables, - hvplot, import_optional, ) from polars.dependencies import numpy as np @@ -123,7 +124,6 @@ import numpy.typing as npt import torch from great_tables import GT - from hvplot.plotting.core import hvPlotTabularPolars from xlsxwriter import Workbook, Worksheet from polars import DataType, Expr, LazyFrame, Series @@ -603,7 +603,7 @@ def _replace(self, column: str, new_column: Series) -> DataFrame: @property @unstable() - def plot(self) -> hvPlotTabularPolars: + def plot(self) -> DataFramePlot: """ Create a plot namespace. @@ -611,9 +611,28 @@ def plot(self) -> hvPlotTabularPolars: This functionality is currently considered **unstable**. It may be changed at any point without it being considered a breaking change. + .. versionchanged:: 1.6.0 + In prior versions of Polars, HvPlot was the plotting backend. If you would + like to restore the previous plotting functionality, all you need to do + is add `import hvplot.polars` at the top of your script and replace + `df.plot` with `df.hvplot`. + Polars does not implement plotting logic itself, but instead defers to - hvplot. Please see the `hvplot reference gallery `_ - for more information and documentation. + `Altair `_: + + - `df.plot.line(**kwargs)` + is shorthand for + `alt.Chart(df).mark_line().encode(**kwargs).interactive()` + - `df.plot.point(**kwargs)` + is shorthand for + `alt.Chart(df).mark_point().encode(**kwargs).interactive()` (and + `plot.scatter` is provided as an alias) + - `df.plot.bar(**kwargs)` + is shorthand for + `alt.Chart(df).mark_bar().encode(**kwargs).interactive()` + - for any other attribute `attr`, `df.plot.attr(**kwargs)` + is shorthand for + `alt.Chart(df).mark_attr().encode(**kwargs).interactive()` Examples -------- @@ -626,32 +645,37 @@ def plot(self) -> hvPlotTabularPolars: ... "species": ["setosa", "setosa", "versicolor"], ... } ... ) - >>> df.plot.scatter(x="length", y="width", by="species") # doctest: +SKIP + >>> df.plot.point(x="length", y="width", color="species") # doctest: +SKIP Line plot: >>> from datetime import date >>> df = pl.DataFrame( ... { - ... "date": [date(2020, 1, 2), date(2020, 1, 3), date(2020, 1, 4)], - ... "stock_1": [1, 4, 6], - ... "stock_2": [1, 5, 2], + ... "date": [date(2020, 1, 2), date(2020, 1, 3), date(2020, 1, 4)] * 2, + ... "price": [1, 4, 6, 1, 5, 2], + ... "stock": ["a", "a", "a", "b", "b", "b"], ... } ... ) - >>> df.plot.line(x="date", y=["stock_1", "stock_2"]) # doctest: +SKIP + >>> df.plot.line(x="date", y="price", color="stock") # doctest: +SKIP - For more info on what you can pass, you can use ``hvplot.help``: + Bar plot: - >>> import hvplot # doctest: +SKIP - >>> hvplot.help("scatter") # doctest: +SKIP + >>> df = pl.DataFrame( + ... { + ... "day": ["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"] * 2, + ... "group": ["a"] * 7 + ["b"] * 7, + ... "value": [1, 3, 2, 4, 5, 6, 1, 1, 3, 2, 4, 5, 1, 2], + ... } + ... ) + >>> df.plot.bar( + ... x="day", y="value", color="day", column="group" + ... ) # doctest: +SKIP """ - if not _HVPLOT_AVAILABLE or parse_version(hvplot.__version__) < parse_version( - "0.9.1" - ): - msg = "hvplot>=0.9.1 is required for `.plot`" + if not _ALTAIR_AVAILABLE or parse_version(altair.__version__) < (5, 4, 0): + msg = "altair>=5.4.0 is required for `.plot`" raise ModuleUpgradeRequiredError(msg) - hvplot.post_patch() - return hvplot.plotting.core.hvPlotTabularPolars(self) + return DataFramePlot(self) @property @unstable() diff --git a/py-polars/polars/dataframe/plotting.py b/py-polars/polars/dataframe/plotting.py new file mode 100644 index 000000000000..ed118e504656 --- /dev/null +++ b/py-polars/polars/dataframe/plotting.py @@ -0,0 +1,256 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Callable, Dict, Union + +if TYPE_CHECKING: + import sys + + import altair as alt + from altair.typing import ( + ChannelColor, + ChannelOrder, + ChannelSize, + ChannelTooltip, + ChannelX, + ChannelY, + EncodeKwds, + ) + + from polars import DataFrame + + if sys.version_info >= (3, 10): + from typing import TypeAlias + else: + from typing_extensions import TypeAlias + if sys.version_info >= (3, 11): + from typing import Unpack + else: + from typing_extensions import Unpack + + Encodings: TypeAlias = Dict[ + str, + Union[ + ChannelX, ChannelY, ChannelColor, ChannelOrder, ChannelSize, ChannelTooltip + ], + ] + + +class DataFramePlot: + """DataFrame.plot namespace.""" + + def __init__(self, df: DataFrame) -> None: + import altair as alt + + self._chart = alt.Chart(df) + + def bar( + self, + x: ChannelX | None = None, + y: ChannelY | None = None, + color: ChannelColor | None = None, + tooltip: ChannelTooltip | None = None, + /, + **kwargs: Unpack[EncodeKwds], + ) -> alt.Chart: + """ + Draw bar plot. + + Polars does not implement plotting logic itself but instead defers to + `Altair `_. + + `df.plot.bar(**kwargs)` is shorthand for + `alt.Chart(df).mark_bar().encode(**kwargs).interactive()`, + and is provided for convenience - for full customisatibility, use a plotting + library directly. + + .. versionchanged:: 1.6.0 + In prior versions of Polars, HvPlot was the plotting backend. If you would + like to restore the previous plotting functionality, all you need to do + is add `import hvplot.polars` at the top of your script and replace + `df.plot` with `df.hvplot`. + + Parameters + ---------- + x + Column with x-coordinates of bars. + y + Column with y-coordinates of bars. + color + Column to color bars by. + tooltip + Columns to show values of when hovering over bars with pointer. + **kwargs + Additional keyword arguments passed to Altair. + + Examples + -------- + >>> df = pl.DataFrame( + ... { + ... "day": ["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"] * 2, + ... "group": ["a"] * 7 + ["b"] * 7, + ... "value": [1, 3, 2, 4, 5, 6, 1, 1, 3, 2, 4, 5, 1, 2], + ... } + ... ) + >>> df.plot.bar( + ... x="day", y="value", color="day", column="group" + ... ) # doctest: +SKIP + """ + encodings: Encodings = {} + if x is not None: + encodings["x"] = x + if y is not None: + encodings["y"] = y + if color is not None: + encodings["color"] = color + if tooltip is not None: + encodings["tooltip"] = tooltip + return self._chart.mark_bar().encode(**encodings, **kwargs).interactive() + + def line( + self, + x: ChannelX | None = None, + y: ChannelY | None = None, + color: ChannelColor | None = None, + order: ChannelOrder | None = None, + tooltip: ChannelTooltip | None = None, + /, + **kwargs: Unpack[EncodeKwds], + ) -> alt.Chart: + """ + Draw line plot. + + Polars does not implement plotting logic itself but instead defers to + `Altair `_. + + `alt.Chart(df).mark_line().encode(**kwargs).interactive()`, + and is provided for convenience - for full customisatibility, use a plotting + library directly. + + .. versionchanged:: 1.6.0 + In prior versions of Polars, HvPlot was the plotting backend. If you would + like to restore the previous plotting functionality, all you need to do + is add `import hvplot.polars` at the top of your script and replace + `df.plot` with `df.hvplot`. + + Parameters + ---------- + x + Column with x-coordinates of lines. + y + Column with y-coordinates of lines. + color + Column to color lines by. + order + Column to use for order of data points in lines. + tooltip + Columns to show values of when hovering over lines with pointer. + **kwargs + Additional keyword arguments passed to Altair. + + Examples + -------- + >>> from datetime import date + >>> df = pl.DataFrame( + ... { + ... "date": [date(2020, 1, 2), date(2020, 1, 3), date(2020, 1, 4)] * 2, + ... "price": [1, 4, 6, 1, 5, 2], + ... "stock": ["a", "a", "a", "b", "b", "b"], + ... } + ... ) + >>> df.plot.line(x="date", y="price", color="stock") # doctest: +SKIP + """ + encodings: Encodings = {} + if x is not None: + encodings["x"] = x + if y is not None: + encodings["y"] = y + if color is not None: + encodings["color"] = color + if order is not None: + encodings["order"] = order + if tooltip is not None: + encodings["tooltip"] = tooltip + return self._chart.mark_line().encode(**encodings, **kwargs).interactive() + + def point( + self, + x: ChannelX | None = None, + y: ChannelY | None = None, + color: ChannelColor | None = None, + size: ChannelSize | None = None, + tooltip: ChannelTooltip | None = None, + /, + **kwargs: Unpack[EncodeKwds], + ) -> alt.Chart: + """ + Draw scatter plot. + + Polars does not implement plotting logic itself but instead defers to + `Altair `_. + + `df.plot.point(**kwargs)` is shorthand for + `alt.Chart(df).mark_point().encode(**kwargs).interactive()`, + and is provided for convenience - for full customisatibility, use a plotting + library directly. + + .. versionchanged:: 1.6.0 + In prior versions of Polars, HvPlot was the plotting backend. If you would + like to restore the previous plotting functionality, all you need to do + is add `import hvplot.polars` at the top of your script and replace + `df.plot` with `df.hvplot`. + + Parameters + ---------- + x + Column with x-coordinates of points. + y + Column with y-coordinates of points. + color + Column to color points by. + size + Column which determines points' sizes. + tooltip + Columns to show values of when hovering over points with pointer. + **kwargs + Additional keyword arguments passed to Altair. + + Examples + -------- + >>> df = pl.DataFrame( + ... { + ... "length": [1, 4, 6], + ... "width": [4, 5, 6], + ... "species": ["setosa", "setosa", "versicolor"], + ... } + ... ) + >>> df.plot.point(x="length", y="width", color="species") # doctest: +SKIP + """ + encodings: Encodings = {} + if x is not None: + encodings["x"] = x + if y is not None: + encodings["y"] = y + if color is not None: + encodings["color"] = color + if size is not None: + encodings["size"] = size + if tooltip is not None: + encodings["tooltip"] = tooltip + return ( + self._chart.mark_point() + .encode( + **encodings, + **kwargs, + ) + .interactive() + ) + + # Alias to `point` because of how common it is. + scatter = point + + def __getattr__(self, attr: str) -> Callable[..., alt.Chart]: + method = getattr(self._chart, f"mark_{attr}", None) + if method is None: + msg = "Altair has no method 'mark_{attr}'" + raise AttributeError(msg) + return lambda **kwargs: method().encode(**kwargs).interactive() diff --git a/py-polars/polars/dependencies.py b/py-polars/polars/dependencies.py index ce457255bb59..10548da8c904 100644 --- a/py-polars/polars/dependencies.py +++ b/py-polars/polars/dependencies.py @@ -8,11 +8,11 @@ from types import ModuleType from typing import TYPE_CHECKING, Any, ClassVar, Hashable, cast +_ALTAIR_AVAILABLE = True _DELTALAKE_AVAILABLE = True _FSSPEC_AVAILABLE = True _GEVENT_AVAILABLE = True _GREAT_TABLES_AVAILABLE = True -_HVPLOT_AVAILABLE = True _HYPOTHESIS_AVAILABLE = True _NUMPY_AVAILABLE = True _PANDAS_AVAILABLE = True @@ -150,11 +150,11 @@ def _lazy_import(module_name: str) -> tuple[ModuleType, bool]: import pickle import subprocess + import altair import deltalake import fsspec import gevent import great_tables - import hvplot import hypothesis import numpy import pandas @@ -175,10 +175,10 @@ def _lazy_import(module_name: str) -> tuple[ModuleType, bool]: subprocess, _ = _lazy_import("subprocess") # heavy/optional third party libs + altair, _ALTAIR_AVAILABLE = _lazy_import("altair") deltalake, _DELTALAKE_AVAILABLE = _lazy_import("deltalake") fsspec, _FSSPEC_AVAILABLE = _lazy_import("fsspec") great_tables, _GREAT_TABLES_AVAILABLE = _lazy_import("great_tables") - hvplot, _HVPLOT_AVAILABLE = _lazy_import("hvplot") hypothesis, _HYPOTHESIS_AVAILABLE = _lazy_import("hypothesis") numpy, _NUMPY_AVAILABLE = _lazy_import("numpy") pandas, _PANDAS_AVAILABLE = _lazy_import("pandas") @@ -301,11 +301,11 @@ def import_optional( "pickle", "subprocess", # lazy-load third party libs + "altair", "deltalake", "fsspec", "gevent", "great_tables", - "hvplot", "numpy", "pandas", "pydantic", @@ -318,11 +318,11 @@ def import_optional( "_check_for_pyarrow", "_check_for_pydantic", # exported flags/guards + "_ALTAIR_AVAILABLE", "_DELTALAKE_AVAILABLE", "_PYICEBERG_AVAILABLE", "_FSSPEC_AVAILABLE", "_GEVENT_AVAILABLE", - "_HVPLOT_AVAILABLE", "_HYPOTHESIS_AVAILABLE", "_NUMPY_AVAILABLE", "_PANDAS_AVAILABLE", diff --git a/py-polars/polars/meta/versions.py b/py-polars/polars/meta/versions.py index 5418df7682bb..6788d25a68ea 100644 --- a/py-polars/polars/meta/versions.py +++ b/py-polars/polars/meta/versions.py @@ -20,13 +20,13 @@ def show_versions() -> None: Python: 3.11.8 (main, Feb 6 2024, 21:21:21) [Clang 15.0.0 (clang-1500.1.0.2.5)] ----Optional dependencies---- adbc_driver_manager: 0.11.0 + altair: 5.4.0 cloudpickle: 3.0.0 connectorx: 0.3.2 deltalake: 0.17.1 fastexcel: 0.10.4 fsspec: 2023.12.2 gevent: 24.2.1 - hvplot: 0.9.2 matplotlib: 3.8.4 nest_asyncio: 1.6.0 numpy: 1.26.4 @@ -64,6 +64,7 @@ def show_versions() -> None: def _get_dependency_list() -> list[str]: return [ "adbc_driver_manager", + "altair", "cloudpickle", "connectorx", "deltalake", @@ -71,7 +72,6 @@ def _get_dependency_list() -> list[str]: "fsspec", "gevent", "great_tables", - "hvplot", "matplotlib", "nest_asyncio", "numpy", diff --git a/py-polars/polars/series/plotting.py b/py-polars/polars/series/plotting.py new file mode 100644 index 000000000000..c666c5a9b177 --- /dev/null +++ b/py-polars/polars/series/plotting.py @@ -0,0 +1,172 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Callable + +from polars.dependencies import altair as alt + +if TYPE_CHECKING: + import sys + + from altair.typing import EncodeKwds + + if sys.version_info >= (3, 11): + from typing import Unpack + else: + from typing_extensions import Unpack + + from polars import Series + + +class SeriesPlot: + """Series.plot namespace.""" + + _accessor = "plot" + + def __init__(self, s: Series) -> None: + name = s.name or "value" + self._df = s.to_frame(name) + self._series_name = name + + def hist( + self, + /, + **kwargs: Unpack[EncodeKwds], + ) -> alt.Chart: + """ + Draw histogram. + + Polars does not implement plotting logic itself but instead defers to + `Altair `_. + + `s.plot.hist(**kwargs)` is shorthand for + `alt.Chart(s.to_frame()).mark_bar().encode(x=alt.X(f'{s.name}:Q', bin=True), y='count()', **kwargs).interactive()`, + and is provided for convenience - for full customisatibility, use a plotting + library directly. + + .. versionchanged:: 1.6.0 + In prior versions of Polars, HvPlot was the plotting backend. If you would + like to restore the previous plotting functionality, all you need to do + is add `import hvplot.polars` at the top of your script and replace + `df.plot` with `df.hvplot`. + + Parameters + ---------- + **kwargs + Additional arguments and keyword arguments passed to Altair. + + Examples + -------- + >>> s = pl.Series("price", [1, 3, 3, 3, 5, 2, 6, 5, 5, 5, 7]) + >>> s.plot.hist() # doctest: +SKIP + """ # noqa: W505 + if self._series_name == "count()": + msg = "Cannot use `plot.hist` when Series name is `'count()'`" + raise ValueError(msg) + return ( + alt.Chart(self._df) + .mark_bar() + .encode(x=alt.X(f"{self._series_name}:Q", bin=True), y="count()", **kwargs) # type: ignore[misc] + .interactive() + ) + + def kde( + self, + /, + **kwargs: Unpack[EncodeKwds], + ) -> alt.Chart: + """ + Draw kernel dentity estimate plot. + + Polars does not implement plotting logic itself but instead defers to + `Altair `_. + + `s.plot.kde(**kwargs)` is shorthand for + `alt.Chart(s.to_frame()).transform_density(s.name, as_=[s.name, 'density']).mark_area().encode(x=s.name, y='density:Q', **kwargs).interactive()`, + and is provided for convenience - for full customisatibility, use a plotting + library directly. + + .. versionchanged:: 1.6.0 + In prior versions of Polars, HvPlot was the plotting backend. If you would + like to restore the previous plotting functionality, all you need to do + is add `import hvplot.polars` at the top of your script and replace + `df.plot` with `df.hvplot`. + + Parameters + ---------- + **kwargs + Additional keyword arguments passed to Altair. + + Examples + -------- + >>> s = pl.Series("price", [1, 3, 3, 3, 5, 2, 6, 5, 5, 5, 7]) + >>> s.plot.kde() # doctest: +SKIP + """ # noqa: W505 + if self._series_name == "density": + msg = "Cannot use `plot.kde` when Series name is `'density'`" + raise ValueError(msg) + return ( + alt.Chart(self._df) + .transform_density(self._series_name, as_=[self._series_name, "density"]) + .mark_area() + .encode(x=self._series_name, y="density:Q", **kwargs) # type: ignore[misc] + .interactive() + ) + + def line( + self, + /, + **kwargs: Unpack[EncodeKwds], + ) -> alt.Chart: + """ + Draw line plot. + + Polars does not implement plotting logic itself but instead defers to + `Altair `_. + + `s.plot.line(**kwargs)` is shorthand for + `alt.Chart(s.to_frame().with_row_index()).mark_line().encode(x='index', y=s.name, **kwargs).interactive()`, + and is provided for convenience - for full customisatibility, use a plotting + library directly. + + .. versionchanged:: 1.6.0 + In prior versions of Polars, HvPlot was the plotting backend. If you would + like to restore the previous plotting functionality, all you need to do + is add `import hvplot.polars` at the top of your script and replace + `df.plot` with `df.hvplot`. + + Parameters + ---------- + **kwargs + Additional keyword arguments passed to Altair. + + Examples + -------- + >>> s = pl.Series("price", [1, 3, 3, 3, 5, 2, 6, 5, 5, 5, 7]) + >>> s.plot.kde() # doctest: +SKIP + """ # noqa: W505 + if self._series_name == "index": + msg = "Cannot call `plot.line` when Series name is 'index'" + raise ValueError(msg) + return ( + alt.Chart(self._df.with_row_index()) + .mark_line() + .encode(x="index", y=self._series_name, **kwargs) # type: ignore[misc] + .interactive() + ) + + def __getattr__(self, attr: str) -> Callable[..., alt.Chart]: + if self._series_name == "index": + msg = "Cannot call `plot.{attr}` when Series name is 'index'" + raise ValueError(msg) + if attr == "scatter": + # alias `scatter` to `point` because of how common it is + attr = "point" + method = getattr(alt.Chart(self._df.with_row_index()), f"mark_{attr}", None) + if method is None: + msg = "Altair has no method 'mark_{attr}'" + raise AttributeError(msg) + return ( + lambda **kwargs: method() + .encode(x="index", y=self._series_name, **kwargs) + .interactive() + ) diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index 84bacb0c00cc..901e424641fa 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -86,12 +86,12 @@ ) from polars.datatypes._utils import dtype_to_init_repr from polars.dependencies import ( - _HVPLOT_AVAILABLE, + _ALTAIR_AVAILABLE, _PYARROW_AVAILABLE, _check_for_numpy, _check_for_pandas, _check_for_pyarrow, - hvplot, + altair, import_optional, ) from polars.dependencies import numpy as np @@ -104,6 +104,7 @@ from polars.series.categorical import CatNameSpace from polars.series.datetime import DateTimeNameSpace from polars.series.list import ListNameSpace +from polars.series.plotting import SeriesPlot from polars.series.string import StringNameSpace from polars.series.struct import StructNameSpace from polars.series.utils import expr_dispatch, get_ffi_func @@ -117,7 +118,6 @@ import jax import numpy.typing as npt import torch - from hvplot.plotting.core import hvPlotTabularPolars from polars import DataFrame, DataType, Expr from polars._typing import ( @@ -7382,7 +7382,7 @@ def struct(self) -> StructNameSpace: @property @unstable() - def plot(self) -> hvPlotTabularPolars: + def plot(self) -> SeriesPlot: """ Create a plot namespace. @@ -7390,33 +7390,44 @@ def plot(self) -> hvPlotTabularPolars: This functionality is currently considered **unstable**. It may be changed at any point without it being considered a breaking change. + .. versionchanged:: 1.6.0 + In prior versions of Polars, HvPlot was the plotting backend. If you would + like to restore the previous plotting functionality, all you need to do + is add `import hvplot.polars` at the top of your script and replace + `df.plot` with `df.hvplot`. + Polars does not implement plotting logic itself, but instead defers to - hvplot. Please see the `hvplot reference gallery `_ - for more information and documentation. + Altair: + + - `s.plot.hist(**kwargs)` + is shorthand for + `alt.Chart(s.to_frame()).mark_bar().encode(x=alt.X(f'{s.name}:Q', bin=True), y='count()', **kwargs).interactive()` + - `s.plot.kde(**kwargs)` + is shorthand for + `alt.Chart(s.to_frame()).transform_density(s.name, as_=[s.name, 'density']).mark_area().encode(x=s.name, y='density:Q', **kwargs).interactive()` + - for any other attribute `attr`, `s.plot.attr(**kwargs)` + is shorthand for + `alt.Chart(s.to_frame().with_row_index()).mark_attr().encode(x='index', y=s.name, **kwargs).interactive()` Examples -------- Histogram: - >>> s = pl.Series("values", [1, 4, 2]) + >>> s = pl.Series([1, 4, 4, 6, 2, 4, 3, 5, 5, 7, 1]) >>> s.plot.hist() # doctest: +SKIP - KDE plot (note: in addition to ``hvplot``, this one also requires ``scipy``): + KDE plot: >>> s.plot.kde() # doctest: +SKIP - For more info on what you can pass, you can use ``hvplot.help``: + Line plot: - >>> import hvplot # doctest: +SKIP - >>> hvplot.help("hist") # doctest: +SKIP - """ - if not _HVPLOT_AVAILABLE or parse_version(hvplot.__version__) < parse_version( - "0.9.1" - ): - msg = "hvplot>=0.9.1 is required for `.plot`" + >>> s.plot.line() # doctest: +SKIP + """ # noqa: W505 + if not _ALTAIR_AVAILABLE or parse_version(altair.__version__) < (5, 4, 0): + msg = "altair>=5.4.0 is required for `.plot`" raise ModuleUpgradeRequiredError(msg) - hvplot.post_patch() - return hvplot.plotting.core.hvPlotTabularPolars(self) + return SeriesPlot(self) def _resolve_temporal_dtype( diff --git a/py-polars/pyproject.toml b/py-polars/pyproject.toml index 913f96768b2d..6d617fb767c3 100644 --- a/py-polars/pyproject.toml +++ b/py-polars/pyproject.toml @@ -71,7 +71,7 @@ iceberg = ["pyiceberg >= 0.5.0"] async = ["gevent"] cloudpickle = ["cloudpickle"] graph = ["matplotlib"] -plot = ["hvplot >= 0.9.1", "polars[pandas]"] +plot = ["altair >= 5.4.0"] style = ["great-tables >= 0.8.0"] timezone = ["backports.zoneinfo; python_version < '3.9'", "tzdata; platform_system == 'Windows'"] @@ -103,6 +103,7 @@ module = [ "IPython.*", "adbc_driver_manager.*", "adbc_driver_sqlite.*", + "altair.*", "arrow_odbc", "backports", "connectorx", @@ -110,7 +111,6 @@ module = [ "fsspec.*", "gevent", "great_tables", - "hvplot.*", "jax.*", "kuzu", "matplotlib.*", diff --git a/py-polars/requirements-dev.txt b/py-polars/requirements-dev.txt index 50192c9653a1..af96cf2bbe02 100644 --- a/py-polars/requirements-dev.txt +++ b/py-polars/requirements-dev.txt @@ -47,7 +47,7 @@ deltalake>=0.15.0 # Csv zstandard # Plotting -hvplot>=0.9.1 +altair>=5.4.0 # Styling great-tables>=0.8.0; python_version >= '3.9' # Async diff --git a/py-polars/tests/unit/operations/namespaces/test_plot.py b/py-polars/tests/unit/operations/namespaces/test_plot.py index 34f8964512d8..fc2fbc02648a 100644 --- a/py-polars/tests/unit/operations/namespaces/test_plot.py +++ b/py-polars/tests/unit/operations/namespaces/test_plot.py @@ -1,15 +1,8 @@ -from datetime import date - -import pytest - import polars as pl -# Calling `plot` the first time is slow -# https://github.com/pola-rs/polars/issues/13500 -pytestmark = pytest.mark.slow - -def test_dataframe_scatter() -> None: +def test_dataframe_plot() -> None: + # dry-run, check nothing errors df = pl.DataFrame( { "length": [1, 4, 6], @@ -17,24 +10,25 @@ def test_dataframe_scatter() -> None: "species": ["setosa", "setosa", "versicolor"], } ) - df.plot.scatter(x="length", y="width", by="species") + df.plot.line(x="length", y="width", color="species").to_json() + df.plot.point(x="length", y="width", size="species").to_json() + df.plot.scatter(x="length", y="width", size="species").to_json() + df.plot.bar(x="length", y="width", color="species").to_json() + df.plot.area(x="length", y="width", color="species").to_json() -def test_dataframe_line() -> None: - df = pl.DataFrame( - { - "date": [date(2020, 1, 2), date(2020, 1, 3), date(2020, 1, 3)], - "stock_1": [1, 4, 6], - "stock_2": [1, 5, 2], - } - ) - df.plot.line(x="date", y=["stock_1", "stock_2"]) +def test_series_plot() -> None: + # dry-run, check nothing errors + s = pl.Series("a", [1, 4, 4, 4, 7, 2, 5, 3, 6]) + s.plot.kde().to_json() + s.plot.hist().to_json() + s.plot.line().to_json() + s.plot.point().to_json() -def test_series_hist() -> None: - s = pl.Series("values", [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) - s.plot.hist() +def test_empty_dataframe() -> None: + pl.DataFrame({"a": [], "b": []}).plot.point(x="a", y="b") -def test_empty_dataframe() -> None: - pl.DataFrame({"a": [], "b": []}).plot.scatter(x="a", y="b") +def test_nameless_series() -> None: + pl.Series([1, 2, 3]).plot.kde().to_json()