Skip to content

Commit

Permalink
chore(datafusion): restore _register and deprecate usage in do_connect
Browse files Browse the repository at this point in the history
  • Loading branch information
gforsyth committed Dec 4, 2024
1 parent 795214d commit 9ff388c
Showing 1 changed file with 66 additions and 23 deletions.
89 changes: 66 additions & 23 deletions ibis/backends/datafusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
import inspect
import typing
from collections.abc import Mapping
from pathlib import Path
from typing import TYPE_CHECKING, Any

import datafusion as df
import pyarrow as pa
import pyarrow_hotfix # noqa: F401
import sqlglot as sg
import sqlglot.expressions as sge
Expand All @@ -25,7 +27,7 @@
from ibis.common.dispatch import lazy_singledispatch
from ibis.expr.operations.udf import InputType
from ibis.formats.pyarrow import PyArrowSchema, PyArrowType
from ibis.util import gen_name, normalize_filename, normalize_filenames
from ibis.util import gen_name, normalize_filename, normalize_filenames, warn_deprecated

try:
from datafusion import ExecutionContext as SessionContext
Expand All @@ -43,11 +45,8 @@
RuntimeConfig = None

if TYPE_CHECKING:
from pathlib import Path

import pandas as pd
import polars as pl
import pyarrow as pa


def as_nullable(dtype: dt.DataType) -> dt.DataType:
Expand Down Expand Up @@ -88,37 +87,30 @@ def do_connect(
Parameters
----------
config
Mapping of table names to files or a `SessionContext`
Mapping of table names to files (deprecated in 10.0) or a `SessionContext`
instance.
Examples
--------
>>> from datafusion import SessionContext
>>> ctx = SessionContext()
>>> _ = ctx.from_pydict({"a": [1, 2, 3]}, "mytable")
>>> import ibis
>>> config = {
... "astronauts": "ci/ibis-testing-data/parquet/astronauts.parquet",
... "diamonds": "ci/ibis-testing-data/csv/diamonds.csv",
... }
>>> con = ibis.datafusion.connect(config)
>>> con = ibis.datafusion.connect(ctx)
>>> con.list_tables()
['astronauts', 'diamonds']
>>> con.table("diamonds")
DatabaseTable: diamonds
carat float64
cut string
color string
clarity string
depth float64
table float64
price int64
x float64
y float64
z float64
['mytable']
"""
if isinstance(config, SessionContext):
(self.con, config) = (config, None)
else:
if config is not None and not isinstance(config, Mapping):
raise TypeError("Input to ibis.datafusion.connect must be a mapping")
elif config is not None and config: # warn if dict is not empty
warn_deprecated(
"Passing a mapping of tables names to files",
as_of="10.0",
instead="Please use the explicit `read_*` methods for the files you would like to load instead.",
)
if SessionConfig is not None:
df_config = SessionConfig(
{"datafusion.sql_parser.dialect": "PostgreSQL"}
Expand Down Expand Up @@ -178,6 +170,57 @@ def _get_schema_using_query(self, query: str) -> sch.Schema:

return PyArrowSchema.to_ibis(df.schema())

def _register(
self,
source: str | Path | pa.Table | pa.RecordBatch | pa.Dataset | pd.DataFrame,
table_name: str | None = None,
**kwargs: Any,
) -> ir.Table:
import pandas as pd
import pyarrow.dataset as ds

if isinstance(source, (str, Path)):
first = str(source)
elif isinstance(source, pa.Table):
self.con.deregister_table(table_name)
self.con.register_record_batches(table_name, [source.to_batches()])
return self.table(table_name)
elif isinstance(source, pa.RecordBatch):
self.con.deregister_table(table_name)
self.con.register_record_batches(table_name, [[source]])
return self.table(table_name)
elif isinstance(source, ds.Dataset):
self.con.deregister_table(table_name)
self.con.register_dataset(table_name, source)
return self.table(table_name)
elif isinstance(source, pd.DataFrame):
return self.register(pa.Table.from_pandas(source), table_name, **kwargs)
else:
raise ValueError("`source` must be either a string or a pathlib.Path")

if first.startswith(("parquet://", "parq://")) or first.endswith(
("parq", "parquet")
):
return self.read_parquet(source, table_name=table_name, **kwargs)
elif first.startswith(("csv://", "txt://")) or first.endswith(
("csv", "tsv", "txt")
):
return self.read_csv(source, table_name=table_name, **kwargs)
else:
self._register_failure()
return None

def _register_failure(self):
import inspect

msg = ", ".join(
m[0] for m in inspect.getmembers(self) if m[0].startswith("read_")
)
raise ValueError(
f"Cannot infer appropriate read function for input, "
f"please call one of {msg} directly"
)

def _register_builtin_udfs(self):
from ibis.backends.datafusion import udfs

Expand Down

0 comments on commit 9ff388c

Please sign in to comment.