Skip to content

Commit

Permalink
refactor(duckdb): align duckdb IO method signatures with base class
Browse files Browse the repository at this point in the history
BREAKING CHANGE: The (positional) arguments to `read_parquet`,
`read_csv`, and `read_delta` have changes to `path`, `path`, and
`source`, respectively.
  • Loading branch information
gforsyth committed Jun 13, 2024
1 parent e04c35d commit a6d0c18
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 23 deletions.
41 changes: 18 additions & 23 deletions ibis/backends/duckdb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,9 +616,9 @@ def _register_failure(self):
@util.experimental
def read_json(
self,
source_list: str | list[str] | tuple[str],
path: str | list[str] | tuple[str],
table_name: str | None = None,
**kwargs,
**kwargs: Any,
) -> ir.Table:
"""Read newline-delimited JSON into an ibis table.
Expand All @@ -628,7 +628,7 @@ def read_json(
Parameters
----------
source_list
path
File or list of files
table_name
Optional table name
Expand All @@ -651,25 +651,23 @@ def read_json(
self._create_temp_view(
table_name,
sg.select(STAR).from_(
self.compiler.f.read_json_auto(
util.normalize_filenames(source_list), *options
)
self.compiler.f.read_json_auto(util.normalize_filenames(path), *options)
),
)

return self.table(table_name)

def read_csv(
self,
source_list: str | list[str] | tuple[str],
path: str | list[str] | tuple[str],
table_name: str | None = None,
**kwargs: Any,
) -> ir.Table:
"""Register a CSV file as a table in the current database.
Parameters
----------
source_list
path
The data source(s). May be a path to a file or directory of CSV files, or an
iterable of CSV files.
table_name
Expand All @@ -685,17 +683,14 @@ def read_csv(
The just-registered table
"""
source_list = util.normalize_filenames(source_list)
path = util.normalize_filenames(path)

if not table_name:
table_name = util.gen_name("read_csv")

# auto_detect and columns collide, so we set auto_detect=True
# unless COLUMNS has been specified
if any(
source.startswith(("http://", "https://", "s3://"))
for source in source_list
):
if any(source.startswith(("http://", "https://", "s3://")) for source in path):
self._load_extensions(["httpfs"])

kwargs.setdefault("header", True)
Expand Down Expand Up @@ -723,7 +718,7 @@ def read_csv(

self._create_temp_view(
table_name,
sg.select(STAR).from_(self.compiler.f.read_csv(source_list, *options)),
sg.select(STAR).from_(self.compiler.f.read_csv(path, *options)),
)

return self.table(table_name)
Expand Down Expand Up @@ -786,15 +781,15 @@ def read_geo(

def read_parquet(
self,
source_list: str | Iterable[str],
path: str | Iterable[str],
table_name: str | None = None,
**kwargs: Any,
) -> ir.Table:
"""Register a parquet file as a table in the current database.
Parameters
----------
source_list
path
The data source(s). May be a path to a file, an iterable of files,
or directory of parquet files.
table_name
Expand All @@ -810,17 +805,17 @@ def read_parquet(
The just-registered table
"""
source_list = util.normalize_filenames(source_list)
path = util.normalize_filenames(path)

table_name = table_name or util.gen_name("read_parquet")

# Default to using the native duckdb parquet reader
# If that fails because of auth issues, fall back to ingesting via
# pyarrow dataset
try:
self._read_parquet_duckdb_native(source_list, table_name, **kwargs)
self._read_parquet_duckdb_native(path, table_name, **kwargs)
except duckdb.IOException:
self._read_parquet_pyarrow_dataset(source_list, table_name, **kwargs)
self._read_parquet_pyarrow_dataset(path, table_name, **kwargs)

return self.table(table_name)

Expand Down Expand Up @@ -892,15 +887,15 @@ def read_in_memory(

def read_delta(
self,
source_table: str,
source: str,
table_name: str | None = None,
**kwargs: Any,
) -> ir.Table:
"""Register a Delta Lake table as a table in the current database.
Parameters
----------
source_table
source
The data source. Must be a directory
containing a Delta Lake table.
table_name
Expand All @@ -915,7 +910,7 @@ def read_delta(
The just-registered table.
"""
source_table = util.normalize_filenames(source_table)[0]
source = util.normalize_filenames(source)[0]

table_name = table_name or util.gen_name("read_delta")

Expand All @@ -928,7 +923,7 @@ def read_delta(
"pip install 'ibis-framework[deltalake]'\n"
)

delta_table = DeltaTable(source_table, **kwargs)
delta_table = DeltaTable(source, **kwargs)

return self.read_in_memory(
delta_table.to_pyarrow_dataset(), table_name=table_name
Expand Down
26 changes: 26 additions & 0 deletions ibis/backends/tests/test_signatures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from __future__ import annotations

import inspect

import pytest

from ibis.backends import _FileIOHandler
from ibis.backends.tests.signature.typecheck import compatible

params = []

for module in [_FileIOHandler]:
methods = list(filter(lambda x: not x.startswith("_"), dir(module)))
for method in methods:
params.append((_FileIOHandler, method))


@pytest.mark.parametrize("base_cls, method", params)
def test_signatures(base_cls, method, backend_cls):
if not hasattr(backend_cls, method):
pytest.skip(f"Method {method} not present in {backend_cls}, skipping...")

base_sig = inspect.signature(getattr(base_cls, method))
backend_sig = inspect.signature(getattr(backend_cls, method))

assert compatible(base_sig, backend_sig, check_annotations=False)

0 comments on commit a6d0c18

Please sign in to comment.