Skip to content

Commit

Permalink
Add endpoint for JWT refresh tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
stveit committed Feb 21, 2025
1 parent 63b9bc5 commit d1bfbc3
Show file tree
Hide file tree
Showing 5 changed files with 325 additions and 0 deletions.
1 change: 1 addition & 0 deletions changelog.d/3270.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add endpoint for JWT refresh tokens
1 change: 1 addition & 0 deletions python/nav/web/api/v1/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,5 @@
name="prefix-usage-detail",
),
re_path(r'^', include(router.urls)),
re_path(r'^refresh/$', views.JWTRefreshViewSet.as_view(), name='jwt-refresh'),
]
50 changes: 50 additions & 0 deletions python/nav/web/api/v1/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,21 @@
from oidc_auth.authentication import JSONWebTokenAuthentication

from nav.models import manage, event, cabling, rack, profiles
from nav.models.api import JWTRefreshToken
from nav.models.fields import INFINITY, UNRESOLVED
from nav.web.servicecheckers import load_checker_classes
from nav.util import auth_token, is_valid_cidr

from nav.buildconf import VERSION
from nav.web.api.v1 import serializers, alert_serializers
from nav.web.status2 import STATELESS_THRESHOLD
from nav.web.jwtgen import (
generate_access_token,
generate_refresh_token,
hash_token,
decode_token,
is_active,
)
from nav.macaddress import MacPrefix
from .auth import (
APIPermission,
Expand Down Expand Up @@ -1153,3 +1161,45 @@ class ModuleViewSet(NAVAPIMixin, viewsets.ReadOnlyModelViewSet):
'device__serial',
)
serializer_class = serializers.ModuleSerializer


class JWTRefreshViewSet(APIView):
"""
Accepts a valid refresh token.
Returns a new refresh token and an access token.
"""

def post(self, request):
incoming_token = request.data.get('refresh_token')
token_hash = hash_token(incoming_token)
try:
# If hash exists in the database, then we know it is a real token
db_token = JWTRefreshToken.objects.get(hash=token_hash)
except JWTRefreshToken.DoesNotExist:
return Response("Invalid token", status=status.HTTP_403_FORBIDDEN)

claims = decode_token(incoming_token)
if not is_active(claims['exp'], claims['nbf']):
return Response("Inactive token", status=status.HTTP_403_FORBIDDEN)

if db_token.revoked:
return Response(
"This token has been revoked", status=status.HTTP_403_FORBIDDEN
)

access_token = generate_access_token(claims)
refresh_token = generate_refresh_token(claims)

new_claims = decode_token(refresh_token)
new_hash = hash_token(refresh_token)
db_token.hash = new_hash
db_token.expires = datetime.fromtimestamp(new_claims['exp'])
db_token.activates = datetime.fromtimestamp(new_claims['nbf'])
db_token.last_used = datetime.now()
db_token.save()

response_data = {
'access_token': access_token,
'refresh_token': refresh_token,
}
return Response(response_data)
12 changes: 12 additions & 0 deletions python/nav/web/jwtgen.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from datetime import datetime, timedelta, timezone
from typing import Any, Optional
import hashlib

import jwt

Expand Down Expand Up @@ -64,3 +65,14 @@ def is_active(exp: float, nbf: float) -> bool:
expires = datetime.fromtimestamp(exp, tz=timezone.utc)
activates = datetime.fromtimestamp(nbf, tz=timezone.utc)
return now >= activates and now < expires

def hash_token(token: str) -> str:
"""Hashes a token with SHA256"""
hash_object = hashlib.sha256(token.encode('utf-8'))
hex_dig = hash_object.hexdigest()
return hex_dig


def decode_token(token: str) -> dict[str, Any]:
"""Decodes a token in JWT format and returns the data of the decoded token"""
return jwt.decode(token, options={'verify_signature': False})
261 changes: 261 additions & 0 deletions tests/integration/jwt_refresh_endpoint_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
from typing import Generator

import jwt
import pytest
from datetime import datetime, timedelta, timezone

from unittest.mock import Mock, patch

from django.urls import reverse
from nav.models.api import JWTRefreshToken
from nav.web.jwtgen import generate_refresh_token, hash_token, decode_token


def test_token_not_in_database_should_be_rejected(db, api_client, url):
token = generate_refresh_token()
token_hash = hash_token(token)

assert not JWTRefreshToken.objects.filter(hash=token_hash).exists()
response = api_client.post(
url,
follow=True,
data={
'refresh_token': token,
},
)
assert response.status_code == 403


def test_inactive_token_should_be_rejected(db, api_client, url, inactive_token):
now = datetime.now()
db_token = JWTRefreshToken(
name="testtoken",
hash=hash_token(inactive_token),
expires=now - timedelta(hours=1),
activates=now - timedelta(hours=2),
)
db_token.save()

response = api_client.post(
url,
follow=True,
data={
'refresh_token': inactive_token,
},
)

assert response.status_code == 403


