From 51feb765776a3753bf1cfbc54001902aeeb261fc Mon Sep 17 00:00:00 2001 From: Christian Molina Date: Sat, 4 Jan 2025 16:35:45 +0800 Subject: [PATCH] Refactored implementation of show() and sql() Signed-off-by: Christian Molina --- python/deltalake/query.py | 58 ++++++++++++++++++++------------- python/tests/test_table_read.py | 10 +++--- 2 files changed, 41 insertions(+), 27 deletions(-) diff --git a/python/deltalake/query.py b/python/deltalake/query.py index f99701ab88..050573d025 100644 --- a/python/deltalake/query.py +++ b/python/deltalake/query.py @@ -28,7 +28,6 @@ def __init__(self) -> None: category=ExperimentalWarning, ) self._query_builder = PyQueryBuilder() - self._print_output = False def register(self, table_name: str, delta_table: DeltaTable) -> QueryBuilder: """ @@ -51,9 +50,9 @@ def register(self, table_name: str, delta_table: DeltaTable) -> QueryBuilder: ) return self - def execute(self, sql: str) -> List[pyarrow.RecordBatch]: + def execute(self, sql: str) -> QueryResult: """ - Execute the query and return a list of record batches + Prepares the sql query to be executed. For example: @@ -63,20 +62,14 @@ def execute(self, sql: str) -> List[pyarrow.RecordBatch]: >>> dt = DeltaTable.create(table_uri=tmp, schema=pa.schema([pa.field('name', pa.string())])) >>> qb = QueryBuilder().register('test', dt) >>> results = qb.execute('SELECT * FROM test') - >>> assert results is not None + >>> assert isinstance(results, QueryResult) """ - records = self._query_builder.execute(sql) - if self._print_output: - if len(records) > 0: - print(pyarrow.Table.from_batches(records)) - else: - logger.info("The executed query contains no records.") + return QueryResult(self._query_builder, sql) - return records - def show(self, print_output: bool = True) -> QueryBuilder: + def sql(self, sql: str) -> QueryResult: """ - Controls whether succeeding query outputs would be printed in the console. + Convenience method for `execute()` method. For example: @@ -85,15 +78,20 @@ def show(self, print_output: bool = True) -> QueryBuilder: >>> from deltalake import DeltaTable, QueryBuilder >>> dt = DeltaTable.create(table_uri=tmp, schema=pa.schema([pa.field('name', pa.string())])) >>> qb = QueryBuilder().register('test', dt) - >>> results = qb.show().execute('SELECT * FROM test') - + >>> query = 'SELECT * FROM test' + >>> assert qb.execute(query).fetchall() == qb.sql(query).fetchall() """ - self._print_output = print_output - return self + return self.execute(sql) - def sql(self, sql: str) -> List[pyarrow.RecordBatch]: + +class QueryResult: + def __init__(self, query_builder: PyQueryBuilder, sql: str): + self._query_builder = query_builder + self._sql_query = sql + + def show(self): """ - Convenience method for `execute()` method + Execute the query and prints the output in the console. For example: @@ -102,8 +100,24 @@ def sql(self, sql: str) -> List[pyarrow.RecordBatch]: >>> from deltalake import DeltaTable, QueryBuilder >>> dt = DeltaTable.create(table_uri=tmp, schema=pa.schema([pa.field('name', pa.string())])) >>> qb = QueryBuilder().register('test', dt) - >>> query = 'SELECT * FROM test' - >>> assert qb.execute(query) == qb.sql(query) + >>> results = qb.execute('SELECT * FROM test').show() + """ + records = self.fetchall() + if len(records) > 0: + print(pyarrow.Table.from_batches(records)) + else: + logger.info("The executed query contains no records.") + def fetchall(self) -> List[pyarrow.RecordBatch]: """ - return self.execute(sql) + Execute the query and return a list of record batches. + + >>> tmp = getfixture('tmp_path') + >>> import pyarrow as pa + >>> from deltalake import DeltaTable, QueryBuilder + >>> dt = DeltaTable.create(table_uri=tmp, schema=pa.schema([pa.field('name', pa.string())])) + >>> qb = QueryBuilder().register('test', dt) + >>> results = qb.execute('SELECT * FROM test').fetchall() + >>> assert results is not None + """ + return self._query_builder.execute(self._sql_query) diff --git a/python/tests/test_table_read.py b/python/tests/test_table_read.py index 489e942cde..0a806110d4 100644 --- a/python/tests/test_table_read.py +++ b/python/tests/test_table_read.py @@ -963,10 +963,10 @@ def test_read_query_builder(): qb = QueryBuilder().register("tbl", dt) query = "SELECT * FROM tbl WHERE year >= 2021 ORDER BY value" - actual = pa.Table.from_batches(qb.execute(query)).to_pydict() + actual = pa.Table.from_batches(qb.execute(query).fetchall()).to_pydict() assert expected == actual - actual = pa.Table.from_batches(qb.sql(query)).to_pydict() + actual = pa.Table.from_batches(qb.sql(query).fetchall()).to_pydict() assert expected == actual @@ -1001,7 +1001,7 @@ def test_read_query_builder_join_multiple_tables(tmp_path): INNER JOIN tbl2 ON tbl1.date = tbl2.date ORDER BY tbl1.date """ - ) + ).fetchall() ).to_pydict() assert expected == actual @@ -1013,10 +1013,10 @@ def test_read_query_builder_show_output(capsys, caplog): qb = QueryBuilder().register("tbl", dt) query = "SELECT * FROM tbl WHERE year >= 2021 ORDER BY value" - qb.show().execute(query) + qb.execute(query).show() assert capsys.readouterr().out.strip() != "" query = "SELECT * FROM tbl WHERE year >= 9999" - qb.show().execute(query) + qb.execute(query).show() assert "query contains no records" in caplog.text assert capsys.readouterr().out.strip() == ""