diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index a8f04a05a..75d2c665c 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -158,6 +158,7 @@ def __init__( ) self.use_hybrid_disposition = kwargs.get("use_hybrid_disposition", True) + self.use_cloud_fetch = kwargs.get("use_cloud_fetch", True) # Extract warehouse ID from http_path self.warehouse_id = self._extract_warehouse_id(http_path) @@ -694,7 +695,7 @@ def get_catalogs( max_bytes=max_bytes, lz4_compression=False, cursor=cursor, - use_cloud_fetch=False, + use_cloud_fetch=self.use_cloud_fetch, parameters=[], async_op=False, enforce_embedded_schema_correctness=False, @@ -727,7 +728,7 @@ def get_schemas( max_bytes=max_bytes, lz4_compression=False, cursor=cursor, - use_cloud_fetch=False, + use_cloud_fetch=self.use_cloud_fetch, parameters=[], async_op=False, enforce_embedded_schema_correctness=False, @@ -768,7 +769,7 @@ def get_tables( max_bytes=max_bytes, lz4_compression=False, cursor=cursor, - use_cloud_fetch=False, + use_cloud_fetch=self.use_cloud_fetch, parameters=[], async_op=False, enforce_embedded_schema_correctness=False, @@ -815,7 +816,7 @@ def get_columns( max_bytes=max_bytes, lz4_compression=False, cursor=cursor, - use_cloud_fetch=False, + use_cloud_fetch=self.use_cloud_fetch, parameters=[], async_op=False, enforce_embedded_schema_correctness=False, diff --git a/src/databricks/sql/backend/sea/utils/filters.py b/src/databricks/sql/backend/sea/utils/filters.py index 0bdb23b03..dd119264a 100644 --- a/src/databricks/sql/backend/sea/utils/filters.py +++ b/src/databricks/sql/backend/sea/utils/filters.py @@ -6,12 +6,12 @@ from __future__ import annotations +import io import logging from typing import ( List, Optional, Any, - Callable, cast, TYPE_CHECKING, ) @@ -20,6 +20,16 @@ from databricks.sql.backend.sea.result_set import SeaResultSet from databricks.sql.backend.types import ExecuteResponse +from databricks.sql.backend.sea.models.base import ResultData +from databricks.sql.backend.sea.backend import SeaDatabricksClient +from databricks.sql.utils import CloudFetchQueue, ArrowQueue + +try: + import pyarrow + import pyarrow.compute as pc +except ImportError: + pyarrow = None + pc = None logger = logging.getLogger(__name__) @@ -30,32 +40,18 @@ class ResultSetFilter: """ @staticmethod - def _filter_sea_result_set( - result_set: SeaResultSet, filter_func: Callable[[List[Any]], bool] - ) -> SeaResultSet: + def _create_execute_response(result_set: SeaResultSet) -> ExecuteResponse: """ - Filter a SEA result set using the provided filter function. + Create an ExecuteResponse with parameters from the original result set. Args: - result_set: The SEA result set to filter - filter_func: Function that takes a row and returns True if the row should be included + result_set: Original result set to copy parameters from Returns: - A filtered SEA result set + ExecuteResponse: New execute response object """ - - # Get all remaining rows - all_rows = result_set.results.remaining_rows() - - # Filter rows - filtered_rows = [row for row in all_rows if filter_func(row)] - - # Reuse the command_id from the original result set - command_id = result_set.command_id - - # Create an ExecuteResponse for the filtered data - execute_response = ExecuteResponse( - command_id=command_id, + return ExecuteResponse( + command_id=result_set.command_id, status=result_set.status, description=result_set.description, has_been_closed_server_side=result_set.has_been_closed_server_side, @@ -64,32 +60,145 @@ def _filter_sea_result_set( is_staging_operation=False, ) - # Create a new ResultData object with filtered data - from databricks.sql.backend.sea.models.base import ResultData + @staticmethod + def _update_manifest(result_set: SeaResultSet, new_row_count: int): + """ + Create a copy of the manifest with updated row count. + + Args: + result_set: Original result set to copy manifest from + new_row_count: New total row count for filtered data - result_data = ResultData(data=filtered_rows, external_links=None) + Returns: + Updated manifest copy + """ + filtered_manifest = result_set.manifest + filtered_manifest.total_row_count = new_row_count + return filtered_manifest - from databricks.sql.backend.sea.backend import SeaDatabricksClient + @staticmethod + def _create_filtered_result_set( + result_set: SeaResultSet, + result_data: ResultData, + row_count: int, + ) -> "SeaResultSet": + """ + Create a new filtered SeaResultSet with the provided data. + + Args: + result_set: Original result set to copy parameters from + result_data: New result data for the filtered set + row_count: Number of rows in the filtered data + + Returns: + New filtered SeaResultSet + """ from databricks.sql.backend.sea.result_set import SeaResultSet - # Create a new SeaResultSet with the filtered data - manifest = result_set.manifest - manifest.total_row_count = len(filtered_rows) + execute_response = ResultSetFilter._create_execute_response(result_set) + filtered_manifest = ResultSetFilter._update_manifest(result_set, row_count) - filtered_result_set = SeaResultSet( + return SeaResultSet( connection=result_set.connection, execute_response=execute_response, sea_client=cast(SeaDatabricksClient, result_set.backend), result_data=result_data, - manifest=manifest, + manifest=filtered_manifest, buffer_size_bytes=result_set.buffer_size_bytes, arraysize=result_set.arraysize, ) - return filtered_result_set + @staticmethod + def _filter_arrow_table( + table: Any, # pyarrow.Table + column_name: str, + allowed_values: List[str], + case_sensitive: bool = True, + ) -> Any: # returns pyarrow.Table + """ + Filter a PyArrow table by column values. + + Args: + table: The PyArrow table to filter + column_name: The name of the column to filter on + allowed_values: List of allowed values for the column + case_sensitive: Whether to perform case-sensitive comparison + + Returns: + A filtered PyArrow table + """ + if not pyarrow: + raise ImportError("PyArrow is required for Arrow table filtering") + + if table.num_rows == 0: + return table + + # Handle case-insensitive filtering by normalizing both column and allowed values + if not case_sensitive: + # Convert allowed values to uppercase + allowed_values = [v.upper() for v in allowed_values] + # Get column values as uppercase + column = pc.utf8_upper(table[column_name]) + else: + # Use column as-is + column = table[column_name] + + # Convert allowed_values to PyArrow Array + allowed_array = pyarrow.array(allowed_values) + + # Construct a boolean mask: True where column is in allowed_list + mask = pc.is_in(column, value_set=allowed_array) + return table.filter(mask) + + @staticmethod + def _filter_arrow_result_set( + result_set: SeaResultSet, + column_index: int, + allowed_values: List[str], + case_sensitive: bool = True, + ) -> SeaResultSet: + """ + Filter a SEA result set that contains Arrow tables. + + Args: + result_set: The SEA result set to filter (containing Arrow data) + column_index: The index of the column to filter on + allowed_values: List of allowed values for the column + case_sensitive: Whether to perform case-sensitive comparison + + Returns: + A filtered SEA result set + """ + # Validate column index and get column name + if column_index >= len(result_set.description): + raise ValueError(f"Column index {column_index} is out of bounds") + column_name = result_set.description[column_index][0] + + # Get all remaining rows as Arrow table and filter it + arrow_table = result_set.results.remaining_rows() + filtered_table = ResultSetFilter._filter_arrow_table( + arrow_table, column_name, allowed_values, case_sensitive + ) + + # Convert the filtered table to Arrow stream format for ResultData + sink = io.BytesIO() + with pyarrow.ipc.new_stream(sink, filtered_table.schema) as writer: + writer.write_table(filtered_table) + arrow_stream_bytes = sink.getvalue() + + # Create ResultData with attachment containing the filtered data + result_data = ResultData( + data=None, # No JSON data + external_links=None, # No external links + attachment=arrow_stream_bytes, # Arrow data as attachment + ) + + return ResultSetFilter._create_filtered_result_set( + result_set, result_data, filtered_table.num_rows + ) @staticmethod - def filter_by_column_values( + def _filter_json_result_set( result_set: SeaResultSet, column_index: int, allowed_values: List[str], @@ -107,22 +216,35 @@ def filter_by_column_values( Returns: A filtered result set """ + # Validate column index (optional - not in arrow version but good practice) + if column_index >= len(result_set.description): + raise ValueError(f"Column index {column_index} is out of bounds") - # Convert to uppercase for case-insensitive comparison if needed + # Extract rows + all_rows = result_set.results.remaining_rows() + + # Convert allowed values if case-insensitive if not case_sensitive: allowed_values = [v.upper() for v in allowed_values] + # Helper lambda to get column value based on case sensitivity + get_column_value = ( + lambda row: row[column_index].upper() + if not case_sensitive + else row[column_index] + ) + + # Filter rows based on allowed values + filtered_rows = [ + row + for row in all_rows + if len(row) > column_index and get_column_value(row) in allowed_values + ] + + # Create filtered result set + result_data = ResultData(data=filtered_rows, external_links=None) - return ResultSetFilter._filter_sea_result_set( - result_set, - lambda row: ( - len(row) > column_index - and ( - row[column_index].upper() - if not case_sensitive - else row[column_index] - ) - in allowed_values - ), + return ResultSetFilter._create_filtered_result_set( + result_set, result_data, len(filtered_rows) ) @staticmethod @@ -143,14 +265,25 @@ def filter_tables_by_type( Returns: A filtered result set containing only tables of the specified types """ - # Default table types if none specified DEFAULT_TABLE_TYPES = ["TABLE", "VIEW", "SYSTEM TABLE"] - valid_types = ( - table_types if table_types and len(table_types) > 0 else DEFAULT_TABLE_TYPES - ) + valid_types = table_types if table_types else DEFAULT_TABLE_TYPES + # Check if we have an Arrow table (cloud fetch) or JSON data # Table type is the 6th column (index 5) - return ResultSetFilter.filter_by_column_values( - result_set, 5, valid_types, case_sensitive=True - ) + if isinstance(result_set.results, (CloudFetchQueue, ArrowQueue)): + # For Arrow tables, we need to handle filtering differently + return ResultSetFilter._filter_arrow_result_set( + result_set, + column_index=5, + allowed_values=valid_types, + case_sensitive=True, + ) + else: + # For JSON data, use the existing filter method + return ResultSetFilter._filter_json_result_set( + result_set, + column_index=5, + allowed_values=valid_types, + case_sensitive=True, + ) diff --git a/tests/unit/test_filters.py b/tests/unit/test_filters.py index 13dfac006..4efe51f3e 100644 --- a/tests/unit/test_filters.py +++ b/tests/unit/test_filters.py @@ -68,7 +68,7 @@ def setUp(self): self.mock_sea_result_set.has_been_closed_server_side = False self.mock_sea_result_set._arrow_schema_bytes = None - def test_filter_by_column_values(self): + def test__filter_json_result_set(self): """Test filtering by column values with various options.""" # Case 1: Case-sensitive filtering allowed_values = ["table1", "table3"] @@ -82,8 +82,8 @@ def test_filter_by_column_values(self): mock_instance = MagicMock() mock_sea_result_set_class.return_value = mock_instance - # Call filter_by_column_values on the table_name column (index 2) - result = ResultSetFilter.filter_by_column_values( + # Call _filter_json_result_set on the table_name column (index 2) + result = ResultSetFilter._filter_json_result_set( self.mock_sea_result_set, 2, allowed_values, case_sensitive=True ) @@ -109,8 +109,8 @@ def test_filter_by_column_values(self): mock_instance = MagicMock() mock_sea_result_set_class.return_value = mock_instance - # Call filter_by_column_values with case-insensitive matching - result = ResultSetFilter.filter_by_column_values( + # Call _filter_json_result_set with case-insensitive matching + result = ResultSetFilter._filter_json_result_set( self.mock_sea_result_set, 2, ["TABLE1", "TABLE3"], @@ -123,37 +123,34 @@ def test_filter_tables_by_type(self): # Case 1: Specific table types table_types = ["TABLE", "VIEW"] - with patch( - "databricks.sql.backend.sea.utils.filters.isinstance", return_value=True - ): - with patch.object( - ResultSetFilter, "filter_by_column_values" - ) as mock_filter: - ResultSetFilter.filter_tables_by_type( - self.mock_sea_result_set, table_types - ) - args, kwargs = mock_filter.call_args - self.assertEqual(args[0], self.mock_sea_result_set) - self.assertEqual(args[1], 5) # Table type column index - self.assertEqual(args[2], table_types) - self.assertEqual(kwargs.get("case_sensitive"), True) + # Mock results as JsonQueue (not CloudFetchQueue or ArrowQueue) + from databricks.sql.backend.sea.queue import JsonQueue + + self.mock_sea_result_set.results = JsonQueue([]) + + with patch.object(ResultSetFilter, "_filter_json_result_set") as mock_filter: + ResultSetFilter.filter_tables_by_type(self.mock_sea_result_set, table_types) + args, kwargs = mock_filter.call_args + self.assertEqual(args[0], self.mock_sea_result_set) + self.assertEqual(kwargs.get("column_index"), 5) # Table type column index + self.assertEqual(kwargs.get("allowed_values"), table_types) + self.assertEqual(kwargs.get("case_sensitive"), True) # Case 2: Default table types (None or empty list) - with patch( - "databricks.sql.backend.sea.utils.filters.isinstance", return_value=True - ): - with patch.object( - ResultSetFilter, "filter_by_column_values" - ) as mock_filter: - # Test with None - ResultSetFilter.filter_tables_by_type(self.mock_sea_result_set, None) - args, kwargs = mock_filter.call_args - self.assertEqual(args[2], ["TABLE", "VIEW", "SYSTEM TABLE"]) - - # Test with empty list - ResultSetFilter.filter_tables_by_type(self.mock_sea_result_set, []) - args, kwargs = mock_filter.call_args - self.assertEqual(args[2], ["TABLE", "VIEW", "SYSTEM TABLE"]) + with patch.object(ResultSetFilter, "_filter_json_result_set") as mock_filter: + # Test with None + ResultSetFilter.filter_tables_by_type(self.mock_sea_result_set, None) + args, kwargs = mock_filter.call_args + self.assertEqual( + kwargs.get("allowed_values"), ["TABLE", "VIEW", "SYSTEM TABLE"] + ) + + # Test with empty list + ResultSetFilter.filter_tables_by_type(self.mock_sea_result_set, []) + args, kwargs = mock_filter.call_args + self.assertEqual( + kwargs.get("allowed_values"), ["TABLE", "VIEW", "SYSTEM TABLE"] + ) if __name__ == "__main__": diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index 396ad906f..f604f2874 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -56,6 +56,29 @@ def sea_client(self, mock_http_client): http_headers=http_headers, auth_provider=auth_provider, ssl_options=ssl_options, + use_cloud_fetch=False, + ) + + return client + + @pytest.fixture + def sea_client_cloud_fetch(self, mock_http_client): + """Create a SeaDatabricksClient instance with cloud fetch enabled.""" + server_hostname = "test-server.databricks.com" + port = 443 + http_path = "/sql/warehouses/abc123" + http_headers = [("header1", "value1"), ("header2", "value2")] + auth_provider = AuthProvider() + ssl_options = SSLOptions() + + client = SeaDatabricksClient( + server_hostname=server_hostname, + port=port, + http_path=http_path, + http_headers=http_headers, + auth_provider=auth_provider, + ssl_options=ssl_options, + use_cloud_fetch=True, ) return client @@ -944,3 +967,74 @@ def test_get_columns(self, sea_client, sea_session_id, mock_cursor): cursor=mock_cursor, ) assert "Catalog name is required for get_columns" in str(excinfo.value) + + def test_get_tables_with_cloud_fetch( + self, sea_client_cloud_fetch, sea_session_id, mock_cursor + ): + """Test the get_tables method with cloud fetch enabled.""" + # Mock the execute_command method and ResultSetFilter + mock_result_set = Mock() + + with patch.object( + sea_client_cloud_fetch, "execute_command", return_value=mock_result_set + ) as mock_execute: + with patch( + "databricks.sql.backend.sea.utils.filters.ResultSetFilter" + ) as mock_filter: + mock_filter.filter_tables_by_type.return_value = mock_result_set + + # Call get_tables + result = sea_client_cloud_fetch.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + ) + + # Verify execute_command was called with use_cloud_fetch=True + mock_execute.assert_called_with( + operation="SHOW TABLES IN CATALOG test_catalog", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=True, # Should use True since client was created with use_cloud_fetch=True + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result == mock_result_set + + def test_get_schemas_with_cloud_fetch( + self, sea_client_cloud_fetch, sea_session_id, mock_cursor + ): + """Test the get_schemas method with cloud fetch enabled.""" + # Mock the execute_command method + mock_result_set = Mock() + with patch.object( + sea_client_cloud_fetch, "execute_command", return_value=mock_result_set + ) as mock_execute: + # Test with catalog name + result = sea_client_cloud_fetch.get_schemas( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + ) + + mock_execute.assert_called_with( + operation="SHOW SCHEMAS IN test_catalog", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=True, # Should use True since client was created with use_cloud_fetch=True + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result == mock_result_set