Skip to content

Commit

Permalink
Added sql() and show() convenience method to QueryBuilder
Browse files Browse the repository at this point in the history
Signed-off-by: Christian Molina <[email protected]>
  • Loading branch information
DevChrisCross committed Jan 3, 2025
1 parent 9e35c06 commit ac92f90
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 6 deletions.
47 changes: 46 additions & 1 deletion python/deltalake/query.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import logging
import warnings
from typing import List

Expand All @@ -9,6 +10,8 @@
from deltalake.table import DeltaTable
from deltalake.warnings import ExperimentalWarning

logger = logging.getLogger(__name__)


class QueryBuilder:
"""
Expand All @@ -25,6 +28,7 @@ def __init__(self) -> None:
category=ExperimentalWarning,
)
self._query_builder = PyQueryBuilder()
self._print_output = False

def register(self, table_name: str, delta_table: DeltaTable) -> QueryBuilder:
"""
Expand Down Expand Up @@ -61,4 +65,45 @@ def execute(self, sql: str) -> List[pyarrow.RecordBatch]:
>>> results = qb.execute('SELECT * FROM test')
>>> assert results is not None
"""
return self._query_builder.execute(sql)
records = self._query_builder.execute(sql)
if self._print_output:
if len(records) > 0:
print(pyarrow.Table.from_batches(records))
else:
logger.info("The executed query contains no records.")

return records

def show(self, print_output: bool = True) -> QueryBuilder:
"""
Controls whether succeeding query outputs would be printed in the console.
For example:
>>> tmp = getfixture('tmp_path')
>>> import pyarrow as pa
>>> from deltalake import DeltaTable, QueryBuilder
>>> dt = DeltaTable.create(table_uri=tmp, schema=pa.schema([pa.field('name', pa.string())]))
>>> qb = QueryBuilder().register('test', dt)
>>> results = qb.show().execute('SELECT * FROM test')
"""
self._print_output = print_output
return self

def sql(self, sql: str) -> List[pyarrow.RecordBatch]:
"""
Convenience method for `execute()` method
For example:
>>> tmp = getfixture('tmp_path')
>>> import pyarrow as pa
>>> from deltalake import DeltaTable, QueryBuilder
>>> dt = DeltaTable.create(table_uri=tmp, schema=pa.schema([pa.field('name', pa.string())]))
>>> qb = QueryBuilder().register('test', dt)
>>> query = 'SELECT * FROM test'
>>> assert qb.execute(query) == qb.sql(query)
"""
return self.execute(sql)
30 changes: 25 additions & 5 deletions python/tests/test_table_read.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import os
import tempfile
from datetime import date, datetime, timezone
Expand Down Expand Up @@ -958,11 +959,14 @@ def test_read_query_builder():
"month": ["4", "12", "12", "12"],
"day": ["5", "4", "20", "20"],
}
actual = pa.Table.from_batches(
QueryBuilder()
.register("tbl", dt)
.execute("SELECT * FROM tbl WHERE year >= 2021 ORDER BY value")
).to_pydict()

qb = QueryBuilder().register("tbl", dt)
query = "SELECT * FROM tbl WHERE year >= 2021 ORDER BY value"

actual = pa.Table.from_batches(qb.execute(query)).to_pydict()
assert expected == actual

actual = pa.Table.from_batches(qb.sql(query)).to_pydict()
assert expected == actual


Expand Down Expand Up @@ -1000,3 +1004,19 @@ def test_read_query_builder_join_multiple_tables(tmp_path):
)
).to_pydict()
assert expected == actual


def test_read_query_builder_show_output(capsys, caplog):
table_path = "../crates/test/tests/data/delta-0.8.0-partitioned"
dt = DeltaTable(table_path)
logging.getLogger("deltalake").setLevel(logging.INFO)

qb = QueryBuilder().register("tbl", dt)
query = "SELECT * FROM tbl WHERE year >= 2021 ORDER BY value"
qb.show().execute(query)
assert capsys.readouterr().out.strip() != ""

query = "SELECT * FROM tbl WHERE year >= 9999"
qb.show().execute(query)
assert "query contains no records" in caplog.text
assert capsys.readouterr().out.strip() == ""

0 comments on commit ac92f90

Please sign in to comment.