Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Raise HTTPError in case of HTTP 5XX responses. #441

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 45 additions & 3 deletions requests_oauthlib/oauth2_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from oauthlib.oauth2 import LegacyApplicationClient
from oauthlib.oauth2 import TokenExpiredError, is_secure_transport
import requests
from requests.exceptions import HTTPError

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -230,9 +231,9 @@ def fetch_token(
`auth` tuple. If the value is `None`, it will be
omitted from the request, however if the value is
an empty string, an empty string will be sent.
:param cert: Client certificate to send for OAuth 2.0 Mutual-TLS Client
Authentication (draft-ietf-oauth-mtls). Can either be the
path of a file containing the private key and certificate or
:param cert: Client certificate to send for OAuth 2.0 Mutual-TLS Client
Authentication (draft-ietf-oauth-mtls). Can either be the
path of a file containing the private key and certificate or
a tuple of two filenames for certificate and key.
:param kwargs: Extra parameters to include in the token request.
:return: A token dict
Expand Down Expand Up @@ -363,6 +364,8 @@ def fetch_token(
log.debug("Invoking hook %s.", hook)
r = hook(r)

self._raise_for_5xx(response=r)

self._client.parse_request_body_response(r.text, scope=self.scope)
self.token = self._client.token
log.debug("Obtained token %s.", self.token)
Expand Down Expand Up @@ -449,6 +452,8 @@ def refresh_token(
log.debug("Invoking hook %s.", hook)
r = hook(r)

self._raise_for_5xx(response=r)

self.token = self._client.parse_request_body_response(r.text, scope=self.scope)
if not "refresh_token" in self.token:
log.debug("No new refresh token given. Re-using old.")
Expand Down Expand Up @@ -538,3 +543,40 @@ def register_compliance_hook(self, hook_type, hook):
"Hook type %s is not in %s.", hook_type, self.compliance_hook
)
self.compliance_hook[hook_type].add(hook)

def _raise_for_5xx(self, response):
# type: (requests.models.Response) -> None
"""
Raise requests.HTTPError if response is an HTTP 5XX error.

Just like the existing Response.raise_for_status() but ignores 4XX
errors.

:param response: HTTP response object from requests
Raises :class:`requests.exceptions.HTTPError`, if a 5XX error occurred.
"""
http_error_msg = ""
if isinstance(response.reason, bytes):
# We attempt to decode utf-8 first because some servers
# choose to localize their reason strings. If the string
# isn't utf-8, we fall back to iso-8859-1 for all other
# encodings. (See psf/requests PR #3538)
try:
reason = response.reason.decode("utf-8")
except UnicodeDecodeError:
reason = response.reason.decode("iso-8859-1")
else:
reason = response.reason
Comment on lines +559 to +569
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You seem to only need the reason in the elif 500 <= response.status_code < 600 block below, so maybe worth moving it in there? That way the behavior doesn't change for any other response codes (at the moment you try decode the reason even if ultimately you might not need it, thus introducing potential risk of breakage for no reason.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code is from requests.Response.raise_for_status(), and I chose to make as few changes as possible since the only desired difference in behavior is to ignore 4XX errors.


if 400 <= response.status_code < 500:
pass # ignored

elif 500 <= response.status_code < 600:
http_error_msg = "%s Server Error: %s for url: %s" % (
response.status_code,
reason,
response.url,
)

if http_error_msg:
raise HTTPError(http_error_msg, response=response)
Comment on lines +581 to +582
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You only define http_error_msg in the elif 500 <= response.status_code < 600 block above, so maybe worth moving it up there (and then not explicitly checking for it nor assigning it to a variable and raising directly)?

67 changes: 65 additions & 2 deletions tests/test_oauth2_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,17 @@
import requests

from requests.auth import _basic_auth_str
from requests.exceptions import HTTPError


fake_time = time.time()
CODE = "asdf345xdf"


def fake_token(token):
def fake_token(token, status_code=200):
def fake_send(r, **kwargs):
resp = mock.MagicMock()
resp.status_code = status_code
resp.text = json.dumps(token)
return resp

Expand Down Expand Up @@ -70,6 +72,7 @@ def verifier(r, **kwargs):
auth_header = r.headers.get(str("Authorization"), None)
self.assertEqual(auth_header, token)
resp = mock.MagicMock()
resp.status_code = 200
resp.cookes = []
return resp

Expand All @@ -89,6 +92,7 @@ def verifier(r, **kwargs):
self.assertEqual(cert, kwargs["cert"])
self.assertIn("client_id=" + self.client_id, r.body)
resp = mock.MagicMock()
resp.status_code = 200
resp.text = json.dumps(self.token)
return resp

Expand Down Expand Up @@ -130,10 +134,11 @@ def test_refresh_token_request(self):
self.expired_token["expires_in"] = "-1"
del self.expired_token["expires_at"]

def fake_refresh(r, **kwargs):
def fake_refresh(r, status_code=200, **kwargs):
if "/refresh" in r.url:
self.assertNotIn("Authorization", r.headers)
resp = mock.MagicMock()
resp.status_code = status_code
resp.text = json.dumps(self.token)
return resp

Expand Down Expand Up @@ -166,6 +171,19 @@ def token_updater(token):
sess.send = fake_refresh
sess.get("https://i.b")

# test 5xx error handler
for client in self.clients:
sess = OAuth2Session(
client=client,
token=self.expired_token,
auto_refresh_url="https://i.b/refresh",
token_updater=token_updater,
)
sess.send = lambda r, **kwargs: fake_refresh(
r=r, status_code=503, kwargs=kwargs
)
self.assertRaises(HTTPError, sess.get, "https://i.b")

def fake_refresh_with_auth(r, **kwargs):
if "/refresh" in r.url:
self.assertIn("Authorization", r.headers)
Expand All @@ -177,6 +195,7 @@ def fake_refresh_with_auth(r, **kwargs):
content = "Basic {encoded}".format(encoded=encoded.decode("latin1"))
self.assertEqual(r.headers["Authorization"], content)
resp = mock.MagicMock()
resp.status_code = 200
resp.text = json.dumps(self.token)
return resp

Expand Down Expand Up @@ -251,6 +270,23 @@ def test_fetch_token(self):
else:
self.assertRaises(OAuth2Error, sess.fetch_token, url)

# test 5xx error responses
error = {"error": "server error!"}
for client in self.clients:
sess = OAuth2Session(client=client, token=self.token)
sess.send = fake_token(error, status_code=500)
if isinstance(client, LegacyApplicationClient):
# this client requires a username+password
self.assertRaises(
HTTPError,
sess.fetch_token,
url,
username="username1",
password="password1",
)
else:
self.assertRaises(HTTPError, sess.fetch_token, url)

# there are different scenarios in which the `client_id` can be specified
# reference `oauthlib.tests.oauth2.rfc6749.clients.test_web_application.WebApplicationClientTest.test_prepare_request_body`
# this only needs to test WebApplicationClient
Expand All @@ -263,6 +299,7 @@ def test_fetch_token(self):
def fake_token_history(token):
def fake_send(r, **kwargs):
resp = mock.MagicMock()
resp.status_code = 200
resp.text = json.dumps(token)
_fetch_history.append(
(r.url, r.body, r.headers.get("Authorization", None))
Expand Down Expand Up @@ -470,6 +507,7 @@ def test_authorized_true(self):
def fake_token(token):
def fake_send(r, **kwargs):
resp = mock.MagicMock()
resp.status_code = 200
resp.text = json.dumps(token)
return resp

Expand Down Expand Up @@ -497,6 +535,31 @@ def fake_send(r, **kwargs):
sess.fetch_token(url)
self.assertTrue(sess.authorized)

def test_raise_for_5xx(self):
for reason_bytes in [
b"\xa1An error occurred!", # iso-8859-i
b"\xc2\xa1An error occurred!", # utf-8
]:
fake_resp = mock.MagicMock()
fake_resp.status_code = 504
fake_resp.reason = reason_bytes
reason_unicode = "\u00A1An error occurred!"
fake_resp.url = "https://example.com/token"
expected = (
"504 Server Error: " + reason_unicode + " for url: " + fake_resp.url
)

# Make sure our expected unicode string literal is indeed unicode
# in both py2 and py3
self.assertEqual(reason_unicode[0].encode("utf-8"), b"\xc2\xa1")

sess = OAuth2Session("test-id")

with self.assertRaises(HTTPError) as cm:
sess._raise_for_5xx(fake_resp)

self.assertEqual(cm.exception.args[0], expected)


class OAuth2SessionNetrcTest(OAuth2SessionTest):
"""Ensure that there is no magic auth handling.
Expand Down