Skip to content

Commit

Permalink
Merge pull request #15 from twosigmajab/consume-request
Browse files Browse the repository at this point in the history
Limit num of bytes read in _consume_request.
  • Loading branch information
propertone authored Dec 25, 2020
2 parents 4ea5cdd + 013f725 commit 714ba41
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 15 deletions.
21 changes: 19 additions & 2 deletions test_wsgi_kerberos.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from wsgi_kerberos import KerberosAuthMiddleware, ensure_bytestring
from webtest import TestApp
from wsgi_kerberos import KerberosAuthMiddleware, ensure_bytestring, _DEFAULT_READ_MAX
from webtest import TestApp, TestRequest
import kerberos
import mock
import unittest
Expand Down Expand Up @@ -120,6 +120,23 @@ def test_unauthorized(self):
self.assertEqual(r.headers['www-authenticate'], 'Negotiate')
self.assertEqual(r.headers['content-type'], 'text/plain')

def test_read_max_on_auth_fail(self):
'''
KerberosAuthMiddleware's ``read_max_on_auth_fail`` should allow
customizing reading of request bodies of unauthenticated requests.
'''
body = b'body of unauthenticated request'
for read_max in (0, 5, 100, _DEFAULT_READ_MAX, float('inf')):
# When we drop Py2, we can use `with self.subTest(read_max=read_max):` here.
app = TestApp(KerberosAuthMiddleware(index, read_max_on_auth_fail=read_max))
req = TestRequest.blank('/', method='POST', body=body)
resp = app.do_request(req, status=401)
if read_max < len(body):
expect_read = 0
else:
expect_read = min(read_max, len(body))
self.assertEqual(req.body_file.input.tell(), expect_read)

def test_unauthorized_when_missing_negotiate(self):
'''
Ensure that when the client sends an Authorization header that does
Expand Down
57 changes: 44 additions & 13 deletions wsgi_kerberos.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,14 @@ def ensure_bytestring(s):
return s.encode('utf-8') if isinstance(s, unicode) else s


def _consume_request(environ):
# 10 << 20 (10 MB) is the maxFormSize value that go uses. Ref:
# https://github.com/golang/go/blob/bc7e4d9/src/net/http/request.go#L1204
_DEFAULT_READ_MAX = 10 << 20


def _consume_request(environ, read_max):
'''
Consume and discard all of the data on the request.
Consume and discard up to *read_max* bytes of the request.
This avoids problems that some clients have when they get an unexpected
and premature close from the server.
Expand All @@ -40,25 +45,41 @@ def _consume_request(environ):
defending itself against denial-of-service attacks, or from badly broken
client implementations.
'''
if read_max == 0: # Short-circuit early when user opts out of this.
return
if environ.get("HTTP_EXPECT") == "100-continue":
return
try:
sock = environ.get('wsgi.input')
if hasattr(sock, 'closed') and sock.closed:
return
# Figure out how much content is available for us to consume.
expected = int(environ.get('CONTENT_LENGTH', '0'))

content_length = environ.get('CONTENT_LENGTH', '')
if not content_length:
remaining = read_max
else:
content_length = int(content_length)
if content_length > read_max:
# User is not willing to read such a large request body, but
# reading anything less does not help naively-written clients.
return
remaining = content_length

# Try to receive all of the data. Keep retrying until we get an error
# which indicates that we can't retry. Eat errors. The client will just
# have to deal with a possible Broken Pipe -- we tried.
received = 0
while received < expected:
while remaining > 0:
try:
received += len(sock.read(expected - received))
delta = len(sock.read(remaining))
except socket.error as err:
if err.errno != errno.EAGAIN:
break
except (KeyError, ValueError):
pass
else:
if delta == 0:
break
remaining -= delta
except Exception as exc:
LOG.debug("_consume_request suppressed: %s", exc)


class KerberosAuthMiddleware(object):
Expand All @@ -77,10 +98,19 @@ class KerberosAuthMiddleware(object):
:param auth_required_callback: predicate accepting the WSGI environ
for a request returning whether the request should be authenticated
:type auth_required_callback: callable
:param read_max_on_auth_fail: When a request could not be authenticated,
read and discard up to this many bytes of the request. This may help
naively- written clients that send large request bodies which they
expect to be consumed before first confirming that the request was
authenticated successfully. Pass 0 to disable this if you don't want to
waste resources to potentially accommodate such clients. Pass math.inf
to read an unlimited number of bytes. Beware that the more the server
is willing to read, the more vulnerable it becomes to denial-of-service
attacks.
:type read_max_on_auth_fail: int
'''

def __init__(self, app, hostname='', unauthorized=None, forbidden=None,
auth_required_callback=None):
auth_required_callback=None, read_max_on_auth_fail=_DEFAULT_READ_MAX):
if hostname:
self._check_hostname(hostname)
self.service = 'HTTP@%s' % hostname
Expand All @@ -105,6 +135,7 @@ def __init__(self, app, hostname='', unauthorized=None, forbidden=None,
self.unauthorized = unauthorized # 401 response text/content-type
self.forbidden = forbidden # 403 response text/content-type
self.auth_required_callback = auth_required_callback
self.read_max_on_auth_fail = read_max_on_auth_fail

@staticmethod
def _check_hostname(hostname):
Expand All @@ -124,7 +155,7 @@ def _unauthorized(self, environ, start_response, token=None):
headers.append(('WWW-Authenticate', token))
else:
headers.append(('WWW-Authenticate', 'Negotiate'))
_consume_request(environ)
_consume_request(environ, self.read_max_on_auth_fail)
start_response('401 Unauthorized', headers)
return [self.unauthorized[0]]

Expand All @@ -133,7 +164,7 @@ def _forbidden(self, environ, start_response):
Send a 403 Forbidden response
'''
headers = [('content-type', self.forbidden[1])]
_consume_request(environ)
_consume_request(environ, self.read_max_on_auth_fail)
start_response('403 Forbidden', headers)
return [self.forbidden[0]]

Expand Down

0 comments on commit 714ba41

Please sign in to comment.