Skip to content

Commit

Permalink
Refactored implementation of show() and sql()
Browse files Browse the repository at this point in the history
Signed-off-by: Christian Molina <[email protected]>
  • Loading branch information
DevChrisCross committed Jan 4, 2025
1 parent c287bf3 commit 51feb76
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 27 deletions.
58 changes: 36 additions & 22 deletions python/deltalake/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ def __init__(self) -> None:
category=ExperimentalWarning,
)
self._query_builder = PyQueryBuilder()
self._print_output = False

def register(self, table_name: str, delta_table: DeltaTable) -> QueryBuilder:
"""
Expand All @@ -51,9 +50,9 @@ def register(self, table_name: str, delta_table: DeltaTable) -> QueryBuilder:
)
return self

def execute(self, sql: str) -> List[pyarrow.RecordBatch]:
def execute(self, sql: str) -> QueryResult:
"""
Execute the query and return a list of record batches
Prepares the sql query to be executed.
For example:
Expand All @@ -63,20 +62,14 @@ def execute(self, sql: str) -> List[pyarrow.RecordBatch]:
>>> dt = DeltaTable.create(table_uri=tmp, schema=pa.schema([pa.field('name', pa.string())]))
>>> qb = QueryBuilder().register('test', dt)
>>> results = qb.execute('SELECT * FROM test')
>>> assert results is not None
>>> assert isinstance(results, QueryResult)
"""
records = self._query_builder.execute(sql)
if self._print_output:
if len(records) > 0:
print(pyarrow.Table.from_batches(records))
else:
logger.info("The executed query contains no records.")
return QueryResult(self._query_builder, sql)

return records

def show(self, print_output: bool = True) -> QueryBuilder:
def sql(self, sql: str) -> QueryResult:
"""
Controls whether succeeding query outputs would be printed in the console.
Convenience method for `execute()` method.
For example:
Expand All @@ -85,15 +78,20 @@ def show(self, print_output: bool = True) -> QueryBuilder:
>>> from deltalake import DeltaTable, QueryBuilder
>>> dt = DeltaTable.create(table_uri=tmp, schema=pa.schema([pa.field('name', pa.string())]))
>>> qb = QueryBuilder().register('test', dt)
>>> results = qb.show().execute('SELECT * FROM test')
>>> query = 'SELECT * FROM test'
>>> assert qb.execute(query).fetchall() == qb.sql(query).fetchall()
"""
self._print_output = print_output
return self
return self.execute(sql)

def sql(self, sql: str) -> List[pyarrow.RecordBatch]:

class QueryResult:
def __init__(self, query_builder: PyQueryBuilder, sql: str):
self._query_builder = query_builder
self._sql_query = sql

def show(self):
"""
Convenience method for `execute()` method
Execute the query and prints the output in the console.
For example:
Expand All @@ -102,8 +100,24 @@ def sql(self, sql: str) -> List[pyarrow.RecordBatch]:
>>> from deltalake import DeltaTable, QueryBuilder
>>> dt = DeltaTable.create(table_uri=tmp, schema=pa.schema([pa.field('name', pa.string())]))
>>> qb = QueryBuilder().register('test', dt)
>>> query = 'SELECT * FROM test'
>>> assert qb.execute(query) == qb.sql(query)
>>> results = qb.execute('SELECT * FROM test').show()
"""
records = self.fetchall()
if len(records) > 0:
print(pyarrow.Table.from_batches(records))
else:
logger.info("The executed query contains no records.")

def fetchall(self) -> List[pyarrow.RecordBatch]:
"""
return self.execute(sql)
Execute the query and return a list of record batches.
>>> tmp = getfixture('tmp_path')
>>> import pyarrow as pa
>>> from deltalake import DeltaTable, QueryBuilder
>>> dt = DeltaTable.create(table_uri=tmp, schema=pa.schema([pa.field('name', pa.string())]))
>>> qb = QueryBuilder().register('test', dt)
>>> results = qb.execute('SELECT * FROM test').fetchall()
>>> assert results is not None
"""
return self._query_builder.execute(self._sql_query)
10 changes: 5 additions & 5 deletions python/tests/test_table_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -963,10 +963,10 @@ def test_read_query_builder():
qb = QueryBuilder().register("tbl", dt)
query = "SELECT * FROM tbl WHERE year >= 2021 ORDER BY value"

actual = pa.Table.from_batches(qb.execute(query)).to_pydict()
actual = pa.Table.from_batches(qb.execute(query).fetchall()).to_pydict()
assert expected == actual

actual = pa.Table.from_batches(qb.sql(query)).to_pydict()
actual = pa.Table.from_batches(qb.sql(query).fetchall()).to_pydict()
assert expected == actual


Expand Down Expand Up @@ -1001,7 +1001,7 @@ def test_read_query_builder_join_multiple_tables(tmp_path):
INNER JOIN tbl2 ON tbl1.date = tbl2.date
ORDER BY tbl1.date
"""
)
).fetchall()
).to_pydict()
assert expected == actual

Expand All @@ -1013,10 +1013,10 @@ def test_read_query_builder_show_output(capsys, caplog):

qb = QueryBuilder().register("tbl", dt)
query = "SELECT * FROM tbl WHERE year >= 2021 ORDER BY value"
qb.show().execute(query)
qb.execute(query).show()
assert capsys.readouterr().out.strip() != ""

query = "SELECT * FROM tbl WHERE year >= 9999"
qb.show().execute(query)
qb.execute(query).show()
assert "query contains no records" in caplog.text
assert capsys.readouterr().out.strip() == ""

0 comments on commit 51feb76

Please sign in to comment.