Skip to content

SEA: Reduce network calls for synchronous commands #633

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jul 19, 2025
113 changes: 59 additions & 54 deletions src/databricks/sql/backend/sea/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@

if TYPE_CHECKING:
from databricks.sql.client import Cursor
from databricks.sql.backend.sea.result_set import SeaResultSet

from databricks.sql.backend.sea.result_set import SeaResultSet

from databricks.sql.backend.databricks_client import DatabricksClient
from databricks.sql.backend.types import (
Expand Down Expand Up @@ -332,7 +333,7 @@ def _extract_description_from_manifest(
return columns

def _results_message_to_execute_response(
self, response: GetStatementResponse
self, response: Union[ExecuteStatementResponse, GetStatementResponse]
) -> ExecuteResponse:
"""
Convert a SEA response to an ExecuteResponse and extract result data.
Expand Down Expand Up @@ -366,6 +367,27 @@ def _results_message_to_execute_response(

return execute_response

def _response_to_result_set(
self,
response: Union[ExecuteStatementResponse, GetStatementResponse],
cursor: Cursor,
) -> SeaResultSet:
"""
Convert a SEA response to a SeaResultSet.
"""

execute_response = self._results_message_to_execute_response(response)

return SeaResultSet(
connection=cursor.connection,
execute_response=execute_response,
sea_client=self,
result_data=response.result,
manifest=response.manifest,
buffer_size_bytes=cursor.buffer_size_bytes,
arraysize=cursor.arraysize,
)

def _check_command_not_in_failed_or_closed_state(
self, state: CommandState, command_id: CommandId
) -> None:
Expand All @@ -386,21 +408,24 @@ def _check_command_not_in_failed_or_closed_state(

def _wait_until_command_done(
self, response: ExecuteStatementResponse
) -> CommandState:
) -> Union[ExecuteStatementResponse, GetStatementResponse]:
"""
Wait until a command is done.
"""

state = response.status.state
command_id = CommandId.from_sea_statement_id(response.statement_id)
final_response: Union[ExecuteStatementResponse, GetStatementResponse] = response

state = final_response.status.state
command_id = CommandId.from_sea_statement_id(final_response.statement_id)

while state in [CommandState.PENDING, CommandState.RUNNING]:
time.sleep(self.POLL_INTERVAL_SECONDS)
state = self.get_query_state(command_id)
final_response = self._poll_query(command_id)
state = final_response.status.state

self._check_command_not_in_failed_or_closed_state(state, command_id)

return state
return final_response

def execute_command(
self,
Expand Down Expand Up @@ -506,8 +531,11 @@ def execute_command(
if async_op:
return None

self._wait_until_command_done(response)
return self.get_execution_result(command_id, cursor)
final_response: Union[ExecuteStatementResponse, GetStatementResponse] = response
if response.status.state != CommandState.SUCCEEDED:
final_response = self._wait_until_command_done(response)

return self._response_to_result_set(final_response, cursor)

def cancel_command(self, command_id: CommandId) -> None:
"""
Expand Down Expand Up @@ -559,18 +587,9 @@ def close_command(self, command_id: CommandId) -> None:
data=request.to_dict(),
)

def get_query_state(self, command_id: CommandId) -> CommandState:
def _poll_query(self, command_id: CommandId) -> GetStatementResponse:
"""
Get the state of a running query.

Args:
command_id: Command identifier

Returns:
CommandState: The current state of the command

Raises:
ValueError: If the command ID is invalid
Poll for the current command info.
"""

if command_id.backend_type != BackendType.SEA:
Expand All @@ -586,9 +605,25 @@ def get_query_state(self, command_id: CommandId) -> CommandState:
path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id),
data=request.to_dict(),
)

# Parse the response
response = GetStatementResponse.from_dict(response_data)

return response

def get_query_state(self, command_id: CommandId) -> CommandState:
"""
Get the state of a running query.

Args:
command_id: Command identifier

Returns:
CommandState: The current state of the command

Raises:
ProgrammingError: If the command ID is invalid
"""

response = self._poll_query(command_id)
return response.status.state

def get_execution_result(
Expand All @@ -610,38 +645,8 @@ def get_execution_result(
ValueError: If the command ID is invalid
"""

if command_id.backend_type != BackendType.SEA:
raise ValueError("Not a valid SEA command ID")

sea_statement_id = command_id.to_sea_statement_id()
if sea_statement_id is None:
raise ValueError("Not a valid SEA command ID")

# Create the request model
request = GetStatementRequest(statement_id=sea_statement_id)

# Get the statement result
response_data = self._http_client._make_request(
method="GET",
path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id),
data=request.to_dict(),
)
response = GetStatementResponse.from_dict(response_data)

# Create and return a SeaResultSet
from databricks.sql.backend.sea.result_set import SeaResultSet

execute_response = self._results_message_to_execute_response(response)

return SeaResultSet(
connection=cursor.connection,
execute_response=execute_response,
sea_client=self,
result_data=response.result,
manifest=response.manifest,
buffer_size_bytes=cursor.buffer_size_bytes,
arraysize=cursor.arraysize,
)
response = self._poll_query(command_id)
return self._response_to_result_set(response, cursor)

def get_chunk_links(
self, statement_id: str, chunk_index: int
Expand Down
2 changes: 1 addition & 1 deletion src/databricks/sql/backend/sea/result_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import logging

from databricks.sql.backend.sea.backend import SeaDatabricksClient
from databricks.sql.backend.sea.models.base import ResultData, ResultManifest
from databricks.sql.backend.sea.utils.conversion import SqlTypeConverter

Expand All @@ -15,6 +14,7 @@

if TYPE_CHECKING:
from databricks.sql.client import Connection
from databricks.sql.backend.sea.backend import SeaDatabricksClient
from databricks.sql.types import Row
from databricks.sql.backend.sea.queue import JsonQueue, SeaResultSetQueueFactory
from databricks.sql.backend.types import ExecuteResponse
Expand Down
9 changes: 3 additions & 6 deletions tests/unit/test_sea_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def test_command_execution_sync(
mock_http_client._make_request.return_value = execute_response

with patch.object(
sea_client, "get_execution_result", return_value="mock_result_set"
sea_client, "_response_to_result_set", return_value="mock_result_set"
) as mock_get_result:
result = sea_client.execute_command(
operation="SELECT 1",
Expand All @@ -242,9 +242,6 @@ def test_command_execution_sync(
enforce_embedded_schema_correctness=False,
)
assert result == "mock_result_set"
cmd_id_arg = mock_get_result.call_args[0][0]
assert isinstance(cmd_id_arg, CommandId)
assert cmd_id_arg.guid == "test-statement-123"

# Test with invalid session ID
with pytest.raises(ValueError) as excinfo:
Expand Down Expand Up @@ -332,7 +329,7 @@ def test_command_execution_advanced(
mock_http_client._make_request.side_effect = [initial_response, poll_response]

with patch.object(
sea_client, "get_execution_result", return_value="mock_result_set"
sea_client, "_response_to_result_set", return_value="mock_result_set"
) as mock_get_result:
with patch("time.sleep"):
result = sea_client.execute_command(
Expand Down Expand Up @@ -360,7 +357,7 @@ def test_command_execution_advanced(
dbsql_param = IntegerParameter(name="param1", value=1)
param = dbsql_param.as_tspark_param(named=True)

with patch.object(sea_client, "get_execution_result"):
with patch.object(sea_client, "_response_to_result_set"):
sea_client.execute_command(
operation="SELECT * FROM table WHERE col = :param1",
session_id=sea_session_id,
Expand Down
Loading