diff --git a/requests_toolbelt/adapters/host_header_ssl.py b/requests_toolbelt/adapters/host_header_ssl.py index f34ed1aa..8515c635 100644 --- a/requests_toolbelt/adapters/host_header_ssl.py +++ b/requests_toolbelt/adapters/host_header_ssl.py @@ -35,6 +35,10 @@ def send(self, request, **kwargs): connection_pool_kwargs = self.poolmanager.connection_pool_kw if host_header: + # host header can include port, but we should not include it in the + # assert_hostname + host_header = host_header.split(':')[0] + connection_pool_kwargs["assert_hostname"] = host_header elif "assert_hostname" in connection_pool_kwargs: # an assert_hostname from a previous request may have been left diff --git a/tests/test_host_header_ssl_adapter.py b/tests/test_host_header_ssl_adapter.py index d86378e6..d5795ecc 100644 --- a/tests/test_host_header_ssl_adapter.py +++ b/tests/test_host_header_ssl_adapter.py @@ -7,10 +7,12 @@ @pytest.fixture def session(): """Create a session with our adapter mounted.""" - session = requests.Session() - session.mount('https://', hhssl.HostHeaderSSLAdapter()) + s = requests.Session() + s.mount('https://', hhssl.HostHeaderSSLAdapter()) + return s +# Let's not spam example.org: @pytest.mark.skip class TestHostHeaderSSLAdapter(object): """Tests for our HostHeaderSNIAdapter.""" @@ -30,14 +32,19 @@ def test_ssladapter(self, session): headers={'Host': 'example.com'}) assert r.status_code == 200 - def test_stream(self): - self.session.get('https://54.175.219.8/stream/20', - headers={'Host': 'httpbin.org'}, - stream=True) + def test_stream(self, session): + session.get('https://54.175.219.8/stream/20', + headers={'Host': 'httpbin.org'}, + stream=True) - def test_case_insensitive_header(self): - r = self.session.get('https://93.184.216.34', - headers={'hOSt': 'example.org'}) + def test_case_insensitive_header(self, session): + r = session.get('https://93.184.216.34', + headers={'hOSt': 'example.org'}) + assert r.status_code == 200 + + def test_case_header_with_port(self, session): + r = session.get('https://93.184.216.34', + headers={'Host': 'example.org:443'}) assert r.status_code == 200 def test_plain_requests(self):