Skip to content

Commit

Permalink
fix(python): read_database now properly handles empty result sets f…
Browse files Browse the repository at this point in the history
…rom `arrow-odbc` (#14916)
  • Loading branch information
alexander-beedie authored Mar 8, 2024
1 parent 6188cbf commit b2d7e77
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 30 deletions.
81 changes: 51 additions & 30 deletions py-polars/polars/io/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,13 +127,22 @@ def fetch_record_batches(
) -> Iterable[pa.RecordBatch]:
"""Fetch results in batches."""
from arrow_odbc import read_arrow_batches_from_odbc
from pyarrow import RecordBatch

yield from read_arrow_batches_from_odbc(
n_batches = 0
batch_reader = read_arrow_batches_from_odbc(
query=self.query,
batch_size=batch_size,
connection_string=self.connection_string,
**self.execute_options,
)
for batch in batch_reader:
yield batch
n_batches += 1

if n_batches == 0:
# empty result set; return empty batch with accurate schema
yield RecordBatch.from_pylist([], schema=batch_reader.schema)

# internally arrow-odbc always reads batches
fetchall = fetchmany = fetch_record_batches
Expand Down Expand Up @@ -172,14 +181,14 @@ def __exit__(
def __repr__(self) -> str:
return f"<{type(self).__name__} module={self.driver_name!r}>"

def _arrow_batches(
def _fetch_arrow(
self,
driver_properties: _ArrowDriverProperties_,
*,
batch_size: int | None,
iter_batches: bool,
) -> Iterable[pa.RecordBatch]:
"""Yield Arrow data in batches, or as a single 'fetchall' batch."""
"""Yield Arrow data as a generator of one or more RecordBatches or Tables."""
fetch_batches = driver_properties["fetch_batches"]
if not iter_batches or fetch_batches is None:
fetch_method, sz = driver_properties["fetch_all"], []
Expand All @@ -200,31 +209,6 @@ def _arrow_batches(
break
yield arrow

def _normalise_cursor(self, conn: ConnectionOrCursor) -> Cursor:
"""Normalise a connection object such that we have the query executor."""
if self.driver_name == "sqlalchemy" and type(conn).__name__ == "Engine":
self.can_close_cursor = True
if conn.driver == "databricks-sql-python": # type: ignore[union-attr]
# take advantage of the raw connection to get arrow integration
self.driver_name = "databricks"
return conn.raw_connection().cursor() # type: ignore[union-attr, return-value]
else:
# sqlalchemy engine; direct use is deprecated, so prefer the connection
return conn.connect() # type: ignore[union-attr, return-value]

elif hasattr(conn, "cursor"):
# connection has a dedicated cursor; prefer over direct execute
cursor = cursor() if callable(cursor := conn.cursor) else cursor
self.can_close_cursor = True
return cursor

elif hasattr(conn, "execute"):
# can execute directly (given cursor, sqlalchemy connection, etc)
return conn # type: ignore[return-value]

msg = f"Unrecognised connection {conn!r}; unable to find 'execute' method"
raise TypeError(msg)

@staticmethod
def _fetchall_rows(result: Cursor) -> Iterable[Sequence[Any]]:
"""Fetch row data in a single call, returning the complete result set."""
Expand Down Expand Up @@ -265,7 +249,7 @@ def _from_arrow(
self.can_close_cursor = fetch_batches is None or not iter_batches
frames = (
from_arrow(batch, schema_overrides=schema_overrides)
for batch in self._arrow_batches(
for batch in self._fetch_arrow(
driver_properties,
iter_batches=iter_batches,
batch_size=batch_size,
Expand Down Expand Up @@ -326,6 +310,31 @@ def _from_rows(
return frames if iter_batches else next(frames) # type: ignore[arg-type]
return None

def _normalise_cursor(self, conn: ConnectionOrCursor) -> Cursor:
"""Normalise a connection object such that we have the query executor."""
if self.driver_name == "sqlalchemy" and type(conn).__name__ == "Engine":
self.can_close_cursor = True
if conn.driver == "databricks-sql-python": # type: ignore[union-attr]
# take advantage of the raw connection to get arrow integration
self.driver_name = "databricks"
return conn.raw_connection().cursor() # type: ignore[union-attr, return-value]
else:
# sqlalchemy engine; direct use is deprecated, so prefer the connection
return conn.connect() # type: ignore[union-attr, return-value]

elif hasattr(conn, "cursor"):
# connection has a dedicated cursor; prefer over direct execute
cursor = cursor() if callable(cursor := conn.cursor) else cursor
self.can_close_cursor = True
return cursor

elif hasattr(conn, "execute"):
# can execute directly (given cursor, sqlalchemy connection, etc)
return conn # type: ignore[return-value]

msg = f"Unrecognised connection {conn!r}; unable to find 'execute' method"
raise TypeError(msg)

def execute(
self,
query: str | Selectable,
Expand Down Expand Up @@ -532,7 +541,11 @@ def read_database( # noqa: D417
* If polars has to create a cursor from your connection in order to execute the
query then that cursor will be automatically closed when the query completes;
however, polars will *never* close any other connection or cursor.
however, polars will *never* close any other open connection or cursor.
* We are able to support more than just relational databases and SQL queries
through this function. For example, we can load graph database results from
a `KùzuDB` connection in conjunction with a Cypher query.
See Also
--------
Expand Down Expand Up @@ -577,6 +590,14 @@ def read_database( # noqa: D417
... batch_size=1000,
... ):
... do_something(df) # doctest: +SKIP
Load graph data query results from a `KùzuDB` connection and a Cypher query:
>>> df = pl.read_database(
... query="MATCH (a:User)-[f:Follows]->(b:User) RETURN a.name, f.since, b.name",
... connection=kuzu_db_conn,
... ) # doctest: +SKIP
""" # noqa: W505
if isinstance(connection, str):
# check for odbc connection string
Expand Down
25 changes: 25 additions & 0 deletions py-polars/tests/unit/io/test_database_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,12 @@ def test_read_database(
engine=str(connect_using), # type: ignore[arg-type]
schema_overrides=schema_overrides,
)
df_empty = pl.read_database_uri(
uri=f"sqlite:///{tmp_sqlite_db}",
query="SELECT * FROM test_data WHERE name LIKE '%polars%'",
engine=str(connect_using), # type: ignore[arg-type]
schema_overrides=schema_overrides,
)
elif "adbc" in os.environ["PYTEST_CURRENT_TEST"]:
# externally instantiated adbc connections
with connect_using(tmp_sqlite_db) as conn, conn.cursor():
Expand All @@ -335,6 +341,12 @@ def test_read_database(
schema_overrides=schema_overrides,
batch_size=batch_size,
)
df_empty = pl.read_database(
connection=conn,
query="SELECT * FROM test_data WHERE name LIKE '%polars%'",
schema_overrides=schema_overrides,
batch_size=batch_size,
)
else:
# other user-supplied connections
df = pl.read_database(
Expand All @@ -343,11 +355,24 @@ def test_read_database(
schema_overrides=schema_overrides,
batch_size=batch_size,
)
df_empty = pl.read_database(
connection=connect_using(tmp_sqlite_db),
query="SELECT * FROM test_data WHERE name LIKE '%polars%'",
schema_overrides=schema_overrides,
batch_size=batch_size,
)

# validate the expected query return (data and schema)
assert df.schema == expected_dtypes
assert df.shape == (2, 4)
assert df["date"].to_list() == expected_dates

# note: 'cursor.description' is not reliable when no query
# data is returned, so no point comparing expected dtypes
assert df_empty.columns == ["id", "name", "value", "date"]
assert df_empty.shape == (0, 4)
assert df_empty["date"].to_list() == []


def test_read_database_alchemy_selectable(tmp_sqlite_db: Path) -> None:
# various flavours of alchemy connection
Expand Down

0 comments on commit b2d7e77

Please sign in to comment.