def test_valid_token_should_be_accepted(db, api_client, url, active_token):
data = decode_token(active_token)
db_token = JWTRefreshToken(
name="testtoken",
hash=hash_token(active_token),
expires=datetime.fromtimestamp(data['exp']),
activates=datetime.fromtimestamp(data['nbf']),
)
db_token.save()
response = api_client.post(
url,
follow=True,
data={
'refresh_token': active_token,
},
)
assert response.status_code == 200


def test_valid_token_should_be_replaced_by_new_token_in_db(
db, api_client, url, active_token
):
token_hash = hash_token(active_token)
data = decode_token(active_token)
db_token = JWTRefreshToken(
name="testtoken",
hash=token_hash,
expires=datetime.fromtimestamp(data['exp']),
activates=datetime.fromtimestamp(data['nbf']),
)
db_token.save()
response = api_client.post(
url,
follow=True,
data={
'refresh_token': active_token,
},
)
assert response.status_code == 200
assert not JWTRefreshToken.objects.filter(hash=token_hash).exists()
new_token = response.data.get("refresh_token")
new_hash = hash_token(new_token)
assert JWTRefreshToken.objects.filter(hash=new_hash).exists()


def test_should_include_access_and_refresh_token_in_response(
db, api_client, url, active_token
):
data = decode_token(active_token)
db_token = JWTRefreshToken(
name="testtoken",
hash=hash_token(active_token),
expires=datetime.fromtimestamp(data['exp']),
activates=datetime.fromtimestamp(data['nbf']),
)
db_token.save()
response = api_client.post(
url,
follow=True,
data={
'refresh_token': active_token,
},
)
assert response.status_code == 200
assert "access_token" in response.data
assert "refresh_token" in response.data


def test_revoked_token_should_be_rejected(db, api_client, url, active_token):
data = decode_token(active_token)
db_token = JWTRefreshToken(
name="testtoken",
hash=hash_token(active_token),
expires=datetime.fromtimestamp(data['exp']),
activates=datetime.fromtimestamp(data['nbf']),
revoked=True,
)
db_token.save()
response = api_client.post(
url,
follow=True,
data={
'refresh_token': active_token,
},
)
assert response.status_code == 403


def test_last_used_should_be_updated_after_token_is_used(
db, api_client, url, active_token
):
token_hash = hash_token(active_token)
data = decode_token(active_token)
db_token = JWTRefreshToken(
name="testtoken",
hash=token_hash,
expires=datetime.fromtimestamp(data['exp']),
activates=datetime.fromtimestamp(data['nbf']),
)
db_token.save()
assert db_token.last_used is None
response = api_client.post(
url,
follow=True,
data={
'refresh_token': active_token,
},
)
new_token = response.data.get("refresh_token")
new_hash = hash_token(new_token)
assert JWTRefreshToken.objects.get(hash=new_hash).last_used is not None


@pytest.fixture()
def inactive_token(nav_name, private_key) -> str:
now = datetime.now(timezone.utc)
claims = {
'exp': (now - timedelta(hours=1)).timestamp(),
'nbf': (now - timedelta(hours=2)).timestamp(),
'iat': (now - timedelta(hours=2)).timestamp(),
'aud': nav_name,
'iss': nav_name,
'token_type': 'refresh_token',
}
token = jwt.encode(claims, private_key, algorithm="RS256")
return token


@pytest.fixture()
def active_token(nav_name, private_key) -> str:
now = datetime.now(timezone.utc)
claims = {
'exp': (now + timedelta(hours=1)).timestamp(),
'nbf': now.timestamp(),
'iat': now.timestamp(),
'aud': nav_name,
'iss': nav_name,
'token_type': 'refresh_token',
}
token = jwt.encode(claims, private_key, algorithm="RS256")
return token


@pytest.fixture()
def url():
return reverse('api:1:jwt-refresh')


@pytest.fixture(scope="module", autouse=True)
def jwtconf_mock(private_key, nav_name) -> Generator[str, None, None]:
"""Mocks the get_nave_name and get_nav_private_key functions for
the JWTConf class
"""
with patch("nav.web.jwtgen.JWTConf") as _jwtconf_mock:
instance = _jwtconf_mock.return_value
instance.get_nav_name = Mock(return_value=nav_name)
instance.get_nav_private_key = Mock(return_value=private_key)
yield _jwtconf_mock


