Skip to content

Commit

Permalink
Resolved merge conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
jprakash-db committed Dec 10, 2024
2 parents 87b1251 + 680b3b6 commit e09a880
Show file tree
Hide file tree
Showing 10 changed files with 373 additions and 84 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# Release History


# 4.0.0

- Split the connector into two separate packages: `databricks-sql-connector` and `databricks-sqlalchemy`. The `databricks-sql-connector` package contains the core functionality of the connector, while the `databricks-sqlalchemy` package contains the SQLAlchemy dialect for the connector.
- Pyarrow dependency is now optional in `databricks-sql-connector`. Users needing arrow are supposed to explicitly install pyarrow
- Pyarrow dependency is now optional in `databricks-sql-connector`. Users needing arrow are supposed to explicitly install pyarrow

# 3.6.0 (2024-10-25)

Expand Down
1 change: 1 addition & 0 deletions docs/parameters.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ See `examples/parameters.py` in this repository for a runnable demo.

- A query executed with native parameters can contain at most 255 parameter markers
- The maximum size of all parameterized values cannot exceed 1MB
- For volume operations such as PUT, native parameters are not supported

## SQL Syntax

Expand Down
143 changes: 96 additions & 47 deletions poetry.lock

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ lz4 = "^4.0.2"
requests = "^2.18.1"
oauthlib = "^3.1.0"
numpy = [
{ version = "^1.16.6", python = ">=3.8,<3.11" },
{ version = "^1.23.4", python = ">=3.11" },
{ version = ">=1.16.6", python = ">=3.8,<3.11" },
{ version = ">=1.23.4", python = ">=3.11" },
]
openpyxl = "^3.0.10"
urllib3 = ">=1.26"
Expand Down
26 changes: 16 additions & 10 deletions src/databricks/sql/auth/retry.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import random
import time
import typing
from enum import Enum
Expand Down Expand Up @@ -285,25 +286,30 @@ def sleep_for_retry(self, response: BaseHTTPResponse) -> bool:
"""
retry_after = self.get_retry_after(response)
if retry_after:
backoff = self.get_backoff_time()
proposed_wait = max(backoff, retry_after)
self.check_proposed_wait(proposed_wait)
time.sleep(proposed_wait)
return True
proposed_wait = retry_after
else:
proposed_wait = self.get_backoff_time()

return False
proposed_wait = min(proposed_wait, self.delay_max)
self.check_proposed_wait(proposed_wait)
time.sleep(proposed_wait)
return True

def get_backoff_time(self) -> float:
"""Calls urllib3's built-in get_backoff_time.
"""
This method implements the exponential backoff algorithm to calculate the delay between retries.
Never returns a value larger than self.delay_max
A MaxRetryDurationError will be raised if the calculated backoff would exceed self.max_attempts_duration
Note: within urllib3, a backoff is only calculated in cases where a Retry-After header is not present
in the previous unsuccessful request and `self.respect_retry_after_header` is True (which is always true)
:return:
"""

proposed_backoff = super().get_backoff_time()
current_attempt = self.stop_after_attempts_count - int(self.total or 0)
proposed_backoff = (2**current_attempt) * self.delay_min
if self.backoff_jitter != 0.0:
proposed_backoff += random.random() * self.backoff_jitter

proposed_backoff = min(proposed_backoff, self.delay_max)
self.check_proposed_wait(proposed_backoff)

Expand Down
105 changes: 105 additions & 0 deletions src/databricks/sql/client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import time
from typing import Dict, Tuple, List, Optional, Any, Union, Sequence

import pandas
Expand Down Expand Up @@ -47,6 +48,7 @@

from databricks.sql.thrift_api.TCLIService.ttypes import (
TSparkParameter,
TOperationState,
)


