Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add pyarrow-stubs #19259

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
477 changes: 477 additions & 0 deletions py-polars/mypyout.txt

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion py-polars/polars/_utils/construction/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
TYPE_CHECKING,
Any,
Callable,
cast,
)

import polars._reexport as pl
Expand Down Expand Up @@ -1115,7 +1116,7 @@ def pandas_to_pydf(
data[col], nan_to_null=nan_to_null, length=length
)

arrow_table = pa.table(arrow_dict)
arrow_table = pa.table(cast(dict[str, list[Any] | pa.Array[Any]], arrow_dict))
return arrow_to_pydf(
arrow_table,
schema=schema,
Expand Down
32 changes: 15 additions & 17 deletions py-polars/polars/_utils/construction/other.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, cast

from polars._utils.construction.utils import get_first_non_none
from polars.dependencies import pyarrow as pa
Expand All @@ -14,7 +14,7 @@ def pandas_series_to_arrow(
*,
length: int | None = None,
nan_to_null: bool = True,
) -> pa.Array:
) -> pa.Array[Any]:
"""
Convert a pandas Series to an Arrow Array.

Expand Down Expand Up @@ -52,21 +52,19 @@ def pandas_series_to_arrow(
)


def coerce_arrow(array: pa.Array) -> pa.Array:
def coerce_arrow(array: pa.Array[Any] | pa.ChunkedArray[Any]) -> pa.Array[Any]:
"""..."""
import pyarrow.compute as pc

if hasattr(array, "num_chunks") and array.num_chunks > 1:
# small integer keys can often not be combined, so let's already cast
# to the uint32 used by polars
if pa.types.is_dictionary(array.type) and (
pa.types.is_int8(array.type.index_type)
or pa.types.is_uint8(array.type.index_type)
or pa.types.is_int16(array.type.index_type)
or pa.types.is_uint16(array.type.index_type)
or pa.types.is_int32(array.type.index_type)
if isinstance(array, pa.ChunkedArray):
# TODO: [pyarrow] remove explicit cast when combine_chunks is fixed
array = cast(pa.Array[Any], array.combine_chunks())
if pa.types.is_dictionary(array.type):
array_type = cast(pa.DictionaryType[Any, Any], array.type)
if (
pa.types.is_int8(array_type.index_type)
or pa.types.is_uint8(array_type.index_type)
or pa.types.is_int16(array_type.index_type)
or pa.types.is_uint16(array_type.index_type)
or pa.types.is_int32(array_type.index_type)
):
array = pc.cast(
array, pa.dictionary(pa.uint32(), pa.large_string())
).combine_chunks()
array = array.cast(pa.dictionary(pa.uint32(), pa.large_string()))
return array
5 changes: 3 additions & 2 deletions py-polars/polars/_utils/construction/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
TYPE_CHECKING,
Any,
Callable,
cast,
)

import polars._reexport as pl
Expand Down Expand Up @@ -391,7 +392,7 @@ def pandas_to_pyseries(

def arrow_to_pyseries(
name: str,
values: pa.Array,
values: pa.Array[Any],
dtype: PolarsDataType | None = None,
*,
strict: bool = True,
Expand All @@ -404,7 +405,7 @@ def arrow_to_pyseries(
if (
len(array) == 0
and isinstance(array.type, pa.DictionaryType)
and array.type.value_type
and cast(pa.DictionaryType, array.type).value_type
in (
pa.utf8(),
pa.large_utf8(),
Expand Down
7 changes: 6 additions & 1 deletion py-polars/polars/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,14 +163,15 @@ def _lazy_import(module_name: str) -> tuple[ModuleType, bool]:
import pyarrow
import pydantic
import pyiceberg
from pyarrow import compute as pyarrow_compute
from pyarrow import dataset as pyarrow_dataset
else:
# infrequently-used builtins
dataclasses, _ = _lazy_import("dataclasses")
html, _ = _lazy_import("html")
json, _ = _lazy_import("json")
pickle, _ = _lazy_import("pickle")
subprocess, _ = _lazy_import("subprocess")

# heavy/optional third party libs
altair, _ALTAIR_AVAILABLE = _lazy_import("altair")
deltalake, _DELTALAKE_AVAILABLE = _lazy_import("deltalake")
Expand All @@ -180,6 +181,8 @@ def _lazy_import(module_name: str) -> tuple[ModuleType, bool]:
numpy, _NUMPY_AVAILABLE = _lazy_import("numpy")
pandas, _PANDAS_AVAILABLE = _lazy_import("pandas")
pyarrow, _PYARROW_AVAILABLE = _lazy_import("pyarrow")
pyarrow_compute, _ = _lazy_import("pyarrow.compute")
pyarrow_dataset, _ = _lazy_import("pyarrow.dataset")
pydantic, _PYDANTIC_AVAILABLE = _lazy_import("pydantic")
pyiceberg, _PYICEBERG_AVAILABLE = _lazy_import("pyiceberg")
zoneinfo, _ZONEINFO_AVAILABLE = (
Expand Down Expand Up @@ -308,6 +311,8 @@ def import_optional(
"pydantic",
"pyiceberg",
"pyarrow",
"pyarrow_dataset",
"pyarrow_compute",
"zoneinfo",
# lazy utilities
"_check_for_numpy",
Expand Down
37 changes: 26 additions & 11 deletions py-polars/polars/io/pyarrow_dataset/anonymous_scan.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,27 @@
from __future__ import annotations

from functools import partial
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, TypedDict

import polars._reexport as pl
from polars.dependencies import pyarrow as pa

if TYPE_CHECKING:
import sys

from polars.dependencies import pyarrow_dataset as ds

if sys.version_info >= (3, 11):
from typing import NotRequired
else:
from typing_extensions import NotRequired
from pyarrow.compute import Expression

from polars import DataFrame, LazyFrame


def _scan_pyarrow_dataset(
ds: pa.dataset.Dataset,
dataset: ds.Dataset,
*,
allow_pyarrow_filter: bool = True,
batch_size: int | None = None,
Expand All @@ -24,7 +34,7 @@ def _scan_pyarrow_dataset(

Parameters
----------
ds
dataset
pyarrow dataset
allow_pyarrow_filter
Allow predicates to be pushed down to pyarrow. This can lead to different
Expand All @@ -33,14 +43,14 @@ def _scan_pyarrow_dataset(
batch_size
The maximum row count for scanned pyarrow record batches.
"""
func = partial(_scan_pyarrow_dataset_impl, ds, batch_size=batch_size)
func = partial(_scan_pyarrow_dataset_impl, dataset, batch_size=batch_size)
return pl.LazyFrame._scan_python_function(
ds.schema, func, pyarrow=allow_pyarrow_filter
dataset.schema, func, pyarrow=allow_pyarrow_filter
)


