@@ -206,7 +206,8 @@ def __init__(
206206 max_attempts : int = MAX_ATTEMPTS ,
207207 request_timeout : Union [float , Tuple [float , float ]] = constants .DEFAULT_REQUEST_TIMEOUT ,
208208 handle_retry = exceptions .RetryWithExponentialBackoff (),
209- verify : bool = True
209+ verify : bool = True ,
210+ presto_headers = False ,
210211 ) -> None :
211212 self ._client_session = ClientSession (
212213 catalog ,
@@ -227,6 +228,7 @@ def __init__(
227228 else :
228229 self ._http_session = self .http .Session ()
229230 self ._http_session .verify = verify
231+ self ._headers_name = constants .get_headers (presto_headers = presto_headers )
230232 self ._http_session .headers .update (self .http_headers )
231233 self ._exceptions = self .HTTP_EXCEPTIONS
232234 self ._auth = auth
@@ -252,28 +254,20 @@ def transaction_id(self, value):
252254
253255 @property
254256 def http_headers (self ) -> Dict [str , str ]:
255- headers = {}
256-
257- headers [constants .HEADER_CATALOG ] = self ._client_session .catalog
258- headers [constants .HEADER_SCHEMA ] = self ._client_session .schema
259- headers [constants .HEADER_SOURCE ] = self ._client_session .source
260- headers [constants .HEADER_USER ] = self ._client_session .user
261-
262- headers [constants .HEADER_SESSION ] = "," .join (
263- # ``name`` must not contain ``=``
264- "{}={}" .format (name , value )
265- for name , value in self ._client_session .properties .items ()
266- )
257+ headers = {
258+ self ._headers_name .catalog : self ._client_session .catalog ,
259+ self ._headers_name .schema : self ._client_session .schema ,
260+ self ._headers_name .source : self ._client_session .source ,
261+ self ._headers_name .user : self ._client_session .user ,
262+ self ._headers_name .session : "," .join ("{}={}" .format (name , value ) for name , value in self ._client_session .properties .items ()),
263+ }
267264
268265 # merge custom http headers
269266 for key in self ._client_session .headers :
270267 if key in headers .keys ():
271268 raise ValueError ("cannot override reserved HTTP header {}" .format (key ))
272269 headers .update (self ._client_session .headers )
273-
274- transaction_id = self ._client_session .transaction_id
275- headers [constants .HEADER_TRANSACTION ] = transaction_id
276-
270+ headers [self ._headers_name .transaction ] = self ._client_session .transaction_id
277271 return headers
278272
279273 @property
@@ -386,18 +380,19 @@ def process(self, http_response) -> TrinoStatus:
386380 http_response .encoding = "utf-8"
387381 response = http_response .json ()
388382 logger .debug ("HTTP %s: %s" , http_response .status_code , response )
383+
389384 if "error" in response :
390385 raise self ._process_error (response ["error" ], response .get ("id" ))
391386
392- if constants . HEADER_CLEAR_SESSION in http_response .headers :
387+ if self . _headers_name . clear_session in http_response .headers :
393388 for prop in get_header_values (
394- http_response .headers , constants . HEADER_CLEAR_SESSION
389+ http_response .headers , self . _headers_name . clear_session
395390 ):
396391 self ._client_session .properties .pop (prop , None )
397392
398- if constants . HEADER_SET_SESSION in http_response .headers :
393+ if self . _headers_name . set_session in http_response .headers :
399394 for key , value in get_session_property_values (
400- http_response .headers , constants . HEADER_SET_SESSION
395+ http_response .headers , self . _headers_name . set_session
401396 ):
402397 self ._client_session .properties [key ] = value
403398
0 commit comments