diff --git a/ibis/backends/datafusion/__init__.py b/ibis/backends/datafusion/__init__.py index 47a51f8852ef7..1a3af336e5347 100644 --- a/ibis/backends/datafusion/__init__.py +++ b/ibis/backends/datafusion/__init__.py @@ -4,9 +4,11 @@ import inspect import typing from collections.abc import Mapping +from pathlib import Path from typing import TYPE_CHECKING, Any import datafusion as df +import pyarrow as pa import pyarrow_hotfix # noqa: F401 import sqlglot as sg import sqlglot.expressions as sge @@ -25,7 +27,7 @@ from ibis.common.dispatch import lazy_singledispatch from ibis.expr.operations.udf import InputType from ibis.formats.pyarrow import PyArrowSchema, PyArrowType -from ibis.util import gen_name, normalize_filename, normalize_filenames +from ibis.util import gen_name, normalize_filename, normalize_filenames, warn_deprecated try: from datafusion import ExecutionContext as SessionContext @@ -43,11 +45,8 @@ RuntimeConfig = None if TYPE_CHECKING: - from pathlib import Path - import pandas as pd import polars as pl - import pyarrow as pa def as_nullable(dtype: dt.DataType) -> dt.DataType: @@ -88,37 +87,30 @@ def do_connect( Parameters ---------- config - Mapping of table names to files or a `SessionContext` + Mapping of table names to files (deprecated in 10.0) or a `SessionContext` instance. Examples -------- + >>> from datafusion import SessionContext + >>> ctx = SessionContext() + >>> _ = ctx.from_pydict({"a": [1, 2, 3]}, "mytable") >>> import ibis - >>> config = { - ... "astronauts": "ci/ibis-testing-data/parquet/astronauts.parquet", - ... "diamonds": "ci/ibis-testing-data/csv/diamonds.csv", - ... } - >>> con = ibis.datafusion.connect(config) + >>> con = ibis.datafusion.connect(ctx) >>> con.list_tables() - ['astronauts', 'diamonds'] - >>> con.table("diamonds") - DatabaseTable: diamonds - carat float64 - cut string - color string - clarity string - depth float64 - table float64 - price int64 - x float64 - y float64 - z float64 + ['mytable'] """ if isinstance(config, SessionContext): (self.con, config) = (config, None) else: if config is not None and not isinstance(config, Mapping): raise TypeError("Input to ibis.datafusion.connect must be a mapping") + elif config is not None and config: # warn if dict is not empty + warn_deprecated( + "Passing a mapping of tables names to files", + as_of="10.0", + instead="Please use the explicit `read_*` methods for the files you would like to load instead.", + ) if SessionConfig is not None: df_config = SessionConfig( {"datafusion.sql_parser.dialect": "PostgreSQL"} @@ -178,6 +170,57 @@ def _get_schema_using_query(self, query: str) -> sch.Schema: return PyArrowSchema.to_ibis(df.schema()) + def _register( + self, + source: str | Path | pa.Table | pa.RecordBatch | pa.Dataset | pd.DataFrame, + table_name: str | None = None, + **kwargs: Any, + ) -> ir.Table: + import pandas as pd + import pyarrow.dataset as ds + + if isinstance(source, (str, Path)): + first = str(source) + elif isinstance(source, pa.Table): + self.con.deregister_table(table_name) + self.con.register_record_batches(table_name, [source.to_batches()]) + return self.table(table_name) + elif isinstance(source, pa.RecordBatch): + self.con.deregister_table(table_name) + self.con.register_record_batches(table_name, [[source]]) + return self.table(table_name) + elif isinstance(source, ds.Dataset): + self.con.deregister_table(table_name) + self.con.register_dataset(table_name, source) + return self.table(table_name) + elif isinstance(source, pd.DataFrame): + return self.register(pa.Table.from_pandas(source), table_name, **kwargs) + else: + raise ValueError("`source` must be either a string or a pathlib.Path") + + if first.startswith(("parquet://", "parq://")) or first.endswith( + ("parq", "parquet") + ): + return self.read_parquet(source, table_name=table_name, **kwargs) + elif first.startswith(("csv://", "txt://")) or first.endswith( + ("csv", "tsv", "txt") + ): + return self.read_csv(source, table_name=table_name, **kwargs) + else: + self._register_failure() + return None + + def _register_failure(self): + import inspect + + msg = ", ".join( + m[0] for m in inspect.getmembers(self) if m[0].startswith("read_") + ) + raise ValueError( + f"Cannot infer appropriate read function for input, " + f"please call one of {msg} directly" + ) + def _register_builtin_udfs(self): from ibis.backends.datafusion import udfs