Skip to content

Commit ed4c10d

Browse files
Ronen HofferRonenHoffer
authored andcommitted
enable presto headers if needed
1 parent 443c4e6 commit ed4c10d

File tree

3 files changed

+98
-46
lines changed

3 files changed

+98
-46
lines changed

trino/client.py

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,8 @@ def __init__(
206206
max_attempts: int = MAX_ATTEMPTS,
207207
request_timeout: Union[float, Tuple[float, float]] = constants.DEFAULT_REQUEST_TIMEOUT,
208208
handle_retry=exceptions.RetryWithExponentialBackoff(),
209-
verify: bool = True
209+
verify: bool = True,
210+
presto_headers = False,
210211
) -> None:
211212
self._client_session = ClientSession(
212213
catalog,
@@ -227,6 +228,7 @@ def __init__(
227228
else:
228229
self._http_session = self.http.Session()
229230
self._http_session.verify = verify
231+
self._headers_name = constants.get_headers(presto_headers=presto_headers)
230232
self._http_session.headers.update(self.http_headers)
231233
self._exceptions = self.HTTP_EXCEPTIONS
232234
self._auth = auth
@@ -252,28 +254,20 @@ def transaction_id(self, value):
252254

253255
@property
254256
def http_headers(self) -> Dict[str, str]:
255-
headers = {}
256-
257-
headers[constants.HEADER_CATALOG] = self._client_session.catalog
258-
headers[constants.HEADER_SCHEMA] = self._client_session.schema
259-
headers[constants.HEADER_SOURCE] = self._client_session.source
260-
headers[constants.HEADER_USER] = self._client_session.user
261-
262-
headers[constants.HEADER_SESSION] = ",".join(
263-
# ``name`` must not contain ``=``
264-
"{}={}".format(name, value)
265-
for name, value in self._client_session.properties.items()
266-
)
257+
headers = {
258+
self._headers_name.catalog: self._client_session.catalog,
259+
self._headers_name.schema: self._client_session.schema,
260+
self._headers_name.source: self._client_session.source,
261+
self._headers_name.user: self._client_session.user,
262+
self._headers_name.session: ",".join("{}={}".format(name, value) for name, value in self._client_session.properties.items()),
263+
}
267264

268265
# merge custom http headers
269266
for key in self._client_session.headers:
270267
if key in headers.keys():
271268
raise ValueError("cannot override reserved HTTP header {}".format(key))
272269
headers.update(self._client_session.headers)
273-
274-
transaction_id = self._client_session.transaction_id
275-
headers[constants.HEADER_TRANSACTION] = transaction_id
276-
270+
headers[self._headers_name.transaction] = self._client_session.transaction_id
277271
return headers
278272

279273
@property
@@ -386,18 +380,19 @@ def process(self, http_response) -> TrinoStatus:
386380
http_response.encoding = "utf-8"
387381
response = http_response.json()
388382
logger.debug("HTTP %s: %s", http_response.status_code, response)
383+
389384
if "error" in response:
390385
raise self._process_error(response["error"], response.get("id"))
391386

392-
if constants.HEADER_CLEAR_SESSION in http_response.headers:
387+
if self._headers_name.clear_session in http_response.headers:
393388
for prop in get_header_values(
394-
http_response.headers, constants.HEADER_CLEAR_SESSION
389+
http_response.headers, self._headers_name.clear_session
395390
):
396391
self._client_session.properties.pop(prop, None)
397392

398-
if constants.HEADER_SET_SESSION in http_response.headers:
393+
if self._headers_name.set_session in http_response.headers:
399394
for key, value in get_session_property_values(
400-
http_response.headers, constants.HEADER_SET_SESSION
395+
http_response.headers, self._headers_name.set_session
401396
):
402397
self._client_session.properties[key] = value
403398

trino/constants.py

Lines changed: 69 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,19 +26,72 @@
2626

2727
URL_STATEMENT_PATH = "/v1/statement"
2828

29-
HEADER_CATALOG = "X-Trino-Catalog"
30-
HEADER_SCHEMA = "X-Trino-Schema"
31-
HEADER_SOURCE = "X-Trino-Source"
32-
HEADER_USER = "X-Trino-User"
33-
HEADER_CLIENT_INFO = "X-Trino-Client-Info"
34-
35-
HEADER_SESSION = "X-Trino-Session"
36-
HEADER_SET_SESSION = "X-Trino-Set-Session"
37-
HEADER_CLEAR_SESSION = "X-Trino-Clear-Session"
38-
39-
HEADER_STARTED_TRANSACTION = "X-Trino-Started-Transaction-Id"
40-
HEADER_TRANSACTION = "X-Trino-Transaction-Id"
41-
42-
HEADER_PREPARED_STATEMENT = 'X-Trino-Prepared-Statement'
43-
HEADER_ADDED_PREPARE = 'X-Trino-Added-Prepare'
44-
HEADER_DEALLOCATED_PREPARE = 'X-Trino-Deallocated-Prepare'
29+
class Headers:
30+
31+
@property
32+
def base(self):
33+
return f'X-{self.NAME}'
34+
35+
@property
36+
def catalog(self):
37+
return f'{self.base}-Catalog'
38+
39+
@property
40+
def schema(self):
41+
return f'{self.base}-Schema'
42+
43+
@property
44+
def source(self):
45+
return f'{self.base}-Source'
46+
47+
@property
48+
def user(self):
49+
return f'{self.base}-User'
50+
51+
@property
52+
def client_info(self):
53+
return f'{self.base}-Client-Info'
54+
55+
@property
56+
def session(self):
57+
return f'{self.base}-Session'
58+
59+
@property
60+
def set_session(self):
61+
return f'{self.base}-Set-Session'
62+
63+
@property
64+
def clear_session(self):
65+
return f'{self.base}-Clear-Session'
66+
67+
@property
68+
def started_transaction_id(self):
69+
return f'{self.base}-Started-Transaction-Id'
70+
71+
@property
72+
def transaction(self):
73+
return f'{self.base}-Transaction-Id'
74+
75+
@property
76+
def prepared_statement(self):
77+
return f'{self.base}-Prepared-Statement'
78+
79+
@property
80+
def added_prepare(self):
81+
return f'{self.base}-Added-Prepare'
82+
83+
@property
84+
def deallocated_prepare(self):
85+
return f'{self.base}-Deallocated-Prepare'
86+
87+
88+
class PrestoHeaders(Headers):
89+
NAME = 'Presto'
90+
91+
92+
class TrinoHeaders(Headers):
93+
NAME = 'Trino'
94+
95+
96+
def get_headers(presto_headers=False):
97+
return PrestoHeaders() if presto_headers else TrinoHeaders()

trino/dbapi.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,8 @@ def __init__(
7676
max_attempts=constants.DEFAULT_MAX_ATTEMPTS,
7777
request_timeout=constants.DEFAULT_REQUEST_TIMEOUT,
7878
isolation_level=IsolationLevel.AUTOCOMMIT,
79-
verify=True
79+
verify=True,
80+
presto_headers=False
8081
):
8182
self.host = host
8283
self.port = port
@@ -98,6 +99,7 @@ def __init__(
9899
self._isolation_level = isolation_level
99100
self._request = None
100101
self._transaction = None
102+
self._presto_headers = presto_headers
101103

102104
@property
103105
def isolation_level(self):
@@ -157,6 +159,7 @@ def _create_request(self):
157159
self.redirect_handler,
158160
self.max_attempts,
159161
self.request_timeout,
162+
presto_headers=self._presto_headers
160163
)
161164

162165
def cursor(self):
@@ -189,6 +192,7 @@ def __init__(self, connection, request):
189192
self.arraysize = 1
190193
self._iterator = None
191194
self._query = None
195+
self._headers_name = constants.get_headers(presto_headers=self._connection._presto_headers)
192196

193197
def __iter__(self):
194198
return self._iterator
@@ -267,9 +271,9 @@ def _prepare_statement(self, operation, statement_name):
267271
# until there are no more results
268272
for _ in result:
269273
response_headers = result.response_headers
270-
271-
if constants.HEADER_ADDED_PREPARE in response_headers:
272-
return response_headers[constants.HEADER_ADDED_PREPARE]
274+
275+
if self._headers_name.added_prepare in response_headers:
276+
return response_headers[self._headers_name.added_prepare]
273277

274278
raise trino.exceptions.FailedToObtainAddedPrepareHeader
275279

@@ -344,17 +348,17 @@ def _deallocate_prepare_statement(self, added_prepare_header, statement_name):
344348
query = trino.client.TrinoQuery(copy.deepcopy(self._request), sql=sql)
345349
result = query.execute(
346350
additional_http_headers={
347-
constants.HEADER_PREPARED_STATEMENT: added_prepare_header
351+
self._headers_name.prepared_statement: added_prepare_header
348352
}
349353
)
350354

351355
# Iterate until the 'X-Trino-Deallocated-Prepare' header is found or
352356
# until there are no more results
353357
for _ in result:
354358
response_headers = result.response_headers
355-
356-
if constants.HEADER_DEALLOCATED_PREPARE in response_headers:
357-
return response_headers[constants.HEADER_DEALLOCATED_PREPARE]
359+
360+
if self._headers_name.deallocated_prepare in response_headers:
361+
return response_headers[self._headers_name.deallocated_prepare]
358362

359363
raise trino.exceptions.FailedToObtainDeallocatedPrepareHeader
360364

@@ -382,7 +386,7 @@ def execute(self, operation, params=None):
382386
)
383387
result = self._query.execute(
384388
additional_http_headers={
385-
constants.HEADER_PREPARED_STATEMENT: added_prepare_header
389+
self._headers_name.prepared_statement: added_prepare_header
386390
}
387391
)
388392
finally:

0 commit comments

Comments
 (0)