Skip to content

Commit

Permalink
sdk/python: Manually handle redirects to close initial session
Browse files Browse the repository at this point in the history
Signed-off-by: Aaron Wilson <[email protected]>
  • Loading branch information
aaronnw committed Oct 28, 2024
1 parent 125e8a4 commit 06c0881
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 25 deletions.
7 changes: 7 additions & 0 deletions python/aistore/sdk/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
HEADER_USER_AGENT = "User-Agent"
HEADER_CONTENT_TYPE = "Content-Type"
HEADER_CONTENT_LENGTH = "Content-Length"
HEADER_LOCATION = "Location"
# Standard Header Values
USER_AGENT_BASE = "ais/python"
JSON_CONTENT_TYPE = "application/json"
Expand Down Expand Up @@ -133,6 +134,12 @@
STATUS_OK = 200
STATUS_BAD_REQUEST = 400
STATUS_PARTIAL_CONTENT = 206
STATUS_REDIRECT_TMP = 307
STATUS_REDIRECT_PERM = 301

# Protocol
HTTP = "http://"
HTTPS = "https://"

# Environment Variables
AIS_CLIENT_CA = "AIS_CLIENT_CA"
Expand Down
63 changes: 56 additions & 7 deletions python/aistore/sdk/request_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
USER_AGENT_BASE,
HEADER_CONTENT_TYPE,
HEADER_AUTHORIZATION,
HTTPS,
HEADER_LOCATION,
STATUS_REDIRECT_PERM,
STATUS_REDIRECT_TMP,
)
from aistore.sdk.session_manager import SessionManager
from aistore.sdk.utils import raise_ais_error, handle_errors, decode_response
Expand Down Expand Up @@ -145,17 +149,62 @@ def request(
if self.token:
headers[HEADER_AUTHORIZATION] = f"Bearer {self.token}"

resp = self.session_manager.session.request(
method,
url,
headers=headers,
timeout=self._timeout,
**kwargs,
)
if url.startswith(HTTPS) and "data" in kwargs:
resp = self._request_with_manual_redirect(method, url, headers, **kwargs)
else:
resp = self._session_request(method, url, headers, **kwargs)

if resp.status_code < 200 or resp.status_code >= 300:
handle_errors(resp, self._error_handler)

return resp

def _request_with_manual_redirect(
self, method: str, url: str, headers, **kwargs
) -> Response:
"""
Make a request to the proxy, close the session, and use a new session to make a request to the redirected
target.
This exists because the current implementation of `requests` does not seem to handle a 307 redirect
properly from a server with TLS enabled with data in the request, and will error with the following on the
initial connection to the proxy:
SSLEOFError(8, 'EOF occurred in violation of protocol (_ssl.c:2406)')
Instead, this implementation will not send the data to the proxy, and only use it to access the proper target.
Args:
method (str): HTTP method (e.g. POST, GET, PUT, DELETE).
url (str): Initial AIS url.
headers (dict): Extra headers to be passed with the request. Content-Type and User-Agent will be overridden.
**kwargs (optional): Optional keyword arguments to pass with the call to request.
Returns:
Final response from AIS target
"""
# Do not include data payload in the initial request to the proxy
proxy_request_kwargs = {
"headers": headers,
"allow_redirects": False,
**{k: v for k, v in kwargs.items() if k != "data"},
}

# Request to proxy, which should redirect
resp = self.session_manager.session.request(method, url, **proxy_request_kwargs)
self.session_manager.session.close()
if resp.status_code in (STATUS_REDIRECT_PERM, STATUS_REDIRECT_TMP):
target_url = resp.headers.get(HEADER_LOCATION)
# Redirected request to target
resp = self._session_request(method, target_url, headers, **kwargs)
return resp

def _session_request(self, method, url, headers, **kwargs) -> Response:
request_kwargs = {"headers": headers, **kwargs}
if self._timeout is not None:
request_kwargs["timeout"] = self._timeout

return self.session_manager.session.request(method, url, **request_kwargs)

def get_full_url(self, path: str, params: Dict[str, Any]) -> str:
"""
Get the full URL to the path on the cluster with the given parameters.
Expand Down
4 changes: 2 additions & 2 deletions python/aistore/sdk/session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from requests.adapters import HTTPAdapter
from urllib3 import Retry

from aistore.sdk.const import AIS_CLIENT_CA
from aistore.sdk.const import AIS_CLIENT_CA, HTTPS, HTTP

DEFAULT_RETRY = Retry(total=6, connect=3, backoff_factor=1)

Expand Down Expand Up @@ -89,6 +89,6 @@ def _create_session(self) -> Session:
"""
request_session = Session()
self._set_session_verification(request_session)
for protocol in ("http://", "https://"):
for protocol in (HTTP, HTTPS):
request_session.mount(protocol, HTTPAdapter(max_retries=self._retry))
return request_session
75 changes: 59 additions & 16 deletions python/tests/unit/sdk/test_request_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import unittest
from unittest.mock import patch, Mock
from unittest.mock import patch, Mock, call

from requests import Response, Session

Expand Down Expand Up @@ -85,7 +85,6 @@ def test_request_deserialize(self, mock_decode):
method,
expected_url,
headers=self.request_headers,
timeout=None,
keyword=custom_kw,
)
mock_decode.assert_called_with(str, self.mock_response)
Expand Down Expand Up @@ -118,13 +117,7 @@ def test_request(self, test_case):
res = self.default_request_client.request(
method, path, headers=extra_headers, keyword=extra_kw_arg
)
self.mock_session.request.assert_called_with(
method,
req_url,
headers=self.request_headers,
timeout=timeout,
keyword=extra_kw_arg,
)
self._request_assert(method, req_url, timeout, extra_kw_arg)
self.assertEqual(self.mock_response, res)

