Skip to content

Commit

Permalink
handle Host header that includes port
Browse files Browse the repository at this point in the history
  • Loading branch information
dsimms committed Mar 18, 2020
1 parent 207ca64 commit 083d1ae
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 9 deletions.
3 changes: 3 additions & 0 deletions requests_toolbelt/adapters/host_header_ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ def send(self, request, **kwargs):
connection_pool_kwargs = self.poolmanager.connection_pool_kw

if host_header:
# host header can include port, but we should not include it in the assert hostname
host_header = host_header.split(':')[0]

connection_pool_kwargs["assert_hostname"] = host_header
elif "assert_hostname" in connection_pool_kwargs:
# an assert_hostname from a previous request may have been left
Expand Down
25 changes: 16 additions & 9 deletions tests/test_host_header_ssl_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
@pytest.fixture
def session():
"""Create a session with our adapter mounted."""
session = requests.Session()
session.mount('https://', hhssl.HostHeaderSSLAdapter())
s = requests.Session()
s.mount('https://', hhssl.HostHeaderSSLAdapter())
return s


# Let's not spam example.org:
@pytest.mark.skip
class TestHostHeaderSSLAdapter(object):
"""Tests for our HostHeaderSNIAdapter."""
Expand All @@ -30,14 +32,19 @@ def test_ssladapter(self, session):
headers={'Host': 'example.com'})
assert r.status_code == 200

def test_stream(self):
self.session.get('https://54.175.219.8/stream/20',
headers={'Host': 'httpbin.org'},
stream=True)
def test_stream(self, session):
session.get('https://54.175.219.8/stream/20',
headers={'Host': 'httpbin.org'},
stream=True)

def test_case_insensitive_header(self):
r = self.session.get('https://93.184.216.34',
headers={'hOSt': 'example.org'})
def test_case_insensitive_header(self, session):
r = session.get('https://93.184.216.34',
headers={'hOSt': 'example.org'})
assert r.status_code == 200

def test_case_header_with_port(self, session):
r = session.get('https://93.184.216.34',
headers={'Host': 'example.org:443'})
assert r.status_code == 200

def test_plain_requests(self):
Expand Down

0 comments on commit 083d1ae

Please sign in to comment.