Skip to content

Commit

Permalink
first pass
Browse files Browse the repository at this point in the history
  • Loading branch information
deanm0000 committed Oct 16, 2024
1 parent 900dc3b commit 60cdf87
Show file tree
Hide file tree
Showing 6 changed files with 529 additions and 32 deletions.
477 changes: 477 additions & 0 deletions py-polars/mypyout.txt

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion py-polars/polars/_utils/construction/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
TYPE_CHECKING,
Any,
Callable,
cast,
)

import polars._reexport as pl
Expand Down Expand Up @@ -1115,7 +1116,7 @@ def pandas_to_pydf(
data[col], nan_to_null=nan_to_null, length=length
)

arrow_table = pa.table(arrow_dict)
arrow_table = pa.table(cast(dict[str, list[Any] | pa.Array[Any]], arrow_dict))
return arrow_to_pydf(
arrow_table,
schema=schema,
Expand Down
32 changes: 15 additions & 17 deletions py-polars/polars/_utils/construction/other.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, cast

from polars._utils.construction.utils import get_first_non_none
from polars.dependencies import pyarrow as pa
Expand All @@ -14,7 +14,7 @@ def pandas_series_to_arrow(
*,
length: int | None = None,
nan_to_null: bool = True,
) -> pa.Array:
) -> pa.Array[Any]:
"""
Convert a pandas Series to an Arrow Array.
Expand Down Expand Up @@ -52,21 +52,19 @@ def pandas_series_to_arrow(
)


def coerce_arrow(array: pa.Array) -> pa.Array:
def coerce_arrow(array: pa.Array[Any] | pa.ChunkedArray[Any]) -> pa.Array[Any]:
"""..."""
import pyarrow.compute as pc

if hasattr(array, "num_chunks") and array.num_chunks > 1:
# small integer keys can often not be combined, so let's already cast
# to the uint32 used by polars
if pa.types.is_dictionary(array.type) and (
pa.types.is_int8(array.type.index_type)
or pa.types.is_uint8(array.type.index_type)
or pa.types.is_int16(array.type.index_type)
or pa.types.is_uint16(array.type.index_type)
or pa.types.is_int32(array.type.index_type)
if isinstance(array, pa.ChunkedArray):
# TODO: [pyarrow] remove explicit cast when combine_chunks is fixed
array = cast(pa.Array[Any], array.combine_chunks())
if pa.types.is_dictionary(array.type):
array_type = cast(pa.DictionaryType[Any, Any], array.type)
if (
pa.types.is_int8(array_type.index_type)
or pa.types.is_uint8(array_type.index_type)
or pa.types.is_int16(array_type.index_type)
or pa.types.is_uint16(array_type.index_type)
or pa.types.is_int32(array_type.index_type)
):
array = pc.cast(
array, pa.dictionary(pa.uint32(), pa.large_string())
).combine_chunks()
array = array.cast(pa.dictionary(pa.uint32(), pa.large_string()))
return array
5 changes: 3 additions & 2 deletions py-polars/polars/_utils/construction/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
TYPE_CHECKING,
Any,
Callable,
cast,
)

import polars._reexport as pl
Expand Down Expand Up @@ -391,7 +392,7 @@ def pandas_to_pyseries(

def arrow_to_pyseries(
name: str,
values: pa.Array,
values: pa.Array[Any],
dtype: PolarsDataType | None = None,
*,
strict: bool = True,
Expand All @@ -404,7 +405,7 @@ def arrow_to_pyseries(
if (
len(array) == 0
and isinstance(array.type, pa.DictionaryType)
and array.type.value_type
and cast(pa.DictionaryType, array.type).value_type
in (
pa.utf8(),
pa.large_utf8(),
Expand Down
7 changes: 6 additions & 1 deletion py-polars/polars/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,14 +163,15 @@ def _lazy_import(module_name: str) -> tuple[ModuleType, bool]:
import pyarrow
import pydantic
import pyiceberg
from pyarrow import compute as pyarrow_compute
from pyarrow import dataset as pyarrow_dataset
else:
# infrequently-used builtins
dataclasses, _ = _lazy_import("dataclasses")
html, _ = _lazy_import("html")
json, _ = _lazy_import("json")
pickle, _ = _lazy_import("pickle")
subprocess, _ = _lazy_import("subprocess")

# heavy/optional third party libs
altair, _ALTAIR_AVAILABLE = _lazy_import("altair")
deltalake, _DELTALAKE_AVAILABLE = _lazy_import("deltalake")
Expand All @@ -180,6 +181,8 @@ def _lazy_import(module_name: str) -> tuple[ModuleType, bool]:
numpy, _NUMPY_AVAILABLE = _lazy_import("numpy")
pandas, _PANDAS_AVAILABLE = _lazy_import("pandas")
pyarrow, _PYARROW_AVAILABLE = _lazy_import("pyarrow")
pyarrow_compute, _ = _lazy_import("pyarrow.compute")
pyarrow_dataset, _ = _lazy_import("pyarrow.dataset")
pydantic, _PYDANTIC_AVAILABLE = _lazy_import("pydantic")
pyiceberg, _PYICEBERG_AVAILABLE = _lazy_import("pyiceberg")
zoneinfo, _ZONEINFO_AVAILABLE = (
Expand Down Expand Up @@ -308,6 +311,8 @@ def import_optional(
"pydantic",
"pyiceberg",
"pyarrow",
"pyarrow_dataset",
"pyarrow_compute",
"zoneinfo",
# lazy utilities
"_check_for_numpy",
Expand Down
37 changes: 26 additions & 11 deletions py-polars/polars/io/pyarrow_dataset/anonymous_scan.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,27 @@
from __future__ import annotations

from functools import partial
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, TypedDict

import polars._reexport as pl
from polars.dependencies import pyarrow as pa

if TYPE_CHECKING:
import sys

from polars.dependencies import pyarrow_dataset as ds

if sys.version_info >= (3, 11):
from typing import NotRequired
else:
from typing_extensions import NotRequired
from pyarrow.compute import Expression

from polars import DataFrame, LazyFrame


def _scan_pyarrow_dataset(
ds: pa.dataset.Dataset,
dataset: ds.Dataset,
*,
allow_pyarrow_filter: bool = True,
batch_size: int | None = None,
Expand All @@ -24,7 +34,7 @@ def _scan_pyarrow_dataset(
Parameters
----------
ds
dataset
pyarrow dataset
allow_pyarrow_filter
Allow predicates to be pushed down to pyarrow. This can lead to different
Expand All @@ -33,14 +43,14 @@ def _scan_pyarrow_dataset(
batch_size
The maximum row count for scanned pyarrow record batches.
"""
func = partial(_scan_pyarrow_dataset_impl, ds, batch_size=batch_size)
func = partial(_scan_pyarrow_dataset_impl, dataset, batch_size=batch_size)
return pl.LazyFrame._scan_python_function(
ds.schema, func, pyarrow=allow_pyarrow_filter
dataset.schema, func, pyarrow=allow_pyarrow_filter
)