@pytest.fixture(scope="module")
def private_key() -> str:
"""Yields a private key in PEM format"""
key = """-----BEGIN PRIVATE KEY-----

Check failure

Code scanning / SonarCloud

Cryptographic private keys should not be disclosed High test

Make sure this private key gets revoked, changed, and removed from the code. See more on SonarQube Cloud
MIIEuwIBADANBgkqhkiG9w0BAQEFAASCBKUwggShAgEAAoIBAQCp+4AEZM4uYZKu
/hrKzySMTFFx3/ncWo6XAFpADQHXLOwRB9Xh1/OwigHiqs/wHRAAmnrlkwCCQA8r
xiHBAMjp5ApbkyggQz/DVijrpSba6Tiy1cyBTZC3cvOK2FpJzsakJLhIXD1HaULO
ClyIJB/YrmHmQc8SL3Uzou5mMpdcBC2pzwmEW1cvQURpnvgrDF8V86GrQkjK6nIP
IEeuW6kbD5lWFAPfLf1ohDWex3yxeSFyXNRApJhbF4HrKFemPkOi7acsky38UomQ
jZgAMHPotJNkQvAHcnXHhg0FcWGdohv5bc/Ctt9GwZOzJxwyJLBBsSewbE310TZi
3oLU1TmvAgMBAAECgf8zrhi95+gdMeKRpwV+TnxOK5CXjqvo0vTcnr7Runf/c9On
WeUtRPr83E4LxuMcSGRqdTfoP0loUGb3EsYwZ+IDOnyWWvytfRoQdExSA2RM1PDo
GRiUN4Dy8CrGNqvnb3agG99Ay3Ura6q5T20n9ykM4qKL3yDrO9fmWyMgRJbAOAYm
xzf7H910mDZghXPpq8nzDky0JLNZcaqbxuPQ3+EI4p2dLNXbNqMPs8Y20JKLeOPs
HikRM0zfhHEJSt5IPFQ54/CzscGHGeCleQINWTgvDLMcE5fJMvbLLZixV+YsBfAq
e2JsSubS+9RI2ktMlSKaemr8yeoIpsXfAiJSHkECgYEA0NKU18xK+9w5IXfgNwI4
peu2tWgwyZSp5R2pdLT7O1dJoLYRoAmcXNePB0VXNARqGxTNypJ9zmMawNmf3YRS
BqG8aKz7qpATlx9OwYlk09fsS6MeVmaur8bHGHP6O+gt7Xg+zhiFPvU9P5LB+C0Z
0d4grEmIxNhJCtJRQOThD8ECgYEA0GKRO9SJdnhw1b6LPLd+o/AX7IEzQDHwdtfi
0h7hKHHGBlUMbIBwwjKmyKm6cSe0PYe96LqrVg+cVf84wbLZPAixhOjyplLznBzF
LqOrfFPfI5lQVhslE1H1CdLlk9eyT96jDgmLAg8EGSMV8aLGj++Gi2l/isujHlWF
BI4YpW8CgYEAsyKyhJzABmbYq5lGQmopZkxapCwJDiP1ypIzd+Z5TmKGytLlM8CK
3iocjEQzlm/jBfBGyWv5eD8UCDOoLEMCiqXcFn+uNJb79zvoN6ZBVGl6TzhTIhNb
73Y5/QQguZtnKrtoRSxLwcJnFE41D0zBRYOjy6gZJ6PSpPHeuiid2QECgYACuZc+
mgvmIbMQCHrXo2qjiCs364SZDU4gr7gGmWLGXZ6CTLBp5tASqgjmTNnkSumfeFvy
ZCaDbJbVxQ2f8s/GajKwEz/BDwqievnVH0zJxmr/kyyqw5Ybh5HVvA1GfqaVRssJ
DvTjZQDft0a9Lyy7ix1OS2XgkcMjTWj840LNPwKBgDPXMBgL5h41jd7jCsXzPhyr
V96RzQkPcKsoVvrCoNi8eoEYgRd9jwfiU12rlXv+fgVXrrfMoJBoYT6YtrxEJVdM
RAjRpnE8PMqCUA8Rd7RFK9Vp5Uo8RxTNvk9yPvDv1+lHHV7lEltIk5PXuKPHIrc1
nNUyhzvJs2Qba2L/huNC
-----END PRIVATE KEY-----"""
return key


@pytest.fixture()
def public_key() -> str:
"""Yields a public key in PEM format"""
key = """-----BEGIN PUBLIC KEY-----
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAqfuABGTOLmGSrv4ays8k
jExRcd/53FqOlwBaQA0B1yzsEQfV4dfzsIoB4qrP8B0QAJp65ZMAgkAPK8YhwQDI
6eQKW5MoIEM/w1Yo66Um2uk4stXMgU2Qt3LzithaSc7GpCS4SFw9R2lCzgpciCQf
2K5h5kHPEi91M6LuZjKXXAQtqc8JhFtXL0FEaZ74KwxfFfOhq0JIyupyDyBHrlup
Gw+ZVhQD3y39aIQ1nsd8sXkhclzUQKSYWxeB6yhXpj5Dou2nLJMt/FKJkI2YADBz
6LSTZELwB3J1x4YNBXFhnaIb+W3PwrbfRsGTsyccMiSwQbEnsGxN9dE2Yt6C1NU5
rwIDAQAB
-----END PUBLIC KEY-----"""
return key


@pytest.fixture(scope="module")
def nav_name() -> str:
return "nav"

0 comments on commit d1bfbc3

Please sign in to comment.