Skip to content

Commit

Permalink
Added sql() and show() convenience method to QueryBuilder
Browse files Browse the repository at this point in the history
Signed-off-by: Christian Molina <[email protected]>
  • Loading branch information
DevChrisCross authored and rtyler committed Jan 4, 2025
1 parent 6430151 commit a420c33
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 8 deletions.
64 changes: 61 additions & 3 deletions python/deltalake/query.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import logging
import warnings
from typing import List

Expand All @@ -9,6 +10,8 @@
from deltalake.table import DeltaTable
from deltalake.warnings import ExperimentalWarning

logger = logging.getLogger(__name__)


class QueryBuilder:
"""
Expand Down Expand Up @@ -47,9 +50,9 @@ def register(self, table_name: str, delta_table: DeltaTable) -> QueryBuilder:
)
return self

def execute(self, sql: str) -> List[pyarrow.RecordBatch]:
def execute(self, sql: str) -> QueryResult:
"""
Execute the query and return a list of record batches
Prepares the sql query to be executed.
For example:
Expand All @@ -59,6 +62,61 @@ def execute(self, sql: str) -> List[pyarrow.RecordBatch]:
>>> dt = DeltaTable.create(table_uri=tmp, schema=pa.schema([pa.field('name', pa.string())]))
>>> qb = QueryBuilder().register('test', dt)
>>> results = qb.execute('SELECT * FROM test')
>>> assert isinstance(results, QueryResult)
"""
return QueryResult(self._query_builder, sql)

def sql(self, sql: str) -> QueryResult:
"""
Convenience method for `execute()` method.
For example:
>>> tmp = getfixture('tmp_path')
>>> import pyarrow as pa
>>> from deltalake import DeltaTable, QueryBuilder
>>> dt = DeltaTable.create(table_uri=tmp, schema=pa.schema([pa.field('name', pa.string())]))
>>> qb = QueryBuilder().register('test', dt)
>>> query = 'SELECT * FROM test'
>>> assert qb.execute(query).fetchall() == qb.sql(query).fetchall()
"""
return self.execute(sql)


class QueryResult:
def __init__(self, query_builder: PyQueryBuilder, sql: str):
self._query_builder = query_builder
self._sql_query = sql

def show(self) -> None:
"""
Execute the query and prints the output in the console.
For example:
>>> tmp = getfixture('tmp_path')
>>> import pyarrow as pa
>>> from deltalake import DeltaTable, QueryBuilder
>>> dt = DeltaTable.create(table_uri=tmp, schema=pa.schema([pa.field('name', pa.string())]))
>>> qb = QueryBuilder().register('test', dt)
>>> results = qb.execute('SELECT * FROM test').show()
"""
records = self.fetchall()
if len(records) > 0:
print(pyarrow.Table.from_batches(records))
else:
logger.info("The executed query contains no records.")

def fetchall(self) -> List[pyarrow.RecordBatch]:
"""
Execute the query and return a list of record batches.
>>> tmp = getfixture('tmp_path')
>>> import pyarrow as pa
>>> from deltalake import DeltaTable, QueryBuilder
>>> dt = DeltaTable.create(table_uri=tmp, schema=pa.schema([pa.field('name', pa.string())]))
>>> qb = QueryBuilder().register('test', dt)
>>> results = qb.execute('SELECT * FROM test').fetchall()
>>> assert results is not None
"""
return self._query_builder.execute(sql)
return self._query_builder.execute(self._sql_query)
31 changes: 26 additions & 5 deletions python/tests/test_table_read.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import os
import tempfile
from datetime import date, datetime, timezone
Expand Down Expand Up @@ -958,11 +959,14 @@ def test_read_query_builder():
"month": ["4", "12", "12", "12"],
"day": ["5", "4", "20", "20"],
}
actual = pa.Table.from_batches(
QueryBuilder()
.register("tbl", dt)
.execute("SELECT * FROM tbl WHERE year >= 2021 ORDER BY value")
).to_pydict()

qb = QueryBuilder().register("tbl", dt)
query = "SELECT * FROM tbl WHERE year >= 2021 ORDER BY value"

actual = pa.Table.from_batches(qb.execute(query).fetchall()).to_pydict()
assert expected == actual

actual = pa.Table.from_batches(qb.sql(query).fetchall()).to_pydict()
assert expected == actual


Expand Down Expand Up @@ -998,5 +1002,22 @@ def test_read_query_builder_join_multiple_tables(tmp_path):
ORDER BY tbl1.date
"""
)
.fetchall()
).to_pydict()
assert expected == actual


def test_read_query_builder_show_output(capsys, caplog):
table_path = "../crates/test/tests/data/delta-0.8.0-partitioned"
dt = DeltaTable(table_path)
logging.getLogger("deltalake").setLevel(logging.INFO)

qb = QueryBuilder().register("tbl", dt)
query = "SELECT * FROM tbl WHERE year >= 2021 ORDER BY value"
qb.execute(query).show()
assert capsys.readouterr().out.strip() != ""

query = "SELECT * FROM tbl WHERE year >= 9999"
qb.execute(query).show()
assert "query contains no records" in caplog.text
assert capsys.readouterr().out.strip() == ""

0 comments on commit a420c33

Please sign in to comment.