From 8b876864a7729a879f2566396505806fc7fc2ffc Mon Sep 17 00:00:00 2001 From: Jiting Xu Date: Thu, 22 Aug 2024 21:06:25 -0700 Subject: [PATCH 01/11] add read_csv --- ibis/backends/__init__.py | 93 ++++++++++++++++++++++++++++ ibis/backends/tests/test_register.py | 13 +--- 2 files changed, 96 insertions(+), 10 deletions(-) diff --git a/ibis/backends/__init__.py b/ibis/backends/__init__.py index 265d3ab76561..569272e4f11b 100644 --- a/ibis/backends/__init__.py +++ b/ibis/backends/__init__.py @@ -9,6 +9,7 @@ import urllib.parse from pathlib import Path from typing import TYPE_CHECKING, Any, ClassVar +import glob import ibis import ibis.common.exceptions as exc @@ -1236,6 +1237,98 @@ def has_operation(cls, operation: type[ops.Value]) -> bool: f"{cls.name} backend has not implemented `has_operation` API" ) + + def read_csv( + self, path: str | Path, table_name: str | None = None, **kwargs: Any + ) -> ir.Table: + """Register a CSV file as a table in the current backend. + + Parameters + ---------- + path + The data source. A string or Path to the CSV file. + table_name + An optional name to use for the created table. This defaults to + a sequentially generated name. + **kwargs + Additional keyword arguments passed to the backend loading function. + + Returns + ------- + ir.Table + The just-registered table + + Examples + -------- + Connect to a SQLite database: + + >>> con = ibis.sqlite.connect() + + Read a single csv file: + + >>> table = con.read_csv("path/to/file.csv") + + Read all csv files in a directory: + + >>> table = con.read_parquet("path/to/csv_directory/*") + + Read all csv files with a glob pattern: + + >>> table = con.read_csv("path/to/csv_directory/test_*.csv") + + Read csv file from s3: + + >>> table = con.read_csv("s3://bucket/path/to/file.csv") + + """ + pa = self._import_pyarrow() + import pyarrow.csv as pcsv + import pyarrow.fs as fs + + read_options_args = {} + parse_options_args = {} + convert_options_args = {} + memory_pool = None + + for key, value in kwargs.items(): + if hasattr(pcsv.ReadOptions, key): + read_options_args[key] = value + elif hasattr(pcsv.ParseOptions, key): + parse_options_args[key] = value + elif hasattr(pcsv.ConvertOptions, key): + convert_options_args[key] = value + elif key == "memory_pool": + memory_pool = value + else: + raise ValueError(f"Invalid args: {key!r}") + + read_options = pcsv.ReadOptions(**read_options_args) + parse_options = pcsv.ParseOptions(**parse_options_args) + convert_options = pcsv.ConvertOptions(**convert_options_args) + if memory_pool: + memory_pool = pa.default_memory_pool() + + path = str(path) + file_system, path = fs.FileSystem.from_uri(path) + + if isinstance(file_system, fs.LocalFileSystem): + paths = glob.glob(path) + if not paths: + raise FileNotFoundError(f"No files found at {path!r}") + else: + paths = [path] + + pyarrow_tables = [] + for path in paths: + with file_system.open_input_file(path) as f: + pyarrow_table = pcsv.read_csv(f, read_options=read_options, parse_options=parse_options, convert_options=convert_options, memory_pool=memory_pool) + pyarrow_tables.append(pyarrow_table) + + pyarrow_table = pa.concat_tables(pyarrow_tables) + table_name = table_name or util.gen_name("read_csv") + self.create_table(table_name, pyarrow_table) + return self.table(table_name) + def _cached(self, expr: ir.Table): """Cache the provided expression. diff --git a/ibis/backends/tests/test_register.py b/ibis/backends/tests/test_register.py index cdfa1683743f..a8b909cc8b07 100644 --- a/ibis/backends/tests/test_register.py +++ b/ibis/backends/tests/test_register.py @@ -488,14 +488,7 @@ def test_read_parquet_glob(con, tmp_path, ft_data): @pytest.mark.notyet( [ "flink", - "impala", - "mssql", - "mysql", "pandas", - "postgres", - "risingwave", - "sqlite", - "trino", ] ) def test_read_csv_glob(con, tmp_path, ft_data): @@ -578,13 +571,13 @@ def num_diamonds(data_dir): [param(None, id="default"), param("fancy_stones", id="file_name")], ) @pytest.mark.notyet( - ["flink", "impala", "mssql", "mysql", "postgres", "risingwave", "sqlite", "trino"] + ["flink"] ) def test_read_csv(con, data_dir, in_table_name, num_diamonds): fname = "diamonds.csv" with pushd(data_dir / "csv"): - if con.name == "pyspark": - # pyspark doesn't respect CWD + if con.name in ("pyspark", "sqlite"): + # pyspark and sqlite doesn't respect CWD fname = str(Path(fname).absolute()) table = con.read_csv(fname, table_name=in_table_name) From fedd4deb6da3b3d01171cacadaa4bc6df0884dfe Mon Sep 17 00:00:00 2001 From: Jiting Xu Date: Thu, 22 Aug 2024 21:12:30 -0700 Subject: [PATCH 02/11] lint --- ibis/backends/__init__.py | 185 ++++++++++++++------------- ibis/backends/tests/test_register.py | 4 +- 2 files changed, 96 insertions(+), 93 deletions(-) diff --git a/ibis/backends/__init__.py b/ibis/backends/__init__.py index 569272e4f11b..d09c30876c40 100644 --- a/ibis/backends/__init__.py +++ b/ibis/backends/__init__.py @@ -3,13 +3,13 @@ import abc import collections.abc import functools +import glob import importlib.metadata import keyword import re import urllib.parse from pathlib import Path from typing import TYPE_CHECKING, Any, ClassVar -import glob import ibis import ibis.common.exceptions as exc @@ -1237,97 +1237,102 @@ def has_operation(cls, operation: type[ops.Value]) -> bool: f"{cls.name} backend has not implemented `has_operation` API" ) - def read_csv( - self, path: str | Path, table_name: str | None = None, **kwargs: Any - ) -> ir.Table: - """Register a CSV file as a table in the current backend. - - Parameters - ---------- - path - The data source. A string or Path to the CSV file. - table_name - An optional name to use for the created table. This defaults to - a sequentially generated name. - **kwargs - Additional keyword arguments passed to the backend loading function. - - Returns - ------- - ir.Table - The just-registered table - - Examples - -------- - Connect to a SQLite database: - - >>> con = ibis.sqlite.connect() - - Read a single csv file: - - >>> table = con.read_csv("path/to/file.csv") - - Read all csv files in a directory: - - >>> table = con.read_parquet("path/to/csv_directory/*") - - Read all csv files with a glob pattern: - - >>> table = con.read_csv("path/to/csv_directory/test_*.csv") - - Read csv file from s3: - - >>> table = con.read_csv("s3://bucket/path/to/file.csv") - - """ - pa = self._import_pyarrow() - import pyarrow.csv as pcsv - import pyarrow.fs as fs - - read_options_args = {} - parse_options_args = {} - convert_options_args = {} - memory_pool = None - - for key, value in kwargs.items(): - if hasattr(pcsv.ReadOptions, key): - read_options_args[key] = value - elif hasattr(pcsv.ParseOptions, key): - parse_options_args[key] = value - elif hasattr(pcsv.ConvertOptions, key): - convert_options_args[key] = value - elif key == "memory_pool": - memory_pool = value - else: - raise ValueError(f"Invalid args: {key!r}") - - read_options = pcsv.ReadOptions(**read_options_args) - parse_options = pcsv.ParseOptions(**parse_options_args) - convert_options = pcsv.ConvertOptions(**convert_options_args) - if memory_pool: - memory_pool = pa.default_memory_pool() - - path = str(path) - file_system, path = fs.FileSystem.from_uri(path) - - if isinstance(file_system, fs.LocalFileSystem): - paths = glob.glob(path) - if not paths: - raise FileNotFoundError(f"No files found at {path!r}") + self, path: str | Path, table_name: str | None = None, **kwargs: Any + ) -> ir.Table: + """Register a CSV file as a table in the current backend. + + Parameters + ---------- + path + The data source. A string or Path to the CSV file. + table_name + An optional name to use for the created table. This defaults to + a sequentially generated name. + **kwargs + Additional keyword arguments passed to the backend loading function. + + Returns + ------- + ir.Table + The just-registered table + + Examples + -------- + Connect to a SQLite database: + + >>> con = ibis.sqlite.connect() + + Read a single csv file: + + >>> table = con.read_csv("path/to/file.csv") + + Read all csv files in a directory: + + >>> table = con.read_parquet("path/to/csv_directory/*") + + Read all csv files with a glob pattern: + + >>> table = con.read_csv("path/to/csv_directory/test_*.csv") + + Read csv file from s3: + + >>> table = con.read_csv("s3://bucket/path/to/file.csv") + + """ + pa = self._import_pyarrow() + import pyarrow.csv as pcsv + from pyarrow import fs + + read_options_args = {} + parse_options_args = {} + convert_options_args = {} + memory_pool = None + + for key, value in kwargs.items(): + if hasattr(pcsv.ReadOptions, key): + read_options_args[key] = value + elif hasattr(pcsv.ParseOptions, key): + parse_options_args[key] = value + elif hasattr(pcsv.ConvertOptions, key): + convert_options_args[key] = value + elif key == "memory_pool": + memory_pool = value else: - paths = [path] - - pyarrow_tables = [] - for path in paths: - with file_system.open_input_file(path) as f: - pyarrow_table = pcsv.read_csv(f, read_options=read_options, parse_options=parse_options, convert_options=convert_options, memory_pool=memory_pool) - pyarrow_tables.append(pyarrow_table) - - pyarrow_table = pa.concat_tables(pyarrow_tables) - table_name = table_name or util.gen_name("read_csv") - self.create_table(table_name, pyarrow_table) - return self.table(table_name) + raise ValueError(f"Invalid args: {key!r}") + + read_options = pcsv.ReadOptions(**read_options_args) + parse_options = pcsv.ParseOptions(**parse_options_args) + convert_options = pcsv.ConvertOptions(**convert_options_args) + if memory_pool: + memory_pool = pa.default_memory_pool() + + path = str(path) + file_system, path = fs.FileSystem.from_uri(path) + + if isinstance(file_system, fs.LocalFileSystem): + paths = glob.glob(path) + if not paths: + raise FileNotFoundError(f"No files found at {path!r}") + else: + paths = [path] + + pyarrow_tables = [] + for path in paths: + with file_system.open_input_file(path) as f: + pyarrow_table = pcsv.read_csv( + f, + read_options=read_options, + parse_options=parse_options, + convert_options=convert_options, + memory_pool=memory_pool, + ) + pyarrow_tables.append(pyarrow_table) + + pyarrow_table = pa.concat_tables(pyarrow_tables) + table_name = table_name or util.gen_name("read_csv") + self.create_table(table_name, pyarrow_table) + return self.table(table_name) def _cached(self, expr: ir.Table): """Cache the provided expression. diff --git a/ibis/backends/tests/test_register.py b/ibis/backends/tests/test_register.py index a8b909cc8b07..d2e50c6d6338 100644 --- a/ibis/backends/tests/test_register.py +++ b/ibis/backends/tests/test_register.py @@ -570,9 +570,7 @@ def num_diamonds(data_dir): "in_table_name", [param(None, id="default"), param("fancy_stones", id="file_name")], ) -@pytest.mark.notyet( - ["flink"] -) +@pytest.mark.notyet(["flink"]) def test_read_csv(con, data_dir, in_table_name, num_diamonds): fname = "diamonds.csv" with pushd(data_dir / "csv"): From 38f91dd45a74bf9823334a07980f8efcaa51a443 Mon Sep 17 00:00:00 2001 From: jitingxu1 Date: Thu, 22 Aug 2024 22:01:29 -0700 Subject: [PATCH 03/11] resolve tests --- ibis/backends/tests/test_register.py | 38 ++++++++++++++++++++++------ 1 file changed, 30 insertions(+), 8 deletions(-) diff --git a/ibis/backends/tests/test_register.py b/ibis/backends/tests/test_register.py index d2e50c6d6338..50fc9c97ca14 100644 --- a/ibis/backends/tests/test_register.py +++ b/ibis/backends/tests/test_register.py @@ -12,14 +12,13 @@ import ibis from ibis.backends.conftest import TEST_TABLES +from ibis.backends.tests.errors import MySQLOperationalError, PyODBCProgrammingError if TYPE_CHECKING: from collections.abc import Iterator import pyarrow as pa -pytestmark = pytest.mark.notimpl(["druid", "exasol", "oracle"]) - @contextlib.contextmanager def pushd(new_dir): @@ -98,6 +97,7 @@ def gzip_csv(data_dir, tmp_path): "trino", ] ) +@pytest.mark.notimpl(["druid", "exasol", "oracle"]) def test_register_csv(con, data_dir, fname, in_table_name, out_table_name): with pushd(data_dir / "csv"): with pytest.warns(FutureWarning, match="v9.1"): @@ -109,7 +109,7 @@ def test_register_csv(con, data_dir, fname, in_table_name, out_table_name): # TODO: rewrite or delete test when register api is removed -@pytest.mark.notimpl(["datafusion"]) +@pytest.mark.notimpl(["datafusion", "druid", "exasol", "oracle"]) @pytest.mark.notyet( [ "bigquery", @@ -153,6 +153,7 @@ def test_register_csv_gz(con, data_dir, gzip_csv): "trino", ] ) +@pytest.mark.notimpl(["druid", "exasol", "oracle"]) def test_register_with_dotted_name(con, data_dir, tmp_path): basename = "foo.bar.baz/diamonds.csv" f = tmp_path.joinpath(basename) @@ -212,6 +213,7 @@ def read_table(path: Path) -> Iterator[tuple[str, pa.Table]]: "trino", ] ) +@pytest.mark.notimpl(["druid", "exasol", "oracle"]) def test_register_parquet( con, tmp_path, data_dir, fname, in_table_name, out_table_name ): @@ -252,6 +254,7 @@ def test_register_parquet( "trino", ] ) +@pytest.mark.notimpl(["druid", "exasol", "oracle"]) def test_register_iterator_parquet( con, tmp_path, @@ -280,7 +283,7 @@ def test_register_iterator_parquet( # TODO: remove entirely when `register` is removed # This same functionality is implemented across all backends # via `create_table` and tested in `test_client.py` -@pytest.mark.notimpl(["datafusion"]) +@pytest.mark.notimpl(["datafusion", "druid", "exasol", "oracle"]) @pytest.mark.notyet( [ "bigquery", @@ -316,7 +319,7 @@ def test_register_pandas(con): # TODO: remove entirely when `register` is removed # This same functionality is implemented across all backends # via `create_table` and tested in `test_client.py` -@pytest.mark.notimpl(["datafusion", "polars"]) +@pytest.mark.notimpl(["datafusion", "polars", "druid", "exasol", "oracle"]) @pytest.mark.notyet( [ "bigquery", @@ -361,6 +364,7 @@ def test_register_pyarrow_tables(con): "trino", ] ) +@pytest.mark.notimpl(["druid", "exasol", "oracle"]) def test_csv_reregister_schema(con, tmp_path): foo = tmp_path.joinpath("foo.csv") with foo.open("w", newline="") as csvfile: @@ -390,10 +394,12 @@ def test_csv_reregister_schema(con, tmp_path): "clickhouse", "dask", "datafusion", - "flink", + "druid", + "exasol" "flink", "impala", "mysql", "mssql", + "oracle", "pandas", "polars", "postgres", @@ -428,6 +434,7 @@ def test_register_garbage(con, monkeypatch): @pytest.mark.notyet( ["flink", "impala", "mssql", "mysql", "postgres", "risingwave", "sqlite", "trino"] ) +@pytest.mark.notimpl(["druid", "exasol", "oracle"]) def test_read_parquet(con, tmp_path, data_dir, fname, in_table_name): pq = pytest.importorskip("pyarrow.parquet") @@ -469,6 +476,7 @@ def ft_data(data_dir): "trino", ] ) +@pytest.mark.notimpl(["druid", "exasol", "oracle"]) def test_read_parquet_glob(con, tmp_path, ft_data): pq = pytest.importorskip("pyarrow.parquet") @@ -491,6 +499,9 @@ def test_read_parquet_glob(con, tmp_path, ft_data): "pandas", ] ) +@pytest.mark.notimpl(["druid"]) +@pytest.mark.notimpl(["mssql"], raises=PyODBCProgrammingError) +@pytest.mark.notimpl(["mysql"], raises=MySQLOperationalError) def test_read_csv_glob(con, tmp_path, ft_data): pc = pytest.importorskip("pyarrow.csv") @@ -527,6 +538,7 @@ def test_read_csv_glob(con, tmp_path, ft_data): raises=ValueError, reason="read_json() missing required argument: 'schema'", ) +@pytest.mark.notimpl(["druid", "exasol", "oracle"]) def test_read_json_glob(con, tmp_path, ft_data): nrows = len(ft_data) ntables = 2 @@ -571,11 +583,21 @@ def num_diamonds(data_dir): [param(None, id="default"), param("fancy_stones", id="file_name")], ) @pytest.mark.notyet(["flink"]) +@pytest.mark.notimpl(["druid", "exasol", "oracle"]) def test_read_csv(con, data_dir, in_table_name, num_diamonds): fname = "diamonds.csv" with pushd(data_dir / "csv"): - if con.name in ("pyspark", "sqlite"): - # pyspark and sqlite doesn't respect CWD + if con.name in ( + "pyspark", + "sqlite", + "mysql", + "postgres", + "risingwave", + "impala", + "mssql", + "trino", + ): + # backend doesn't respect CWD fname = str(Path(fname).absolute()) table = con.read_csv(fname, table_name=in_table_name) From 773cfb57347bffa68dbbe40c679e922e57da6827 Mon Sep 17 00:00:00 2001 From: jitingxu1 Date: Mon, 26 Aug 2024 16:09:30 -0700 Subject: [PATCH 04/11] resolve typo --- ibis/backends/tests/test_register.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/ibis/backends/tests/test_register.py b/ibis/backends/tests/test_register.py index 50fc9c97ca14..d4cf4cf3143b 100644 --- a/ibis/backends/tests/test_register.py +++ b/ibis/backends/tests/test_register.py @@ -395,7 +395,8 @@ def test_csv_reregister_schema(con, tmp_path): "dask", "datafusion", "druid", - "exasol" "flink", + "exasol", + "flink", "impala", "mysql", "mssql", @@ -589,7 +590,7 @@ def test_read_csv(con, data_dir, in_table_name, num_diamonds): with pushd(data_dir / "csv"): if con.name in ( "pyspark", - "sqlite", + #"sqlite", "mysql", "postgres", "risingwave", @@ -597,7 +598,8 @@ def test_read_csv(con, data_dir, in_table_name, num_diamonds): "mssql", "trino", ): - # backend doesn't respect CWD + # pyspark backend doesn't respect CWD + # backends using pyarrow implementation need absolute path fname = str(Path(fname).absolute()) table = con.read_csv(fname, table_name=in_table_name) From c7aea6e2c58692e71a03134dbea0e2d7c06cbd68 Mon Sep 17 00:00:00 2001 From: jitingxu1 Date: Mon, 26 Aug 2024 16:11:02 -0700 Subject: [PATCH 05/11] resolve typo --- ibis/backends/tests/test_register.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ibis/backends/tests/test_register.py b/ibis/backends/tests/test_register.py index d4cf4cf3143b..38a76c012b48 100644 --- a/ibis/backends/tests/test_register.py +++ b/ibis/backends/tests/test_register.py @@ -590,7 +590,7 @@ def test_read_csv(con, data_dir, in_table_name, num_diamonds): with pushd(data_dir / "csv"): if con.name in ( "pyspark", - #"sqlite", + "sqlite", "mysql", "postgres", "risingwave", From 6152533428eed1ee622f6924f62f267594d5d1c8 Mon Sep 17 00:00:00 2001 From: jitingxu1 Date: Sun, 15 Sep 2024 12:27:09 -0700 Subject: [PATCH 06/11] test --- ibis/backends/tests/test_register.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/ibis/backends/tests/test_register.py b/ibis/backends/tests/test_register.py index be4d12928efd..dd2f6bc9a161 100644 --- a/ibis/backends/tests/test_register.py +++ b/ibis/backends/tests/test_register.py @@ -479,11 +479,6 @@ def test_read_parquet_glob(con, tmp_path, ft_data): [ "flink", "impala", - "mssql", - "mysql", - "postgres", - "risingwave", - "sqlite", "trino", ] ) From e0230255039af25a7ac2d196ae4c68e712790720 Mon Sep 17 00:00:00 2001 From: jitingxu1 Date: Thu, 19 Sep 2024 10:02:24 -0700 Subject: [PATCH 07/11] remove trino from notyet --- ibis/backends/tests/test_register.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ibis/backends/tests/test_register.py b/ibis/backends/tests/test_register.py index dd2f6bc9a161..deb2530ce2ac 100644 --- a/ibis/backends/tests/test_register.py +++ b/ibis/backends/tests/test_register.py @@ -479,7 +479,6 @@ def test_read_parquet_glob(con, tmp_path, ft_data): [ "flink", "impala", - "trino", ] ) @pytest.mark.notimpl(["druid"]) From 902bb4737e1edc594d1ab3ba57749010717c6bb8 Mon Sep 17 00:00:00 2001 From: jitingxu1 Date: Fri, 20 Sep 2024 10:46:17 -0700 Subject: [PATCH 08/11] skip trino and impala test --- ibis/backends/__init__.py | 5 +++++ ibis/backends/tests/test_register.py | 5 +++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/ibis/backends/__init__.py b/ibis/backends/__init__.py index f826f4222cb9..b203351ae5ff 100644 --- a/ibis/backends/__init__.py +++ b/ibis/backends/__init__.py @@ -1308,11 +1308,16 @@ def has_operation(cls, operation: type[ops.Value]) -> bool: f"{cls.name} backend has not implemented `has_operation` API" ) + @util.experimental def read_csv( self, path: str | Path, table_name: str | None = None, **kwargs: Any ) -> ir.Table: """Register a CSV file as a table in the current backend. + This function reads a CSV file and registers it as a table in the current + backend. Note that for Impala and Trino backends, CSV read performance + may be suboptimal. + Parameters ---------- path diff --git a/ibis/backends/tests/test_register.py b/ibis/backends/tests/test_register.py index deb2530ce2ac..6af0ae3e3fde 100644 --- a/ibis/backends/tests/test_register.py +++ b/ibis/backends/tests/test_register.py @@ -565,6 +565,9 @@ def num_diamonds(data_dir): @pytest.mark.notyet(["flink"]) @pytest.mark.notimpl(["druid", "exasol", "oracle"]) def test_read_csv(con, data_dir, in_table_name, num_diamonds): + if con.name in ("trino", "impala"): + # TODO: remove after trino and impala have efficient insertion + pytest.skip("Both Impala and Trino lack efficient data insertion methods from Python.") fname = "diamonds.csv" with pushd(data_dir / "csv"): if con.name in ( @@ -573,9 +576,7 @@ def test_read_csv(con, data_dir, in_table_name, num_diamonds): "mysql", "postgres", "risingwave", - "impala", "mssql", - "trino", ): # pyspark backend doesn't respect CWD # backends using pyarrow implementation need absolute path From 79885a9721225b934f03d0b9525ba86342e0371b Mon Sep 17 00:00:00 2001 From: jitingxu1 Date: Fri, 20 Sep 2024 10:47:13 -0700 Subject: [PATCH 09/11] lint --- ibis/backends/tests/test_register.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ibis/backends/tests/test_register.py b/ibis/backends/tests/test_register.py index 6af0ae3e3fde..32955b4c239b 100644 --- a/ibis/backends/tests/test_register.py +++ b/ibis/backends/tests/test_register.py @@ -567,7 +567,9 @@ def num_diamonds(data_dir): def test_read_csv(con, data_dir, in_table_name, num_diamonds): if con.name in ("trino", "impala"): # TODO: remove after trino and impala have efficient insertion - pytest.skip("Both Impala and Trino lack efficient data insertion methods from Python.") + pytest.skip( + "Both Impala and Trino lack efficient data insertion methods from Python." + ) fname = "diamonds.csv" with pushd(data_dir / "csv"): if con.name in ( From 9acda5cad4d326a96d242f482b15ff1eb5f47338 Mon Sep 17 00:00:00 2001 From: jitingxu1 Date: Fri, 20 Sep 2024 11:12:47 -0700 Subject: [PATCH 10/11] enable impala in test --- ibis/backends/tests/test_register.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/ibis/backends/tests/test_register.py b/ibis/backends/tests/test_register.py index 32955b4c239b..13708c9268fb 100644 --- a/ibis/backends/tests/test_register.py +++ b/ibis/backends/tests/test_register.py @@ -475,12 +475,7 @@ def test_read_parquet_glob(con, tmp_path, ft_data): assert table.count().execute() == nrows * ntables -@pytest.mark.notyet( - [ - "flink", - "impala", - ] -) +@pytest.mark.notyet(["flink"]) @pytest.mark.notimpl(["druid"]) @pytest.mark.notimpl(["mssql"], raises=PyODBCProgrammingError) @pytest.mark.notimpl(["mysql"], raises=MySQLOperationalError) From 96ff7018a833836477f09052c3b47daf69a06361 Mon Sep 17 00:00:00 2001 From: jitingxu1 Date: Tue, 24 Sep 2024 10:57:18 -0700 Subject: [PATCH 11/11] test: add unit test and documentation --- ibis/backends/__init__.py | 15 ++++-- ibis/backends/tests/test_register.py | 70 ++++++++++++++++++++++++++++ 2 files changed, 82 insertions(+), 3 deletions(-) diff --git a/ibis/backends/__init__.py b/ibis/backends/__init__.py index 33e2f8d0c8eb..4cfa06f539d0 100644 --- a/ibis/backends/__init__.py +++ b/ibis/backends/__init__.py @@ -1277,8 +1277,7 @@ def read_csv( """Register a CSV file as a table in the current backend. This function reads a CSV file and registers it as a table in the current - backend. Note that for Impala and Trino backends, CSV read performance - may be suboptimal. + backend. Note that for Impala and Trino backends, the performance may be suboptimal. Parameters ---------- @@ -1289,6 +1288,11 @@ def read_csv( a sequentially generated name. **kwargs Additional keyword arguments passed to the backend loading function. + Common options are skip_rows, column_names, delimiter, and include_columns. + More details could be found: + https://arrow.apache.org/docs/python/generated/pyarrow.csv.ReadOptions.html + https://arrow.apache.org/docs/python/generated/pyarrow.csv.ParseOptions.html + https://arrow.apache.org/docs/python/generated/pyarrow.csv.ConvertOptions.html Returns ------- @@ -1317,6 +1321,11 @@ def read_csv( >>> table = con.read_csv("s3://bucket/path/to/file.csv") + Read csv file with custom pyarrow options: + + >>> table = con.read_csv( + ... "path/to/file.csv", delimiter=",", include_columns=["col1", "col3"] + ... ) """ pa = self._import_pyarrow() import pyarrow.csv as pcsv @@ -1342,7 +1351,7 @@ def read_csv( read_options = pcsv.ReadOptions(**read_options_args) parse_options = pcsv.ParseOptions(**parse_options_args) convert_options = pcsv.ConvertOptions(**convert_options_args) - if memory_pool: + if not memory_pool: memory_pool = pa.default_memory_pool() path = str(path) diff --git a/ibis/backends/tests/test_register.py b/ibis/backends/tests/test_register.py index 1cdd60bd2893..405a1c562298 100644 --- a/ibis/backends/tests/test_register.py +++ b/ibis/backends/tests/test_register.py @@ -614,3 +614,73 @@ def test_read_csv(con, data_dir, in_table_name, num_diamonds): } ) assert table.count().execute() == num_diamonds + + +@pytest.mark.parametrize( + ("skip_rows", "new_column_names", "delimiter", "include_columns"), + [ + param(True, True, False, False, id="skip_rows_with_column_names"), + param(False, False, False, True, id="include_columns"), + param(False, False, True, False, id="delimiter"), + ], +) +@pytest.mark.notyet(["flink"]) +@pytest.mark.notimpl(["druid"]) +@pytest.mark.never( + [ + "duckdb", + "polars", + "bigquery", + "clickhouse", + "datafusion", + "snowflake", + "pyspark", + ], + reason="backend implements its own read_csv", +) +@pytest.mark.notimpl(["mssql"], raises=PyODBCProgrammingError) +@pytest.mark.notimpl(["mysql"], raises=MySQLOperationalError) +def test_read_csv_pyarrow_options( + con, tmp_path, ft_data, skip_rows, new_column_names, delimiter, include_columns +): + pc = pytest.importorskip("pyarrow.csv") + + if con.name in ( + "duckdb", + "polars", + "bigquery", + "clickhouse", + "datafusion", + "snowflake", + "pyspark", + ): + pytest.skip(f"{con.name} implements its own `read_parquet`") + + column_names = ft_data.column_names + num_rows = ft_data.num_rows + fname = "tmp.csv" + pc.write_csv(ft_data, tmp_path / fname) + + options = {} + if skip_rows: + options["skip_rows"] = 2 + num_rows = num_rows - options["skip_rows"] + 1 + if new_column_names: + column_names = [f"col_{i}" for i in range(ft_data.num_columns)] + options["column_names"] = column_names + if delimiter: + new_delimiter = "*" + options["delimiter"] = new_delimiter + pc.write_csv( + ft_data, tmp_path / fname, pc.WriteOptions(delimiter=new_delimiter) + ) + if include_columns: + # try to include all types here + # pick the first 12 columns + column_names = column_names[:12] + options["include_columns"] = column_names + + table = con.read_csv(tmp_path / fname, **options) + + assert set(table.columns) == set(column_names) + assert table.count().execute() == num_rows