diff --git a/setup.py b/setup.py index ebc484dd..10fa308f 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ 'pytest-asyncio==0.21.0', 'aiohttp>=3.8.4', 'aiofiles>=23.1.0', - 'requests-kerberos>=0.14.0' + 'requests-kerberos>=0.15.0' ] INSTALL_REQUIRES = [ @@ -48,7 +48,7 @@ 'uwsgi': ['uwsgi>=2.0.0'], 'cpphash': ['mmh3cffi==0.2.1'], 'asyncio': ['aiohttp>=3.8.4', 'aiofiles>=23.1.0'], - 'kerberos': ['requests-kerberos>=0.14.0'] + 'kerberos': ['requests-kerberos>=0.15.0'] }, setup_requires=['pytest-runner', 'pluggy==1.0.0;python_version<"3.8"'], classifiers=[ diff --git a/splitio/api/client.py b/splitio/api/client.py index b255baff..02eff8c2 100644 --- a/splitio/api/client.py +++ b/splitio/api/client.py @@ -5,8 +5,10 @@ import abc import logging import json -from splitio.optional.loaders import HTTPKerberosAuth, OPTIONAL +import threading +from urllib3.util import parse_url +from splitio.optional.loaders import HTTPKerberosAuth, OPTIONAL from splitio.client.config import AuthenticateScheme from splitio.optional.loaders import aiohttp from splitio.util.time import get_current_epoch_time_ms @@ -69,6 +71,24 @@ def __init__(self, message): """ Exception.__init__(self, message) +class HTTPAdapterWithProxyKerberosAuth(requests.adapters.HTTPAdapter): + """HTTPAdapter override for Kerberos Proxy auth""" + + def __init__(self, principal=None, password=None): + requests.adapters.HTTPAdapter.__init__(self) + self._principal = principal + self._password = password + + def proxy_headers(self, proxy): + headers = {} + if self._principal is not None: + auth = HTTPKerberosAuth(principal=self._principal, password=self._password) + else: + auth = HTTPKerberosAuth() + negotiate_details = auth.generate_request_header(None, parse_url(proxy).host, is_preemptive=True) + headers['Proxy-Authorization'] = negotiate_details + return headers + class HttpClientBase(object, metaclass=abc.ABCMeta): """HttpClient wrapper template.""" @@ -93,6 +113,11 @@ def set_telemetry_data(self, metric_name, telemetry_runtime_producer): self._telemetry_runtime_producer = telemetry_runtime_producer self._metric_name = metric_name + def _get_headers(self, extra_headers, sdk_key): + headers = _build_basic_headers(sdk_key) + if extra_headers is not None: + headers.update(extra_headers) + return headers class HttpClient(HttpClientBase): """HttpClient wrapper.""" @@ -112,10 +137,12 @@ def __init__(self, timeout=None, sdk_url=None, events_url=None, auth_url=None, t :param telemetry_url: Optional alternative telemetry URL. :type telemetry_url: str """ + _LOGGER.debug("Initializing httpclient") self._timeout = timeout/1000 if timeout else None # Convert ms to seconds. + self._urls = _construct_urls(sdk_url, events_url, auth_url, telemetry_url) self._authentication_scheme = authentication_scheme self._authentication_params = authentication_params - self._urls = _construct_urls(sdk_url, events_url, auth_url, telemetry_url) + self._lock = threading.RLock() def get(self, server, path, sdk_key, query=None, extra_headers=None): # pylint: disable=too-many-arguments """ @@ -135,25 +162,22 @@ def get(self, server, path, sdk_key, query=None, extra_headers=None): # pylint: :return: Tuple of status_code & response text :rtype: HttpResponse """ - headers = _build_basic_headers(sdk_key) - if extra_headers is not None: - headers.update(extra_headers) - - authentication = self._get_authentication() - start = get_current_epoch_time_ms() - try: - response = requests.get( - _build_url(server, path, self._urls), - params=query, - headers=headers, - timeout=self._timeout, - auth=authentication - ) - self._record_telemetry(response.status_code, get_current_epoch_time_ms() - start) - return HttpResponse(response.status_code, response.text, response.headers) - - except Exception as exc: # pylint: disable=broad-except - raise HttpClientException('requests library is throwing exceptions') from exc + with self._lock: + start = get_current_epoch_time_ms() + with requests.Session() as session: + self._set_authentication(session) + try: + response = session.get( + _build_url(server, path, self._urls), + params=query, + headers=self._get_headers(extra_headers, sdk_key), + timeout=self._timeout + ) + self._record_telemetry(response.status_code, get_current_epoch_time_ms() - start) + return HttpResponse(response.status_code, response.text, response.headers) + + except Exception as exc: # pylint: disable=broad-except + raise HttpClientException('requests library is throwing exceptions') from exc def post(self, server, path, sdk_key, body, query=None, extra_headers=None): # pylint: disable=too-many-arguments """ @@ -175,36 +199,37 @@ def post(self, server, path, sdk_key, body, query=None, extra_headers=None): # :return: Tuple of status_code & response text :rtype: HttpResponse """ - headers = _build_basic_headers(sdk_key) - - if extra_headers is not None: - headers.update(extra_headers) - - authentication = self._get_authentication() - start = get_current_epoch_time_ms() - try: - response = requests.post( - _build_url(server, path, self._urls), - json=body, - params=query, - headers=headers, - timeout=self._timeout, - auth=authentication - ) - self._record_telemetry(response.status_code, get_current_epoch_time_ms() - start) - return HttpResponse(response.status_code, response.text, response.headers) - - except Exception as exc: # pylint: disable=broad-except - raise HttpClientException('requests library is throwing exceptions') from exc - - def _get_authentication(self): - authentication = None - if self._authentication_scheme == AuthenticateScheme.KERBEROS: - if self._authentication_params is not None: - authentication = HTTPKerberosAuth(principal=self._authentication_params[0], password=self._authentication_params[1], mutual_authentication=OPTIONAL) + with self._lock: + start = get_current_epoch_time_ms() + with requests.Session() as session: + self._set_authentication(session) + try: + response = session.post( + _build_url(server, path, self._urls), + json=body, + params=query, + headers=self._get_headers(extra_headers, sdk_key), + timeout=self._timeout, + ) + self._record_telemetry(response.status_code, get_current_epoch_time_ms() - start) + return HttpResponse(response.status_code, response.text, response.headers) + except Exception as exc: # pylint: disable=broad-except + raise HttpClientException('requests library is throwing exceptions') from exc + + def _set_authentication(self, session): + if self._authentication_scheme == AuthenticateScheme.KERBEROS_SPNEGO: + _LOGGER.debug("Using Kerberos Spnego Authentication") + if self._authentication_params != [None, None]: + session.auth = HTTPKerberosAuth(principal=self._authentication_params[0], password=self._authentication_params[1], mutual_authentication=OPTIONAL) + else: + session.auth = HTTPKerberosAuth(mutual_authentication=OPTIONAL) + elif self._authentication_scheme == AuthenticateScheme.KERBEROS_PROXY: + _LOGGER.debug("Using Kerberos Proxy Authentication") + if self._authentication_params != [None, None]: + session.mount('https://', HTTPAdapterWithProxyKerberosAuth(principal=self._authentication_params[0], password=self._authentication_params[1])) else: - authentication = HTTPKerberosAuth(mutual_authentication=OPTIONAL) - return authentication + session.mount('https://', HTTPAdapterWithProxyKerberosAuth()) + def _record_telemetry(self, status_code, elapsed): """ @@ -220,8 +245,8 @@ def _record_telemetry(self, status_code, elapsed): if 200 <= status_code < 300: self._telemetry_runtime_producer.record_successful_sync(self._metric_name, get_current_epoch_time_ms()) return - self._telemetry_runtime_producer.record_sync_error(self._metric_name, status_code) + self._telemetry_runtime_producer.record_sync_error(self._metric_name, status_code) class HttpClientAsync(HttpClientBase): """HttpClientAsync wrapper.""" @@ -260,10 +285,8 @@ async def get(self, server, path, apikey, query=None, extra_headers=None): # py :return: Tuple of status_code & response text :rtype: HttpResponse """ - headers = _build_basic_headers(apikey) - if extra_headers is not None: - headers.update(extra_headers) start = get_current_epoch_time_ms() + headers = self._get_headers(extra_headers, apikey) try: url = _build_url(server, path, self._urls) _LOGGER.debug("GET request: %s", url) @@ -303,9 +326,7 @@ async def post(self, server, path, apikey, body, query=None, extra_headers=None) :return: Tuple of status_code & response text :rtype: HttpResponse """ - headers = _build_basic_headers(apikey) - if extra_headers is not None: - headers.update(extra_headers) + headers = self._get_headers(extra_headers, apikey) start = get_current_epoch_time_ms() try: headers['Accept-Encoding'] = 'gzip' diff --git a/splitio/client/config.py b/splitio/client/config.py index 60643a37..78d08b45 100644 --- a/splitio/client/config.py +++ b/splitio/client/config.py @@ -12,8 +12,8 @@ class AuthenticateScheme(Enum): """Authentication Scheme.""" NONE = 'NONE' - KERBEROS = 'KERBEROS' - + KERBEROS_SPNEGO = 'KERBEROS_SPNEGO' + KERBEROS_PROXY = 'KERBEROS_PROXY' DEFAULT_CONFIG = { 'operationMode': 'standalone', @@ -164,7 +164,7 @@ def sanitize(sdk_key, config): except (ValueError, AttributeError): authenticate_scheme = AuthenticateScheme.NONE _LOGGER.warning('You passed an invalid HttpAuthenticationScheme, HttpAuthenticationScheme should be ' \ - 'one of the following values: `none` or `kerberos`. ' + 'one of the following values: `none`, `kerberos_proxy` or `kerberos_spnego`. ' ' Defaulting to `none` mode.') processed["httpAuthenticateScheme"] = authenticate_scheme diff --git a/splitio/client/factory.py b/splitio/client/factory.py index 27938ecd..fffb0212 100644 --- a/splitio/client/factory.py +++ b/splitio/client/factory.py @@ -509,7 +509,7 @@ def _build_in_memory_factory(api_key, cfg, sdk_url=None, events_url=None, # pyl telemetry_init_producer = telemetry_producer.get_telemetry_init_producer() authentication_params = None - if cfg.get("httpAuthenticateScheme") == AuthenticateScheme.KERBEROS: + if cfg.get("httpAuthenticateScheme") in [AuthenticateScheme.KERBEROS_SPNEGO, AuthenticateScheme.KERBEROS_PROXY]: authentication_params = [cfg.get("kerberosPrincipalUser"), cfg.get("kerberosPrincipalPassword")] diff --git a/splitio/version.py b/splitio/version.py index a671925d..642e5ce1 100644 --- a/splitio/version.py +++ b/splitio/version.py @@ -1 +1 @@ -__version__ = '10.1.0rc1' \ No newline at end of file +__version__ = '10.1.0rc2' \ No newline at end of file diff --git a/tests/api/test_httpclient.py b/tests/api/test_httpclient.py index c0530854..d95dcb5f 100644 --- a/tests/api/test_httpclient.py +++ b/tests/api/test_httpclient.py @@ -2,6 +2,7 @@ from requests_kerberos import HTTPKerberosAuth, OPTIONAL import pytest import unittest.mock as mock +import requests from splitio.client.config import AuthenticateScheme from splitio.api import client @@ -19,7 +20,7 @@ def test_get(self, mocker): response_mock.text = 'ok' get_mock = mocker.Mock() get_mock.return_value = response_mock - mocker.patch('splitio.api.client.requests.get', new=get_mock) + mocker.patch('splitio.api.client.requests.Session.get', new=get_mock) httpclient = client.HttpClient() httpclient.set_telemetry_data("metric", mocker.Mock()) response = httpclient.get('sdk', 'test1', 'some_api_key', {'param1': 123}, {'h1': 'abc'}) @@ -27,8 +28,7 @@ def test_get(self, mocker): client.SDK_URL + '/test1', headers={'Authorization': 'Bearer some_api_key', 'h1': 'abc', 'Content-Type': 'application/json'}, params={'param1': 123}, - timeout=None, - auth=None + timeout=None ) assert response.status_code == 200 assert response.body == 'ok' @@ -40,8 +40,7 @@ def test_get(self, mocker): client.EVENTS_URL + '/test1', headers={'Authorization': 'Bearer some_api_key', 'h1': 'abc', 'Content-Type': 'application/json'}, params={'param1': 123}, - timeout=None, - auth=None + timeout=None ) assert get_mock.mock_calls == [call] assert response.status_code == 200 @@ -55,7 +54,7 @@ def test_get_custom_urls(self, mocker): response_mock.text = 'ok' get_mock = mocker.Mock() get_mock.return_value = response_mock - mocker.patch('splitio.api.client.requests.get', new=get_mock) + mocker.patch('splitio.api.client.requests.Session.get', new=get_mock) httpclient = client.HttpClient(sdk_url='https://sdk.com', events_url='https://events.com') httpclient.set_telemetry_data("metric", mocker.Mock()) response = httpclient.get('sdk', 'test1', 'some_api_key', {'param1': 123}, {'h1': 'abc'}) @@ -63,8 +62,7 @@ def test_get_custom_urls(self, mocker): 'https://sdk.com/test1', headers={'Authorization': 'Bearer some_api_key', 'h1': 'abc', 'Content-Type': 'application/json'}, params={'param1': 123}, - timeout=None, - auth=None + timeout=None ) assert get_mock.mock_calls == [call] assert response.status_code == 200 @@ -76,8 +74,7 @@ def test_get_custom_urls(self, mocker): 'https://events.com/test1', headers={'Authorization': 'Bearer some_api_key', 'h1': 'abc', 'Content-Type': 'application/json'}, params={'param1': 123}, - timeout=None, - auth=None + timeout=None ) assert response.status_code == 200 assert response.body == 'ok' @@ -92,7 +89,7 @@ def test_post(self, mocker): response_mock.text = 'ok' get_mock = mocker.Mock() get_mock.return_value = response_mock - mocker.patch('splitio.api.client.requests.post', new=get_mock) + mocker.patch('splitio.api.client.requests.Session.post', new=get_mock) httpclient = client.HttpClient() httpclient.set_telemetry_data("metric", mocker.Mock()) response = httpclient.post('sdk', 'test1', 'some_api_key', {'p1': 'a'}, {'param1': 123}, {'h1': 'abc'}) @@ -101,8 +98,7 @@ def test_post(self, mocker): json={'p1': 'a'}, headers={'Authorization': 'Bearer some_api_key', 'h1': 'abc', 'Content-Type': 'application/json'}, params={'param1': 123}, - timeout=None, - auth=None + timeout=None ) assert response.status_code == 200 assert response.body == 'ok' @@ -115,8 +111,7 @@ def test_post(self, mocker): json={'p1': 'a'}, headers={'Authorization': 'Bearer some_api_key', 'h1': 'abc', 'Content-Type': 'application/json'}, params={'param1': 123}, - timeout=None, - auth=None + timeout=None ) assert response.status_code == 200 assert response.body == 'ok' @@ -130,7 +125,7 @@ def test_post_custom_urls(self, mocker): response_mock.text = 'ok' get_mock = mocker.Mock() get_mock.return_value = response_mock - mocker.patch('splitio.api.client.requests.post', new=get_mock) + mocker.patch('splitio.api.client.requests.Session.post', new=get_mock) httpclient = client.HttpClient(sdk_url='https://sdk.com', events_url='https://events.com') httpclient.set_telemetry_data("metric", mocker.Mock()) response = httpclient.post('sdk', 'test1', 'some_api_key', {'p1': 'a'}, {'param1': 123}, {'h1': 'abc'}) @@ -139,8 +134,7 @@ def test_post_custom_urls(self, mocker): json={'p1': 'a'}, headers={'Authorization': 'Bearer some_api_key', 'h1': 'abc', 'Content-Type': 'application/json'}, params={'param1': 123}, - timeout=None, - auth=None + timeout=None ) assert response.status_code == 200 assert response.body == 'ok' @@ -153,8 +147,7 @@ def test_post_custom_urls(self, mocker): json={'p1': 'a'}, headers={'Authorization': 'Bearer some_api_key', 'h1': 'abc', 'Content-Type': 'application/json'}, params={'param1': 123}, - timeout=None, - auth=None + timeout=None ) assert response.status_code == 200 assert response.body == 'ok' @@ -166,21 +159,94 @@ def test_authentication_scheme(self, mocker): response_mock.text = 'ok' get_mock = mocker.Mock() get_mock.return_value = response_mock - mocker.patch('splitio.api.client.requests.get', new=get_mock) - httpclient = client.HttpClient(sdk_url='https://sdk.com', authentication_scheme=AuthenticateScheme.KERBEROS) + mocker.patch('splitio.api.client.requests.Session.get', new=get_mock) + httpclient = client.HttpClient(sdk_url='https://sdk.com', authentication_scheme=AuthenticateScheme.KERBEROS_SPNEGO, authentication_params=[None, None]) + httpclient.set_telemetry_data("metric", mocker.Mock()) + response = httpclient.get('sdk', '/test1', 'some_api_key', {'param1': 123}, {'h1': 'abc'}) + call = mocker.call( + 'https://sdk.com/test1', + headers={'Authorization': 'Bearer some_api_key', 'h1': 'abc', 'Content-Type': 'application/json'}, + params={'param1': 123}, + timeout=None +# auth=HTTPKerberosAuth(mutual_authentication=OPTIONAL) + ) + assert response.status_code == 200 + assert response.body == 'ok' + assert get_mock.mock_calls == [call] + get_mock.reset_mock() + + httpclient = client.HttpClient(sdk_url='https://sdk.com', authentication_scheme=AuthenticateScheme.KERBEROS_SPNEGO, authentication_params=['bilal', 'split']) + httpclient.set_telemetry_data("metric", mocker.Mock()) + response = httpclient.get('sdk', '/test1', 'some_api_key', {'param1': 123}, {'h1': 'abc'}) + call = mocker.call( + 'https://sdk.com/test1', + headers={'Authorization': 'Bearer some_api_key', 'h1': 'abc', 'Content-Type': 'application/json'}, + params={'param1': 123}, + timeout=None +# auth=HTTPKerberosAuth(principal='bilal', password='split', mutual_authentication=OPTIONAL) + ) + assert response.status_code == 200 + assert response.body == 'ok' + assert get_mock.mock_calls == [call] + get_mock.reset_mock() + + httpclient = client.HttpClient(sdk_url='https://sdk.com', authentication_scheme=AuthenticateScheme.KERBEROS_PROXY, authentication_params=[None, None]) httpclient.set_telemetry_data("metric", mocker.Mock()) response = httpclient.get('sdk', '/test1', 'some_api_key', {'param1': 123}, {'h1': 'abc'}) call = mocker.call( 'https://sdk.com/test1', headers={'Authorization': 'Bearer some_api_key', 'h1': 'abc', 'Content-Type': 'application/json'}, params={'param1': 123}, - timeout=None, - auth=HTTPKerberosAuth(mutual_authentication=OPTIONAL) + timeout=None +# auth=HTTPKerberosAuth(mutual_authentication=OPTIONAL) ) + assert response.status_code == 200 + assert response.body == 'ok' + assert get_mock.mock_calls == [call] + get_mock.reset_mock() - httpclient = client.HttpClient(sdk_url='https://sdk.com', authentication_scheme=AuthenticateScheme.KERBEROS, authentication_params=['bilal', 'split']) + httpclient = client.HttpClient(sdk_url='https://sdk.com', authentication_scheme=AuthenticateScheme.KERBEROS_PROXY, authentication_params=['bilal', 'split']) httpclient.set_telemetry_data("metric", mocker.Mock()) response = httpclient.get('sdk', '/test1', 'some_api_key', {'param1': 123}, {'h1': 'abc'}) + call = mocker.call( + 'https://sdk.com/test1', + headers={'Authorization': 'Bearer some_api_key', 'h1': 'abc', 'Content-Type': 'application/json'}, + params={'param1': 123}, + timeout=None + ) + assert response.status_code == 200 + assert response.body == 'ok' + assert get_mock.mock_calls == [call] + get_mock.reset_mock() + + # test auth settings + httpclient = client.HttpClient(sdk_url='https://sdk.com', authentication_scheme=AuthenticateScheme.KERBEROS_SPNEGO, authentication_params=['bilal', 'split']) + my_session = requests.Session() + httpclient._set_authentication(my_session) + assert(my_session.auth.principal == 'bilal') + assert(my_session.auth.password == 'split') + assert(isinstance(my_session.auth, HTTPKerberosAuth)) + + httpclient = client.HttpClient(sdk_url='https://sdk.com', authentication_scheme=AuthenticateScheme.KERBEROS_SPNEGO, authentication_params=[None, None]) + my_session2 = requests.Session() + httpclient._set_authentication(my_session2) + assert(my_session2.auth.principal == None) + assert(my_session2.auth.password == None) + assert(isinstance(my_session2.auth, HTTPKerberosAuth)) + + httpclient = client.HttpClient(sdk_url='https://sdk.com', authentication_scheme=AuthenticateScheme.KERBEROS_PROXY, authentication_params=['bilal', 'split']) + my_session = requests.Session() + httpclient._set_authentication(my_session) + assert(my_session.adapters['https://']._principal == 'bilal') + assert(my_session.adapters['https://']._password == 'split') + assert(isinstance(my_session.adapters['https://'], client.HTTPAdapterWithProxyKerberosAuth)) + + httpclient = client.HttpClient(sdk_url='https://sdk.com', authentication_scheme=AuthenticateScheme.KERBEROS_PROXY, authentication_params=[None, None]) + my_session2 = requests.Session() + httpclient._set_authentication(my_session2) + assert(my_session2.adapters['https://']._principal == None) + assert(my_session2.adapters['https://']._password == None) + assert(isinstance(my_session2.adapters['https://'], client.HTTPAdapterWithProxyKerberosAuth)) def test_telemetry(self, mocker): telemetry_storage = InMemoryTelemetryStorage() @@ -193,7 +259,7 @@ def test_telemetry(self, mocker): response_mock.text = 'ok' get_mock = mocker.Mock() get_mock.return_value = response_mock - mocker.patch('splitio.api.client.requests.post', new=get_mock) + mocker.patch('splitio.api.client.requests.Session.post', new=get_mock) httpclient = client.HttpClient(sdk_url='https://sdk.com', events_url='https://events.com') httpclient.set_telemetry_data("metric", telemetry_runtime_producer) @@ -231,7 +297,7 @@ def record_sync_error(metric_name, elapsed): assert (self.status == 400) # testing get call - mocker.patch('splitio.api.client.requests.get', new=get_mock) + mocker.patch('splitio.api.client.requests.Session.get', new=get_mock) self.metric1 = None self.cur_time = 0 self.metric2 = None diff --git a/tests/client/test_config.py b/tests/client/test_config.py index ddfd85b0..028736b3 100644 --- a/tests/client/test_config.py +++ b/tests/client/test_config.py @@ -76,8 +76,11 @@ def test_sanitize(self): processed = config.sanitize('some', {'storageType': 'pluggable', 'flagSetsFilter': ['set']}) assert processed['flagSetsFilter'] is None - processed = config.sanitize('some', {'httpAuthenticateScheme': 'KERBEROS'}) - assert processed['httpAuthenticateScheme'] is config.AuthenticateScheme.KERBEROS + processed = config.sanitize('some', {'httpAuthenticateScheme': 'KERBEROS_spnego'}) + assert processed['httpAuthenticateScheme'] is config.AuthenticateScheme.KERBEROS_SPNEGO + + processed = config.sanitize('some', {'httpAuthenticateScheme': 'kerberos_proxy'}) + assert processed['httpAuthenticateScheme'] is config.AuthenticateScheme.KERBEROS_PROXY processed = config.sanitize('some', {'httpAuthenticateScheme': 'anything'}) assert processed['httpAuthenticateScheme'] is config.AuthenticateScheme.NONE