From a420c3307c8f94be148ddfc62d42ea53e0a42e5a Mon Sep 17 00:00:00 2001 From: Christian Molina Date: Sat, 4 Jan 2025 02:07:44 +0800 Subject: [PATCH] Added sql() and show() convenience method to QueryBuilder Signed-off-by: Christian Molina --- python/deltalake/query.py | 64 +++++++++++++++++++++++++++++++-- python/tests/test_table_read.py | 31 +++++++++++++--- 2 files changed, 87 insertions(+), 8 deletions(-) diff --git a/python/deltalake/query.py b/python/deltalake/query.py index b3e4100d8c..8e1b2b5963 100644 --- a/python/deltalake/query.py +++ b/python/deltalake/query.py @@ -1,5 +1,6 @@ from __future__ import annotations +import logging import warnings from typing import List @@ -9,6 +10,8 @@ from deltalake.table import DeltaTable from deltalake.warnings import ExperimentalWarning +logger = logging.getLogger(__name__) + class QueryBuilder: """ @@ -47,9 +50,9 @@ def register(self, table_name: str, delta_table: DeltaTable) -> QueryBuilder: ) return self - def execute(self, sql: str) -> List[pyarrow.RecordBatch]: + def execute(self, sql: str) -> QueryResult: """ - Execute the query and return a list of record batches + Prepares the sql query to be executed. For example: @@ -59,6 +62,61 @@ def execute(self, sql: str) -> List[pyarrow.RecordBatch]: >>> dt = DeltaTable.create(table_uri=tmp, schema=pa.schema([pa.field('name', pa.string())])) >>> qb = QueryBuilder().register('test', dt) >>> results = qb.execute('SELECT * FROM test') + >>> assert isinstance(results, QueryResult) + """ + return QueryResult(self._query_builder, sql) + + def sql(self, sql: str) -> QueryResult: + """ + Convenience method for `execute()` method. + + For example: + + >>> tmp = getfixture('tmp_path') + >>> import pyarrow as pa + >>> from deltalake import DeltaTable, QueryBuilder + >>> dt = DeltaTable.create(table_uri=tmp, schema=pa.schema([pa.field('name', pa.string())])) + >>> qb = QueryBuilder().register('test', dt) + >>> query = 'SELECT * FROM test' + >>> assert qb.execute(query).fetchall() == qb.sql(query).fetchall() + """ + return self.execute(sql) + + +class QueryResult: + def __init__(self, query_builder: PyQueryBuilder, sql: str): + self._query_builder = query_builder + self._sql_query = sql + + def show(self) -> None: + """ + Execute the query and prints the output in the console. + + For example: + + >>> tmp = getfixture('tmp_path') + >>> import pyarrow as pa + >>> from deltalake import DeltaTable, QueryBuilder + >>> dt = DeltaTable.create(table_uri=tmp, schema=pa.schema([pa.field('name', pa.string())])) + >>> qb = QueryBuilder().register('test', dt) + >>> results = qb.execute('SELECT * FROM test').show() + """ + records = self.fetchall() + if len(records) > 0: + print(pyarrow.Table.from_batches(records)) + else: + logger.info("The executed query contains no records.") + + def fetchall(self) -> List[pyarrow.RecordBatch]: + """ + Execute the query and return a list of record batches. + + >>> tmp = getfixture('tmp_path') + >>> import pyarrow as pa + >>> from deltalake import DeltaTable, QueryBuilder + >>> dt = DeltaTable.create(table_uri=tmp, schema=pa.schema([pa.field('name', pa.string())])) + >>> qb = QueryBuilder().register('test', dt) + >>> results = qb.execute('SELECT * FROM test').fetchall() >>> assert results is not None """ - return self._query_builder.execute(sql) + return self._query_builder.execute(self._sql_query) diff --git a/python/tests/test_table_read.py b/python/tests/test_table_read.py index 30d7f21d7f..5fe0e413b3 100644 --- a/python/tests/test_table_read.py +++ b/python/tests/test_table_read.py @@ -1,3 +1,4 @@ +import logging import os import tempfile from datetime import date, datetime, timezone @@ -958,11 +959,14 @@ def test_read_query_builder(): "month": ["4", "12", "12", "12"], "day": ["5", "4", "20", "20"], } - actual = pa.Table.from_batches( - QueryBuilder() - .register("tbl", dt) - .execute("SELECT * FROM tbl WHERE year >= 2021 ORDER BY value") - ).to_pydict() + + qb = QueryBuilder().register("tbl", dt) + query = "SELECT * FROM tbl WHERE year >= 2021 ORDER BY value" + + actual = pa.Table.from_batches(qb.execute(query).fetchall()).to_pydict() + assert expected == actual + + actual = pa.Table.from_batches(qb.sql(query).fetchall()).to_pydict() assert expected == actual @@ -998,5 +1002,22 @@ def test_read_query_builder_join_multiple_tables(tmp_path): ORDER BY tbl1.date """ ) + .fetchall() ).to_pydict() assert expected == actual + + +def test_read_query_builder_show_output(capsys, caplog): + table_path = "../crates/test/tests/data/delta-0.8.0-partitioned" + dt = DeltaTable(table_path) + logging.getLogger("deltalake").setLevel(logging.INFO) + + qb = QueryBuilder().register("tbl", dt) + query = "SELECT * FROM tbl WHERE year >= 2021 ORDER BY value" + qb.execute(query).show() + assert capsys.readouterr().out.strip() != "" + + query = "SELECT * FROM tbl WHERE year >= 9999" + qb.execute(query).show() + assert "query contains no records" in caplog.text + assert capsys.readouterr().out.strip() == ""