Skip to content

Commit

Permalink
Add tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
keiranjprice101 committed Jun 10, 2024
1 parent 9d9391a commit 53ad916
Show file tree
Hide file tree
Showing 2 changed files with 288 additions and 0 deletions.
110 changes: 110 additions & 0 deletions src/tokens.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
from __future__ import annotations

import logging
import os
from abc import ABC
from datetime import datetime, timezone, timedelta
from typing import Any, Dict, Optional

import jwt

from src.exceptions import BadJWTError
from src.model import User

PRIVATE_KEY = os.environ.get("PRIVATE_KEY", "shh")
logger = logging.getLogger(__name__)


class Token(ABC):

def verify(self):
try:
self._payload = jwt.decode(
self.jwt,
PRIVATE_KEY,
algorithms=["HS256"],
options={"verify_signature": True, "require": ["exp"], "verify_exp": True},
)
return
except jwt.InvalidSignatureError:
logger.warning("token has bad signature - %s", self.jwt)
except jwt.ExpiredSignatureError:
logger.warning("token signature is expired - %s", self.jwt)
except jwt.InvalidTokenError:
logger.warning("Issue decoding token - %s", self.jwt)
except Exception as e:
logger.exception("Oh Dear", e)
raise BadJWTError("oh dear")

def _encode(self) -> None:
bytes_key = bytes(PRIVATE_KEY, encoding="utf8")

self.jwt = jwt.encode(self._payload, bytes_key, algorithm="HS256")


class AccessToken(Token):
def __init__(self, jwt_token: Optional[str] = None, payload: Optional[Dict[str, Any]] = None) -> None:
if payload and not jwt_token:
self._payload = payload
self._payload["exp"] = datetime.now(timezone.utc) + timedelta(minutes=1)
self._encode()
elif jwt_token and not payload:
try:
self._payload = jwt.decode(
jwt_token,
PRIVATE_KEY,
algorithms=["HS256"],
options={"verify_signature": True, "require": ["exp"], "verify_exp": True},
)
self.jwt = jwt_token
except jwt.DecodeError as e:
raise BadJWTError("Token could not be decoded") from e
else:
raise BadJWTError("Access token creation requires jwt_token string XOR a payload")

def refresh(self):
self.verify()
self._payload["exp"] = datetime.now(timezone.utc) + timedelta(minutes=1)
self._encode()


class RefreshToken(Token):
def __init__(self, jwt_token: Optional[str] = None) -> None:

if jwt_token is None:
self._payload = {"exp": datetime.now(timezone.utc) + timedelta(hours=12)}
self._encode()
else:
self.jwt = jwt_token
try:
self._payload = jwt.decode(
self.jwt,
PRIVATE_KEY,
algorithms=["HS256"],
options={"verify_signature": True, "require": ["exp"], "verify_exp": True},
)
except jwt.DecodeError as e:
raise BadJWTError("Badly formed JWT given") from e
except jwt.ExpiredSignatureError as e:
raise BadJWTError("Token signature has expired") from e
except Exception:
raise BadJWTError("Problem decoding JWT")


def generate_access_token(user: User) -> AccessToken:
payload = {"usernumber": user.user_number, "role": user.role.value, "username": "foo"}
return AccessToken(payload=payload)


def load_access_token(token: str) -> AccessToken:
return AccessToken(jwt_token=token)


def load_refresh_token(token: str) -> RefreshToken:
if token is None:
raise BadJWTError("Token is None")
return RefreshToken(jwt_token=token)


def generate_refresh_token() -> RefreshToken:
return RefreshToken()
178 changes: 178 additions & 0 deletions test/test_tokens.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
from datetime import datetime, timezone, timedelta
from unittest.mock import patch

import jwt
import pytest

from src.exceptions import BadJWTError
from src.model import User
from src.tokens import Token, AccessToken, RefreshToken, generate_access_token


@patch("jwt.decode")
@patch("src.tokens.logger")
def test_verify_success(mock_logger, mock_decode):
token_instance = Token()
token_instance.jwt = "valid_jwt_token"
mock_decode.return_value = {"some": "payload"}

token_instance.verify()

mock_decode.assert_called_once_with(
"valid_jwt_token",
"shh",
algorithms=["HS256"],
options={"verify_signature": True, "require": ["exp"], "verify_exp": True},
)
mock_logger.warning.assert_not_called()


@patch("jwt.decode")
@patch("src.tokens.logger")
def test_verify_invalid_signature(mock_logger, mock_decode):
token_instance = Token()
token_instance.jwt = "bad_signature_jwt"
mock_decode.side_effect = jwt.InvalidSignatureError()