for response_code in [199, 300]:
Expand All @@ -137,16 +130,66 @@ def test_request(self, test_case):
headers=extra_headers,
keyword=extra_kw_arg,
)
self.mock_session.request.assert_called_with(
method,
req_url,
headers=self.request_headers,
timeout=timeout,
keyword=extra_kw_arg,
)
self._request_assert(method, req_url, timeout, extra_kw_arg)
self.assertEqual(self.mock_response, res)
mock_handle_err.assert_called_once()

def _request_assert(self, method, url, timeout, expected_kw):
if timeout:
self.mock_session.request.assert_called_with(
method,
url,
headers=self.request_headers,
timeout=timeout,
keyword=expected_kw,
)
else:
self.mock_session.request.assert_called_with(
method,
url,
headers=self.request_headers,
keyword=expected_kw,
)

def test_request_https_data(self):
method = "request_method"
path = "request_path"
extra_kw_arg = "arg"
data = "my_data"
expected_url = self.endpoint + "/v1/" + path
redirect_url = "target" + "/v1/" + path

redirect_response = Mock(spec=Response)
redirect_response.status_code = 307
redirect_response.headers = {"Location": redirect_url}
self.mock_response.status_code = 200
self.mock_session.request.side_effect = [redirect_response, self.mock_response]

response = self.default_request_client.request(
method, path, data=data, keyword=extra_kw_arg
)

self.assertEqual(self.mock_response, response)

expected_proxy_call = call(
method,
expected_url,
headers=self.request_headers,
allow_redirects=False,
keyword=extra_kw_arg,
)
expected_target_call = call(
method,
redirect_url,
headers=self.request_headers,
keyword=extra_kw_arg,
data=data,
)

self.mock_session.request.assert_has_calls(
[expected_proxy_call, expected_target_call]
)

def test_get_full_url(self):
path = "/testpath/to_obj"
params = {"p1key": "p1val", "p2key": "p2val"}
Expand Down

0 comments on commit 06c0881

Please sign in to comment.