def _scan_pyarrow_dataset_impl(
ds: pa.dataset.Dataset,
dataset: ds.Dataset,
with_columns: list[str] | None,
predicate: str | None,
n_rows: int | None,
Expand All @@ -51,7 +61,7 @@ def _scan_pyarrow_dataset_impl(
Parameters
----------
ds
dataset
pyarrow dataset
with_columns
Columns that are projected
Expand Down Expand Up @@ -93,11 +103,16 @@ def _scan_pyarrow_dataset_impl(
},
)

common_params = {"columns": with_columns, "filter": _filter}
class Common_params(TypedDict):
columns: list[str] | None
filter: Expression | None
batch_size: NotRequired[int]

common_params: Common_params = {"columns": with_columns, "filter": _filter}
if batch_size is not None:
common_params["batch_size"] = batch_size

if n_rows:
return from_arrow(ds.head(n_rows, **common_params)) # type: ignore[return-value]

return from_arrow(ds.to_table(**common_params)) # type: ignore[return-value]
return from_arrow(dataset.head(n_rows, **common_params)) # type: ignore[return-value]
# TODO: [pyarrow] remove ignore when from_arrow has annotations
return from_arrow(dataset.to_table(**common_params)) # type: ignore[return-value]

0 comments on commit 60cdf87

Please sign in to comment.