with pytest.raises(BadJWTError):
token_instance.verify()
mock_logger.warning.assert_called_once_with("token has bad signature - %s", "bad_signature_jwt")


@patch("jwt.decode")
@patch("src.tokens.logger")
def test_verify_expired_signature(mock_logger, mock_decode):
token_instance = Token()
token_instance.jwt = "expired_jwt_token"
mock_decode.side_effect = jwt.ExpiredSignatureError()

with pytest.raises(BadJWTError):
token_instance.verify()
mock_logger.warning.assert_called_once_with("token signature is expired - %s", "expired_jwt_token")


@patch("jwt.decode")
@patch("src.tokens.logger")
def test_verify_invalid_token(mock_logger, mock_decode):
token_instance = Token()
token_instance.jwt = "invalid_jwt_token"
mock_decode.side_effect = jwt.InvalidTokenError()

with pytest.raises(BadJWTError):
token_instance.verify()
mock_logger.warning.assert_called_once_with("Issue decoding token - %s", "invalid_jwt_token")


@patch("jwt.decode")
@patch("src.tokens.logger")
def test_verify_general_exception(mock_logger, mock_decode):
token_instance = Token()
token_instance.jwt = "jwt_with_general_issue"
exception = Exception("Unexpected error")
mock_decode.side_effect = exception

with pytest.raises(BadJWTError):
token_instance.verify()
mock_logger.exception.assert_called_once_with("Oh Dear", exception)


@patch("src.tokens.datetime")
@patch("jwt.encode")
@patch("jwt.decode")
def test_access_token_with_payload(_, mock_encode, mock_datetime):
fixed_time = datetime(2021, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
mock_datetime.now.return_value = fixed_time
payload = {"user": "test_user"}

AccessToken(payload=payload)

mock_encode.assert_called_once_with(
{"user": "test_user", "exp": fixed_time + timedelta(minutes=1)},
b"shh",
algorithm="HS256",
)


@patch("jwt.decode")
def test_access_token_with_jwt_token(mock_decode):
jwt_token = "encoded.jwt.token"
expected_payload = {"exp": datetime.now(timezone.utc) + timedelta(minutes=1)}
mock_decode.return_value = expected_payload

token = AccessToken(jwt_token=jwt_token)

assert token.jwt == jwt_token
assert token._payload == expected_payload
mock_decode.assert_called_once_with(
jwt_token,
"shh",
algorithms=["HS256"],
options={"verify_signature": True, "require": ["exp"], "verify_exp": True},
)


def test_access_token_with_both_none():
with pytest.raises(BadJWTError):
AccessToken()


@patch("jwt.encode")
@patch("jwt.decode")
def test_access_token_refresh(mock_decode, mock_encode):
jwt_token = "valid.jwt.token"
mock_decode.return_value = {"user": "test_user", "exp": datetime.now(timezone.utc)}
token = AccessToken(jwt_token=jwt_token)

token.refresh()

args, kwargs = mock_encode.call_args
assert args[0]["exp"] > datetime.now(timezone.utc) # checks if the expiration time is extended


@patch("src.tokens.datetime")
@patch("jwt.encode")
@patch("jwt.decode")
def test_refresh_token_creation_no_jwt(_, mock_encode, mock_datetime):
fixed_time = datetime(2021, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
mock_datetime.now.return_value = fixed_time

RefreshToken()

mock_encode.assert_called_once_with(
{"exp": fixed_time + timedelta(hours=12)},
b"shh",
algorithm="HS256",
)


@patch("jwt.decode")
def test_refresh_token_creation_with_jwt(mock_decode):
jwt_token = "encoded.jwt.token"
expected_payload = {"exp": datetime.now(timezone.utc) + timedelta(hours=12)}
mock_decode.return_value = expected_payload

token = RefreshToken(jwt_token=jwt_token)

assert token._payload == expected_payload
mock_decode.assert_called_once_with(
jwt_token,
"shh",
algorithms=["HS256"],
options={"verify_signature": True, "require": ["exp"], "verify_exp": True},
)


@patch("jwt.decode", side_effect=BadJWTError)
def test_refresh_token_with_invalid_jwt(_):
with pytest.raises(BadJWTError):
RefreshToken(jwt_token="invalid.jwt.token")


def test_generate_access_token():
user = User(user_number=12345)

access_token = generate_access_token(user)

expected_payload = {"usernumber": 12345, "role": "user", "username": "foo"}

assert access_token._payload == expected_payload

0 comments on commit 53ad916

Please sign in to comment.