Expand Down Expand Up @@ -437,6 +439,8 @@ def __init__(
self.escaper = ParamEscaper()
self.lastrowid = None

self.ASYNC_DEFAULT_POLLING_INTERVAL = 2

# The ideal return type for this method is perhaps Self, but that was not added until 3.11, and we support pre-3.11 pythons, currently.
def __enter__(self) -> "Cursor":
return self
Expand Down Expand Up @@ -803,6 +807,7 @@ def execute(
cursor=self,
use_cloud_fetch=self.connection.use_cloud_fetch,
parameters=prepared_params,
async_op=False,
)
self.active_result_set = ResultSet(
self.connection,
Expand All @@ -819,6 +824,106 @@ def execute(

return self

def execute_async(
self,
operation: str,
parameters: Optional[TParameterCollection] = None,
) -> "Cursor":
"""
Execute a query and do not wait for it to complete and just move ahead
:param operation:
:param parameters:
:return:
"""
param_approach = self._determine_parameter_approach(parameters)
if param_approach == ParameterApproach.NONE:
prepared_params = NO_NATIVE_PARAMS
prepared_operation = operation

elif param_approach == ParameterApproach.INLINE:
prepared_operation, prepared_params = self._prepare_inline_parameters(
operation, parameters
)
elif param_approach == ParameterApproach.NATIVE:
normalized_parameters = self._normalize_tparametercollection(parameters)
param_structure = self._determine_parameter_structure(normalized_parameters)
transformed_operation = transform_paramstyle(
operation, normalized_parameters, param_structure
)
prepared_operation, prepared_params = self._prepare_native_parameters(
transformed_operation, normalized_parameters, param_structure
)

self._check_not_closed()
self._close_and_clear_active_result_set()
self.thrift_backend.execute_command(
operation=prepared_operation,
session_handle=self.connection._session_handle,
max_rows=self.arraysize,
max_bytes=self.buffer_size_bytes,
lz4_compression=self.connection.lz4_compression,
cursor=self,
use_cloud_fetch=self.connection.use_cloud_fetch,
parameters=prepared_params,
async_op=True,
)

return self

def get_query_state(self) -> "TOperationState":
"""
Get the state of the async executing query or basically poll the status of the query
:return:
"""
self._check_not_closed()
return self.thrift_backend.get_query_state(self.active_op_handle)

def get_async_execution_result(self):
"""
Checks for the status of the async executing query and fetches the result if the query is finished
Otherwise it will keep polling the status of the query till there is a Not pending state
:return:
"""
self._check_not_closed()

def is_executing(operation_state) -> "bool":
return not operation_state or operation_state in [
ttypes.TOperationState.RUNNING_STATE,
ttypes.TOperationState.PENDING_STATE,
]

while is_executing(self.get_query_state()):
# Poll after some default time
time.sleep(self.ASYNC_DEFAULT_POLLING_INTERVAL)

operation_state = self.get_query_state()
if operation_state == ttypes.TOperationState.FINISHED_STATE:
execute_response = self.thrift_backend.get_execution_result(
self.active_op_handle, self
)
self.active_result_set = ResultSet(
self.connection,
execute_response,
self.thrift_backend,
self.buffer_size_bytes,
self.arraysize,
)

if execute_response.is_staging_operation:
self._handle_staging_operation(
staging_allowed_local_path=self.thrift_backend.staging_allowed_local_path
)

return self
else:
raise Error(
f"get_execution_result failed with Operation status {operation_state}"
)

def executemany(self, operation, seq_of_parameters):
"""
Execute the operation once for every set of passed in parameters.
Expand Down
80 changes: 77 additions & 3 deletions src/databricks/sql/thrift_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import threading
from typing import List, Union

from databricks.sql.thrift_api.TCLIService.ttypes import TOperationState

try:
import pyarrow
except ImportError:
Expand Down Expand Up @@ -64,8 +66,8 @@
# - 900s attempts-duration lines up w ODBC/JDBC drivers (for cluster startup > 10 mins)
_retry_policy = { # (type, default, min, max)
"_retry_delay_min": (float, 1, 0.1, 60),
"_retry_delay_max": (float, 60, 5, 3600),
"_retry_stop_after_attempts_count": (int, 30, 1, 60),
"_retry_delay_max": (float, 30, 5, 3600),
"_retry_stop_after_attempts_count": (int, 5, 1, 60),
"_retry_stop_after_attempts_duration": (float, 900, 1, 86400),
"_retry_delay_default": (float, 5, 1, 60),
}
Expand Down Expand Up @@ -769,6 +771,63 @@ def _results_message_to_execute_response(self, resp, operation_state):
arrow_schema_bytes=schema_bytes,
)

def get_execution_result(self, op_handle, cursor):

assert op_handle is not None

req = ttypes.TFetchResultsReq(
operationHandle=ttypes.TOperationHandle(
op_handle.operationId,
op_handle.operationType,
False,
op_handle.modifiedRowCount,
),
maxRows=cursor.arraysize,
maxBytes=cursor.buffer_size_bytes,
orientation=ttypes.TFetchOrientation.FETCH_NEXT,
includeResultSetMetadata=True,
)

resp = self.make_request(self._client.FetchResults, req)

t_result_set_metadata_resp = resp.resultSetMetadata

lz4_compressed = t_result_set_metadata_resp.lz4Compressed
is_staging_operation = t_result_set_metadata_resp.isStagingOperation
has_more_rows = resp.hasMoreRows
description = self._hive_schema_to_description(
t_result_set_metadata_resp.schema
)

schema_bytes = (
t_result_set_metadata_resp.arrowSchema
or self._hive_schema_to_arrow_schema(t_result_set_metadata_resp.schema)
.serialize()
.to_pybytes()
)

queue = ResultSetQueueFactory.build_queue(
row_set_type=resp.resultSetMetadata.resultFormat,
t_row_set=resp.results,
arrow_schema_bytes=schema_bytes,
max_download_threads=self.max_download_threads,
lz4_compressed=lz4_compressed,
description=description,
ssl_options=self._ssl_options,
)

return ExecuteResponse(
arrow_queue=queue,
status=resp.status,
has_been_closed_server_side=False,
has_more_rows=has_more_rows,
lz4_compressed=lz4_compressed,
is_staging_operation=is_staging_operation,
command_handle=op_handle,
description=description,
arrow_schema_bytes=schema_bytes,
)

def _wait_until_command_done(self, op_handle, initial_operation_status_resp):
if initial_operation_status_resp:
self._check_command_not_in_error_or_closed_state(
Expand All @@ -787,6 +846,12 @@ def _wait_until_command_done(self, op_handle, initial_operation_status_resp):
self._check_command_not_in_error_or_closed_state(op_handle, poll_resp)
return operation_state

def get_query_state(self, op_handle) -> "TOperationState":
poll_resp = self._poll_for_status(op_handle)
operation_state = poll_resp.operationState
self._check_command_not_in_error_or_closed_state(op_handle, poll_resp)
return operation_state

@staticmethod
def _check_direct_results_for_error(t_spark_direct_results):
if t_spark_direct_results:
Expand Down Expand Up @@ -817,6 +882,7 @@ def execute_command(
cursor,
use_cloud_fetch=True,
parameters=[],
async_op=False,
):
assert session_handle is not None

Expand Down Expand Up @@ -846,7 +912,11 @@ def execute_command(
parameters=parameters,
)
resp = self.make_request(self._client.ExecuteStatement, req)
return self._handle_execute_response(resp, cursor)

if async_op:
self._handle_execute_response_async(resp, cursor)
else:
return self._handle_execute_response(resp, cursor)

def get_catalogs(self, session_handle, max_rows, max_bytes, cursor):
assert session_handle is not None
Expand Down Expand Up @@ -945,6 +1015,10 @@ def _handle_execute_response(self, resp, cursor):

return self._results_message_to_execute_response(resp, final_operation_state)

def _handle_execute_response_async(self, resp, cursor):
cursor.active_op_handle = resp.operationHandle
self._check_direct_results_for_error(resp.directResults)

def fetch_results(
self,
op_handle,
Expand Down
8 changes: 4 additions & 4 deletions tests/e2e/common/retry_test_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def test_retry_max_count_not_exceeded(self):
def test_retry_exponential_backoff(self):
"""GIVEN the retry policy is configured for reasonable exponential backoff
WHEN the server sends nothing but 429 responses with retry-afters
THEN the connector will use those retry-afters as a floor
THEN the connector will use those retry-afters values as delay
"""
retry_policy = self._retry_policy.copy()
retry_policy["_retry_delay_min"] = 1
Expand All @@ -191,10 +191,10 @@ def test_retry_exponential_backoff(self):
assert isinstance(cm.value.args[1], MaxRetryDurationError)

# With setting delay_min to 1, the expected retry delays should be:
# 3, 3, 4
# The first 2 retries are allowed, the 3rd retry puts the total duration over the limit
# 3, 3, 3, 3
# The first 3 retries are allowed, the 4th retry puts the total duration over the limit
# of 10 seconds
assert mock_obj.return_value.getresponse.call_count == 3
assert mock_obj.return_value.getresponse.call_count == 4
assert duration > 6

# Should be less than 7, but this is a safe margin for CI/CD slowness
Expand Down
Loading

0 comments on commit e09a880

Please sign in to comment.