Skip to content

Commit

Permalink
fix type checking
Browse files Browse the repository at this point in the history
  • Loading branch information
jonmmease committed Oct 14, 2024
1 parent f9f7961 commit f27c7f7
Showing 1 changed file with 40 additions and 23 deletions.
63 changes: 40 additions & 23 deletions vegafusion-python/vegafusion/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import narwhals as nw
import psutil
from arro3.core import Table
from narwhals.typing import IntoFrame

from vegafusion.transformer import DataFrameLike
from vegafusion.utils import get_column_usage
Expand All @@ -18,10 +17,13 @@
from .local_tz import get_local_tz

if TYPE_CHECKING:
import duckdb # noqa: F401
import pandas as pd
import polars as pl # noqa: F401
import pyarrow as pa
from duckdb import DuckDBPyConnection
from grpc import Channel
from narwhals.typing import IntoFrameT

from vegafusion._vegafusion import PyChartState, PyVegaFusionRuntime

Expand All @@ -37,13 +39,14 @@ def _get_common_namespace(inline_datasets: dict[str, Any] | None) -> str | None:
namespaces.add(nw.get_native_namespace(nw.from_native(df)))

if len(namespaces) == 1:
return next(iter(namespaces)).__name__
return str(next(iter(namespaces)).__name__)
else:
return None
except TypeError:
# Types not compatible with Narwhals
return None


def _get_default_namespace() -> ModuleType:
# Returns a default narwhals namespace, based on what is installed
if pd := sys.modules.get("pandas") and sys.modules.get("pyarrow"):
Expand Down Expand Up @@ -222,7 +225,9 @@ def set_connection(
"""
# Don't import duckdb unless it's already loaded. If it's not loaded,
# then the input connection can't be a duckdb connection.
duckdb = sys.modules.get("duckdb", None)
if not TYPE_CHECKING:
duckdb = sys.modules.get("duckdb", None)

if isinstance(connection, str):
if connection == "datafusion":
# Connection of None uses DataFusion
Expand Down Expand Up @@ -307,7 +312,7 @@ def process_request_bytes(self, request: bytes) -> bytes:

def _import_or_register_inline_datasets(
self,
inline_datasets: dict[str, IntoFrame | pd.DataFrame | SqlDataset] | None = None,
inline_datasets: dict[str, IntoFrameT | SqlDataset] | None = None,
inline_dataset_usage: dict[str, list[str]] | None = None,
) -> dict[str, Table | SqlDataset]:
"""
Expand All @@ -319,7 +324,9 @@ def _import_or_register_inline_datasets(
specification using the following url syntax
'vegafusion+dataset://{dataset_name}' or 'table://{dataset_name}'.
"""
pd: pd = sys.modules.get("pandas", None)
if not TYPE_CHECKING:
pd = sys.modules.get("pandas", None)
pa = sys.modules.get("pyarrow", None)

inline_datasets = inline_datasets or {}
inline_dataset_usage = inline_dataset_usage or {}
Expand All @@ -328,23 +335,26 @@ def _import_or_register_inline_datasets(
columns = inline_dataset_usage.get(name)
if isinstance(value, SqlDataset):
imported_inline_datasets[name] = value
elif pd is not None and isinstance(value, pd.DataFrame):
import pyarrow as pa
# elif pd is not None and isinstance(value, pd.DataFrame):
elif isinstance(value, pd.DataFrame):
# rename to help mypy
inner_value: pd.DataFrame = value
del value

# Project down columns if possible
if columns is not None:
value = value[columns]
inner_value = inner_value[columns]

# Convert problematic object columns to strings
for col, dtype in value.dtypes.items():
for col, dtype in inner_value.dtypes.items():
if dtype.kind == "O":
try:
# See if the Table constructor can handle column by itself
col_tbl = Table(value[[col]])
col_tbl = Table(inner_value[[col]])

# If so, keep the arrow version so that it's more efficient
# to convert as part of the whole table later
value = value.assign(
inner_value = inner_value.assign(
**{
col: pd.arrays.ArrowExtensionArray(
pa.chunked_array(col_tbl.column(0))
Expand All @@ -354,18 +364,20 @@ def _import_or_register_inline_datasets(
except TypeError:
# If the Table constructor can't handle the object column,
# convert the column to pyarrow strings
value = value.assign(
**{col: value[col].astype("string[pyarrow]")}
inner_value = inner_value.assign(
**{col: inner_value[col].astype("string[pyarrow]")}
)

if self._connection is not None:
try:
# Try registering DataFrame if supported
self._connection.register_pandas(name, value, temporary=True)
self._connection.register_pandas(
name, inner_value, temporary=True
)
continue
except ValueError:
pass
imported_inline_datasets[name] = Table(value)
imported_inline_datasets[name] = Table(inner_value)
elif isinstance(value, dict):
# Let narwhals import the dict using a default backend
df_nw = nw.from_dict(value, native_namespace=_get_default_namespace())
Expand All @@ -374,16 +386,17 @@ def _import_or_register_inline_datasets(
# Import through PyCapsule interface on narwhals
try:
df_nw = nw.from_native(value)

# Project down columns if possible
if columns is not None:
# TODO: Nice error message when column is not found
df_nw = df_nw[columns]
df_nw = df_nw[columns] # type: ignore[index]

imported_inline_datasets[name] = Table(df_nw)
imported_inline_datasets[name] = Table(df_nw) # type: ignore[arg-type]
except TypeError:
# Not supported by Narwhals, try pycapsule interface directly
if hasattr(value, "__arrow_c_stream__"):
imported_inline_datasets[name] = Table(value)
imported_inline_datasets[name] = Table(value) # type: ignore[arg-type]
else:
raise

Expand Down Expand Up @@ -660,9 +673,10 @@ def pre_transform_datasets(
if self._grpc_channel:
raise ValueError("pre_transform_datasets not yet supported over gRPC")
else:
pl = sys.modules.get("polars", None)
pa = sys.modules.get("pyarrow", None)
pd = sys.modules.get("pandas", None)
if not TYPE_CHECKING:
pl = sys.modules.get("polars", None)
pa = sys.modules.get("pyarrow", None)
pd = sys.modules.get("pandas", None)

local_tz = local_tz or get_local_tz()

Expand Down Expand Up @@ -697,14 +711,16 @@ def pre_transform_datasets(
nw_dataframes = [
nw.from_native(pl.DataFrame(value)) for value in values
]

elif namespace == "pyarrow" and pa is not None:
nw_dataframes = [nw.from_native(pa.table(value)) for value in values]
elif namespace == "pandas" and pd is not None and pa is not None:
nw_dataframes = [
nw.from_native(pa.table(value).to_pandas()) for value in values
]
else:
# Either no inline datasets, inline datasets with mixed or unrecognized types
# Either no inline datasets, inline datasets with mixed or
# unrecognized types
if pa is not None and pd is not None:
nw_dataframes = [
nw.from_native(pa.table(value).to_pandas()) for value in values
Expand All @@ -717,7 +733,8 @@ def pre_transform_datasets(
# Hopefully narwhals will eventually help us fall back to whatever
# is installed here
raise ValueError(
"Either polars or pandas must be installed to extract transformed data"
"Either polars or pandas must be installed to extract "
"transformed data"
)

# Localize datetime columns to UTC, then extract the native DataFrame
Expand Down

0 comments on commit f27c7f7

Please sign in to comment.