Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ PECO-2065 ] Create the async execution flow for the PySQL Connector #463

Merged
merged 9 commits into from
Nov 26, 2024
44 changes: 43 additions & 1 deletion src/databricks/sql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,6 +733,7 @@ def execute(
self,
operation: str,
parameters: Optional[TParameterCollection] = None,
async_op=False,
) -> "Cursor":
"""
Execute a query and wait for execution to complete.
jprakash-db marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -796,13 +797,15 @@ def execute(
cursor=self,
use_cloud_fetch=self.connection.use_cloud_fetch,
parameters=prepared_params,
async_op=async_op,
)
self.active_result_set = ResultSet(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

result set is not ready yet when async_op is True, why do you set this? It should be set in theget_execution_result

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The result set that is returned over here is empty and does not have any data.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know, but this will make the code confusing and I do not think it is is necessary.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did this to keep the same logical flow for both execute_async and execute. Like in execute the active_result_set has data and in execute_async since there is no data so it is none. Once data is available the active_result_set will again have data, so logically I felt it made sense

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the return result comes from the returned value from execute_command in the sync flow, which means it is not ready until the sync completes, this is why I said it is confusing as it should be set in the completion of the async operation, this is standard practice/ways for most of the async code (on_complete = (result) => { setResult(result) )

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got it what Jacky meant here. This is about confusing the users regarding the interface. We can keep the logic internally same, but don't need to keep the interface same for async and sync. For async what matters is the operationHandle. We can have different interface for both, but internally can reuse the code if possible.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gopalldb @jackyhu-db I have changed the code, based on these suggestions.

self.connection,
execute_response,
self.thrift_backend,
self.buffer_size_bytes,
self.arraysize,
async_op,
)

if execute_response.is_staging_operation:
Expand All @@ -812,6 +815,44 @@ def execute(

return self

def execute_async(
jprakash-db marked this conversation as resolved.
Show resolved Hide resolved
self,
operation: str,
parameters: Optional[TParameterCollection] = None,
):
return self.execute(operation, parameters, True)

def get_query_state(self):
jprakash-db marked this conversation as resolved.
Show resolved Hide resolved
self._check_not_closed()
return self.thrift_backend.get_query_state(self.active_op_handle)

def get_execution_result(self):
self._check_not_closed()

operation_state = self.get_query_state()
if operation_state == ttypes.TOperationState.FINISHED_STATE:
execute_response = self.thrift_backend.get_execution_result(
self.active_op_handle, self
)
self.active_result_set = ResultSet(
self.connection,
execute_response,
self.thrift_backend,
self.buffer_size_bytes,
self.arraysize,
)

if execute_response.is_staging_operation:
self._handle_staging_operation(
staging_allowed_local_path=self.thrift_backend.staging_allowed_local_path
)

return self
else:
raise Error(
f"get_execution_result failed with Operation status {operation_state}"
)

def executemany(self, operation, seq_of_parameters):
"""
Execute the operation once for every set of passed in parameters.
Expand Down Expand Up @@ -1097,6 +1138,7 @@ def __init__(
thrift_backend: ThriftBackend,
result_buffer_size_bytes: int = DEFAULT_RESULT_BUFFER_SIZE_BYTES,
arraysize: int = 10000,
async_op=False,
):
"""
A ResultSet manages the results of a single command.
Expand All @@ -1119,7 +1161,7 @@ def __init__(
self._arrow_schema_bytes = execute_response.arrow_schema_bytes
self._next_row_index = 0

if execute_response.arrow_queue:
if execute_response.arrow_queue or async_op:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why async_op depends on arrow? What if pyarrow is not installed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does not depend on arrow, currently in our codebase all data is named as arrow_queue be it arrow queue or column queue. So in this statement I am checking if data is already present or if it is an async operation don't do anything.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My point is ResultSet should not couple with async_op as I do not see it has something related to async. If you want to force to use arrow_queue, please use force_arrow_queue or similar parameter instead of async_op in the constructor.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed this

# In this case the server has taken the fast path and returned an initial batch of
# results
self.results = execute_response.arrow_queue
Expand Down
87 changes: 86 additions & 1 deletion src/databricks/sql/thrift_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,6 +769,63 @@ def _results_message_to_execute_response(self, resp, operation_state):
arrow_schema_bytes=schema_bytes,
)

def get_execution_result(self, op_handle, cursor):

assert op_handle is not None

req = ttypes.TFetchResultsReq(
operationHandle=ttypes.TOperationHandle(
op_handle.operationId,
op_handle.operationType,
False,
op_handle.modifiedRowCount,
),
maxRows=cursor.arraysize,
maxBytes=cursor.buffer_size_bytes,
orientation=ttypes.TFetchOrientation.FETCH_NEXT,
includeResultSetMetadata=True,
)

resp = self.make_request(self._client.FetchResults, req)

t_result_set_metadata_resp = resp.resultSetMetadata

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't need to check the state of response?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the client.py we check the result status and then go ahead with fetching


lz4_compressed = t_result_set_metadata_resp.lz4Compressed
is_staging_operation = t_result_set_metadata_resp.isStagingOperation
has_more_rows = resp.hasMoreRows
description = self._hive_schema_to_description(
t_result_set_metadata_resp.schema
)

schema_bytes = (
t_result_set_metadata_resp.arrowSchema
or self._hive_schema_to_arrow_schema(t_result_set_metadata_resp.schema)
.serialize()
.to_pybytes()
)

queue = ResultSetQueueFactory.build_queue(
row_set_type=resp.resultSetMetadata.resultFormat,
t_row_set=resp.results,
arrow_schema_bytes=schema_bytes,
max_download_threads=self.max_download_threads,
lz4_compressed=lz4_compressed,
description=description,
ssl_options=self._ssl_options,
)

return ExecuteResponse(
arrow_queue=queue,
status=resp.status,
has_been_closed_server_side=False,
has_more_rows=has_more_rows,
lz4_compressed=lz4_compressed,
is_staging_operation=is_staging_operation,
command_handle=op_handle,
description=description,
arrow_schema_bytes=schema_bytes,
)

def _wait_until_command_done(self, op_handle, initial_operation_status_resp):
if initial_operation_status_resp:
self._check_command_not_in_error_or_closed_state(
Expand All @@ -787,6 +844,12 @@ def _wait_until_command_done(self, op_handle, initial_operation_status_resp):
self._check_command_not_in_error_or_closed_state(op_handle, poll_resp)
return operation_state

def get_query_state(self, op_handle):
poll_resp = self._poll_for_status(op_handle)
operation_state = poll_resp.operationState
self._check_command_not_in_error_or_closed_state(op_handle, poll_resp)
return operation_state

@staticmethod
def _check_direct_results_for_error(t_spark_direct_results):
if t_spark_direct_results:
Expand Down Expand Up @@ -817,6 +880,7 @@ def execute_command(
cursor,
use_cloud_fetch=True,
parameters=[],
async_op=False,
):
assert session_handle is not None

Expand Down Expand Up @@ -846,7 +910,11 @@ def execute_command(
parameters=parameters,
)
resp = self.make_request(self._client.ExecuteStatement, req)
return self._handle_execute_response(resp, cursor)

if async_op:
return self._handle_execute_response_async(resp, cursor)
else:
return self._handle_execute_response(resp, cursor)

def get_catalogs(self, session_handle, max_rows, max_bytes, cursor):
assert session_handle is not None
Expand Down Expand Up @@ -945,6 +1013,23 @@ def _handle_execute_response(self, resp, cursor):

return self._results_message_to_execute_response(resp, final_operation_state)

def _handle_execute_response_async(self, resp, cursor):
cursor.active_op_handle = resp.operationHandle
self._check_direct_results_for_error(resp.directResults)
operation_status = resp.status.statusCode

return ExecuteResponse(
arrow_queue=None,
status=operation_status,
has_been_closed_server_side=None,
has_more_rows=None,
lz4_compressed=None,
is_staging_operation=None,
command_handle=resp.operationHandle,
description=None,
arrow_schema_bytes=None,
)

def fetch_results(
self,
op_handle,
Expand Down
21 changes: 21 additions & 0 deletions tests/e2e/test_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
compare_dbr_versions,
is_thrift_v5_plus,
)
from databricks.sql.thrift_api.TCLIService import ttypes
from tests.e2e.common.core_tests import CoreTestMixin, SmokeTestMixin
from tests.e2e.common.large_queries_mixin import LargeQueriesMixin
from tests.e2e.common.timestamp_tests import TimestampTestsMixin
Expand Down Expand Up @@ -175,6 +176,26 @@ def test_cloud_fetch(self):
for i in range(len(cf_result)):
assert cf_result[i] == noop_result[i]

def test_execute_async(self):
def isExecuting(operation_state):
return not operation_state or operation_state in [
ttypes.TOperationState.RUNNING_STATE,
ttypes.TOperationState.PENDING_STATE,
]

long_running_query = "SELECT COUNT(*) FROM RANGE(10000 * 16) x JOIN RANGE(10000) y ON FROM_UNIXTIME(x.id * y.id, 'yyyy-MM-dd') LIKE '%not%a%date%'"
with self.cursor() as cursor:
cursor.execute_async(long_running_query)

## Polling
while isExecuting(cursor.get_query_state()):
jprakash-db marked this conversation as resolved.
Show resolved Hide resolved
log.info("Polling the status in test_execute_async")

cursor.get_execution_result()
result = cursor.fetchall()

assert result[0].asDict() == {"count(1)": 0}


# Exclude Retry tests because they require specific setups, and LargeQueries too slow for core
# tests
Expand Down
Loading