def _scan_pyarrow_dataset_impl(
ds: pa.dataset.Dataset,
dataset: ds.Dataset,
with_columns: list[str] | None,
predicate: str | None,
n_rows: int | None,
Expand All @@ -51,7 +61,7 @@ def _scan_pyarrow_dataset_impl(

Parameters
----------
ds
dataset
pyarrow dataset
with_columns
Columns that are projected
Expand Down Expand Up @@ -93,11 +103,16 @@ def _scan_pyarrow_dataset_impl(
},
)

common_params = {"columns": with_columns, "filter": _filter}
class Common_params(TypedDict):
columns: list[str] | None
filter: Expression | None
batch_size: NotRequired[int]

common_params: Common_params = {"columns": with_columns, "filter": _filter}
if batch_size is not None:
common_params["batch_size"] = batch_size

if n_rows:
return from_arrow(ds.head(n_rows, **common_params)) # type: ignore[return-value]

return from_arrow(ds.to_table(**common_params)) # type: ignore[return-value]
return from_arrow(dataset.head(n_rows, **common_params)) # type: ignore[return-value]
# TODO: [pyarrow] remove ignore when from_arrow has annotations
return from_arrow(dataset.to_table(**common_params)) # type: ignore[return-value]
1 change: 1 addition & 0 deletions py-polars/requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,4 @@ flask-cors
# Stub files
pandas-stubs
boto3-stubs
pyarrow-stubs
Loading