-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
9d9391a
commit 53ad916
Showing
2 changed files
with
288 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
from __future__ import annotations | ||
|
||
import logging | ||
import os | ||
from abc import ABC | ||
from datetime import datetime, timezone, timedelta | ||
from typing import Any, Dict, Optional | ||
|
||
import jwt | ||
|
||
from src.exceptions import BadJWTError | ||
from src.model import User | ||
|
||
PRIVATE_KEY = os.environ.get("PRIVATE_KEY", "shh") | ||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class Token(ABC): | ||
|
||
def verify(self): | ||
try: | ||
self._payload = jwt.decode( | ||
self.jwt, | ||
PRIVATE_KEY, | ||
algorithms=["HS256"], | ||
options={"verify_signature": True, "require": ["exp"], "verify_exp": True}, | ||
) | ||
return | ||
except jwt.InvalidSignatureError: | ||
logger.warning("token has bad signature - %s", self.jwt) | ||
except jwt.ExpiredSignatureError: | ||
logger.warning("token signature is expired - %s", self.jwt) | ||
except jwt.InvalidTokenError: | ||
logger.warning("Issue decoding token - %s", self.jwt) | ||
except Exception as e: | ||
logger.exception("Oh Dear", e) | ||
raise BadJWTError("oh dear") | ||
|
||
def _encode(self) -> None: | ||
bytes_key = bytes(PRIVATE_KEY, encoding="utf8") | ||
|
||
self.jwt = jwt.encode(self._payload, bytes_key, algorithm="HS256") | ||
|
||
|
||
class AccessToken(Token): | ||
def __init__(self, jwt_token: Optional[str] = None, payload: Optional[Dict[str, Any]] = None) -> None: | ||
if payload and not jwt_token: | ||
self._payload = payload | ||
self._payload["exp"] = datetime.now(timezone.utc) + timedelta(minutes=1) | ||
self._encode() | ||
elif jwt_token and not payload: | ||
try: | ||
self._payload = jwt.decode( | ||
jwt_token, | ||
PRIVATE_KEY, | ||
algorithms=["HS256"], | ||
options={"verify_signature": True, "require": ["exp"], "verify_exp": True}, | ||
) | ||
self.jwt = jwt_token | ||
except jwt.DecodeError as e: | ||
raise BadJWTError("Token could not be decoded") from e | ||
else: | ||
raise BadJWTError("Access token creation requires jwt_token string XOR a payload") | ||
|
||
def refresh(self): | ||
self.verify() | ||
self._payload["exp"] = datetime.now(timezone.utc) + timedelta(minutes=1) | ||
self._encode() | ||
|
||
|
||
class RefreshToken(Token): | ||
def __init__(self, jwt_token: Optional[str] = None) -> None: | ||
|
||
if jwt_token is None: | ||
self._payload = {"exp": datetime.now(timezone.utc) + timedelta(hours=12)} | ||
self._encode() | ||
else: | ||
self.jwt = jwt_token | ||
try: | ||
self._payload = jwt.decode( | ||
self.jwt, | ||
PRIVATE_KEY, | ||
algorithms=["HS256"], | ||
options={"verify_signature": True, "require": ["exp"], "verify_exp": True}, | ||
) | ||
except jwt.DecodeError as e: | ||
raise BadJWTError("Badly formed JWT given") from e | ||
except jwt.ExpiredSignatureError as e: | ||
raise BadJWTError("Token signature has expired") from e | ||
except Exception: | ||
raise BadJWTError("Problem decoding JWT") | ||
|
||
|
||
def generate_access_token(user: User) -> AccessToken: | ||
payload = {"usernumber": user.user_number, "role": user.role.value, "username": "foo"} | ||
return AccessToken(payload=payload) | ||
|
||
|
||
def load_access_token(token: str) -> AccessToken: | ||
return AccessToken(jwt_token=token) | ||
|
||
|
||
def load_refresh_token(token: str) -> RefreshToken: | ||
if token is None: | ||
raise BadJWTError("Token is None") | ||
return RefreshToken(jwt_token=token) | ||
|
||
|
||
def generate_refresh_token() -> RefreshToken: | ||
return RefreshToken() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,178 @@ | ||
from datetime import datetime, timezone, timedelta | ||
from unittest.mock import patch | ||
|
||
import jwt | ||
import pytest | ||
|
||
from src.exceptions import BadJWTError | ||
from src.model import User | ||
from src.tokens import Token, AccessToken, RefreshToken, generate_access_token | ||
|
||
|
||
@patch("jwt.decode") | ||
@patch("src.tokens.logger") | ||
def test_verify_success(mock_logger, mock_decode): | ||
token_instance = Token() | ||
token_instance.jwt = "valid_jwt_token" | ||
mock_decode.return_value = {"some": "payload"} | ||
|
||
token_instance.verify() | ||
|
||
mock_decode.assert_called_once_with( | ||
"valid_jwt_token", | ||
"shh", | ||
algorithms=["HS256"], | ||
options={"verify_signature": True, "require": ["exp"], "verify_exp": True}, | ||
) | ||
mock_logger.warning.assert_not_called() | ||
|
||
|
||
@patch("jwt.decode") | ||
@patch("src.tokens.logger") | ||
def test_verify_invalid_signature(mock_logger, mock_decode): | ||
token_instance = Token() | ||
token_instance.jwt = "bad_signature_jwt" | ||
mock_decode.side_effect = jwt.InvalidSignatureError() | ||
|
||
with pytest.raises(BadJWTError): | ||
token_instance.verify() | ||
mock_logger.warning.assert_called_once_with("token has bad signature - %s", "bad_signature_jwt") | ||
|
||
|
||
@patch("jwt.decode") | ||
@patch("src.tokens.logger") | ||
def test_verify_expired_signature(mock_logger, mock_decode): | ||
token_instance = Token() | ||
token_instance.jwt = "expired_jwt_token" | ||
mock_decode.side_effect = jwt.ExpiredSignatureError() | ||
|
||
with pytest.raises(BadJWTError): | ||
token_instance.verify() | ||
mock_logger.warning.assert_called_once_with("token signature is expired - %s", "expired_jwt_token") | ||
|
||
|
||
@patch("jwt.decode") | ||
@patch("src.tokens.logger") | ||
def test_verify_invalid_token(mock_logger, mock_decode): | ||
token_instance = Token() | ||
token_instance.jwt = "invalid_jwt_token" | ||
mock_decode.side_effect = jwt.InvalidTokenError() | ||
|
||
with pytest.raises(BadJWTError): | ||
token_instance.verify() | ||
mock_logger.warning.assert_called_once_with("Issue decoding token - %s", "invalid_jwt_token") | ||
|
||
|
||
@patch("jwt.decode") | ||
@patch("src.tokens.logger") | ||
def test_verify_general_exception(mock_logger, mock_decode): | ||
token_instance = Token() | ||
token_instance.jwt = "jwt_with_general_issue" | ||
exception = Exception("Unexpected error") | ||
mock_decode.side_effect = exception | ||
|
||
with pytest.raises(BadJWTError): | ||
token_instance.verify() | ||
mock_logger.exception.assert_called_once_with("Oh Dear", exception) | ||
|
||
|
||
@patch("src.tokens.datetime") | ||
@patch("jwt.encode") | ||
@patch("jwt.decode") | ||
def test_access_token_with_payload(_, mock_encode, mock_datetime): | ||
fixed_time = datetime(2021, 1, 1, 12, 0, 0, tzinfo=timezone.utc) | ||
mock_datetime.now.return_value = fixed_time | ||
payload = {"user": "test_user"} | ||
|
||
AccessToken(payload=payload) | ||
|
||
mock_encode.assert_called_once_with( | ||
{"user": "test_user", "exp": fixed_time + timedelta(minutes=1)}, | ||
b"shh", | ||
algorithm="HS256", | ||
) | ||
|
||
|
||
@patch("jwt.decode") | ||
def test_access_token_with_jwt_token(mock_decode): | ||
jwt_token = "encoded.jwt.token" | ||
expected_payload = {"exp": datetime.now(timezone.utc) + timedelta(minutes=1)} | ||
mock_decode.return_value = expected_payload | ||
|
||
token = AccessToken(jwt_token=jwt_token) | ||
|
||
assert token.jwt == jwt_token | ||
assert token._payload == expected_payload | ||
mock_decode.assert_called_once_with( | ||
jwt_token, | ||
"shh", | ||
algorithms=["HS256"], | ||
options={"verify_signature": True, "require": ["exp"], "verify_exp": True}, | ||
) | ||
|
||
|
||
def test_access_token_with_both_none(): | ||
with pytest.raises(BadJWTError): | ||
AccessToken() | ||
|
||
|
||
@patch("jwt.encode") | ||
@patch("jwt.decode") | ||
def test_access_token_refresh(mock_decode, mock_encode): | ||
jwt_token = "valid.jwt.token" | ||
mock_decode.return_value = {"user": "test_user", "exp": datetime.now(timezone.utc)} | ||
token = AccessToken(jwt_token=jwt_token) | ||
|
||
token.refresh() | ||
|
||
args, kwargs = mock_encode.call_args | ||
assert args[0]["exp"] > datetime.now(timezone.utc) # checks if the expiration time is extended | ||
|
||
|
||
@patch("src.tokens.datetime") | ||
@patch("jwt.encode") | ||
@patch("jwt.decode") | ||
def test_refresh_token_creation_no_jwt(_, mock_encode, mock_datetime): | ||
fixed_time = datetime(2021, 1, 1, 12, 0, 0, tzinfo=timezone.utc) | ||
mock_datetime.now.return_value = fixed_time | ||
|
||
RefreshToken() | ||
|
||
mock_encode.assert_called_once_with( | ||
{"exp": fixed_time + timedelta(hours=12)}, | ||
b"shh", | ||
algorithm="HS256", | ||
) | ||
|
||
|
||
@patch("jwt.decode") | ||
def test_refresh_token_creation_with_jwt(mock_decode): | ||
jwt_token = "encoded.jwt.token" | ||
expected_payload = {"exp": datetime.now(timezone.utc) + timedelta(hours=12)} | ||
mock_decode.return_value = expected_payload | ||
|
||
token = RefreshToken(jwt_token=jwt_token) | ||
|
||
assert token._payload == expected_payload | ||
mock_decode.assert_called_once_with( | ||
jwt_token, | ||
"shh", | ||
algorithms=["HS256"], | ||
options={"verify_signature": True, "require": ["exp"], "verify_exp": True}, | ||
) | ||
|
||
|
||
@patch("jwt.decode", side_effect=BadJWTError) | ||
def test_refresh_token_with_invalid_jwt(_): | ||
with pytest.raises(BadJWTError): | ||
RefreshToken(jwt_token="invalid.jwt.token") | ||
|
||
|
||
def test_generate_access_token(): | ||
user = User(user_number=12345) | ||
|
||
access_token = generate_access_token(user) | ||
|
||
expected_payload = {"usernumber": 12345, "role": "user", "username": "foo"} | ||
|
||
assert access_token._payload == expected_payload |