-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
adding JWT security to all endpoints
- Loading branch information
1 parent
7f492b6
commit 1f563fe
Showing
4 changed files
with
247 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# SPDX-FileCopyrightText: 2022 Renaissance Computing Institute. All rights reserved. | ||
# SPDX-FileCopyrightText: 2023 Renaissance Computing Institute. All rights reserved. | ||
# | ||
# SPDX-License-Identifier: GPL-3.0-or-later | ||
# SPDX-License-Identifier: LicenseRef-RENCI | ||
# SPDX-License-Identifier: MIT | ||
|
||
""" | ||
JWT bearer utilities. | ||
Author: Phil Owen, 6/27/2023 | ||
""" | ||
from fastapi import Request, HTTPException | ||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | ||
|
||
from src.common.security import Security | ||
|
||
|
||
class JWTBearer(HTTPBearer): | ||
""" | ||
class to handle JWT operations | ||
""" | ||
def __init__(self, sec: Security, auto_error: bool = True): | ||
# save the security object | ||
self.sec = sec | ||
|
||
# call the superclass to init | ||
super().__init__(auto_error=auto_error) | ||
|
||
async def __call__(self, request: Request): | ||
""" | ||
called by fastapi to authenticate the request | ||
:param request: | ||
:return: | ||
""" | ||
# get the JWT Bearer token from the request | ||
auth: HTTPAuthorizationCredentials = await super().__call__(request) | ||
|
||
# if we got the bearer creds | ||
if auth: | ||
# make sure that the request has an auth bearer | ||
if not auth.scheme == "Bearer": | ||
# raise error if no bearer specified | ||
raise HTTPException(status_code=403, detail="Invalid authentication scheme.") | ||
|
||
# decode and validate the JWT auth token | ||
if not self.sec.decode_jwt(auth.credentials): | ||
raise HTTPException(status_code=403, detail="Invalid authentication token.") | ||
|
||
# return the JWT creds | ||
return auth.credentials |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
# SPDX-FileCopyrightText: 2022 Renaissance Computing Institute. All rights reserved. | ||
# SPDX-FileCopyrightText: 2023 Renaissance Computing Institute. All rights reserved. | ||
# | ||
# SPDX-License-Identifier: GPL-3.0-or-later | ||
# SPDX-License-Identifier: LicenseRef-RENCI | ||
# SPDX-License-Identifier: MIT | ||
|
||
""" | ||
Security utilities. | ||
Author: Phil Owen, 6/27/2023 | ||
""" | ||
import os | ||
import jwt | ||
|
||
from pydantic import BaseModel, Field | ||
|
||
|
||
class BearerSchema(BaseModel): | ||
""" | ||
declare a data model for the Bearer details | ||
""" | ||
bearer_name: str = Field(...) | ||
bearer_secret: str = Field(...) | ||
|
||
class Config: | ||
""" | ||
an example usage of the model | ||
""" | ||
schema_extra = {"bearer_name": "SomeBearerName", "bearer_secret": "SomeBearerSecret"} | ||
|
||
|
||
class Security: | ||
""" | ||
Methods to handle security | ||
""" | ||
|
||
def __init__(self): | ||
""" | ||
Init this class with the JWT params | ||
""" | ||
self.bearer_name = os.environ.get('BEARER_NAME') | ||
self.bearer_secret = os.environ.get('BEARER_SECRET') | ||
self.jwt_algorithm = os.environ.get('JWT_ALGORITHM') | ||
self.jwt_secret = os.environ.get('JWT_SECRET') | ||
|
||
def sign_jwt(self, token_def: dict): | ||
""" | ||
creates and returns a signed token | ||
:return: | ||
""" | ||
# create the jwt token | ||
jwt_token = jwt.encode(token_def, self.jwt_secret, algorithm=self.jwt_algorithm) | ||
|
||
# return the new token | ||
return {"access_token": jwt_token} | ||
|
||
def decode_jwt(self, token: str) -> bool: | ||
""" | ||
decodes and validates the JWT token | ||
:param token: | ||
:return: | ||
""" | ||
# init the return | ||
ret_val: bool = False | ||
|
||
try: | ||
# try to decode the token passed | ||
decoded_token = jwt.decode(token, self.jwt_secret, algorithms=[self.jwt_algorithm]) | ||
|
||
# verify that the token is legit | ||
if 'bearer_name' in decoded_token and decoded_token['bearer_name'] == self.bearer_name and 'bearer_secret' in decoded_token and \ | ||
decoded_token['bearer_secret'] == self.bearer_secret: | ||
ret_val = True | ||
|
||
except Exception: | ||
# trap a decode error | ||
ret_val = False | ||
|
||
# return to the caller | ||
return ret_val |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
# SPDX-FileCopyrightText: 2022 Renaissance Computing Institute. All rights reserved. | ||
# SPDX-FileCopyrightText: 2023 Renaissance Computing Institute. All rights reserved. | ||
# | ||
# SPDX-License-Identifier: GPL-3.0-or-later | ||
# SPDX-License-Identifier: LicenseRef-RENCI | ||
# SPDX-License-Identifier: MIT | ||
|
||
""" | ||
Settings tests. | ||
Author: Phil Owen, 6/27/2023 | ||
""" | ||
import os | ||
|
||
import requests | ||
import pytest | ||
|
||
from src.common.security import Security | ||
|
||
|
||
#@pytest.mark.skip(reason="Local test only") | ||
def test_sign_jwt(): | ||
""" | ||
tests the creation of a JWT token | ||
:return: | ||
""" | ||
# create a security object | ||
sec = Security() | ||
|
||
# create a payload for the token generation | ||
payload = {'bearer_name': os.environ.get("BEARER_NAME"), 'bearer_secret': os.environ.get("BEARER_SECRET")} | ||
|
||
# create a new token | ||
token = sec.sign_jwt(payload) | ||
|
||
# check the result | ||
assert token and 'access_token' in token | ||
|
||
|
||
@pytest.mark.skip(reason="Local test only") | ||
def test_decode_jwt(): | ||
""" | ||
tests the decode and validation of a JWT token | ||
:return: | ||
""" | ||
# create a security object | ||
sec = Security() | ||
|
||
# create a payload for the token generation | ||
payload = {'bearer_name': os.environ.get("BEARER_NAME"), 'bearer_secret': os.environ.get("BEARER_SECRET")} | ||
|
||
# create a new token | ||
token = sec.sign_jwt(payload) | ||
|
||
# decode the jwt token | ||
ret_val = sec.decode_jwt(token['access_token']) | ||
|
||
# validate the result | ||
assert ret_val | ||
|
||
# decode the jwt token | ||
ret_val = sec.decode_jwt(token['access_token'] + 'this-will-fail') | ||
|
||
assert not ret_val | ||
|
||
|
||
#@pytest.mark.skip(reason="Local test only") | ||
def test_access(): | ||
""" | ||
makes a secure request to the app running locally | ||
:return: | ||
""" | ||
# create a security object | ||
sec = Security() | ||
|
||
# create a payload for the token generation | ||
payload = {'bearer_name': os.environ.get("BEARER_NAME"), 'bearer_secret': os.environ.get("BEARER_SECRET")} | ||
|
||
# create a new token | ||
token = sec.sign_jwt(payload) | ||
|
||
# create an auth header | ||
auth_header: dict = {'Content-Type': 'application/json', 'Authorization': f'Bearer {token["access_token"]}'} | ||
|
||
# execute the post | ||
ret_val = requests.get('http://localhost:4000/get_job_order/ASGS', headers=auth_header, timeout=10) | ||
|
||
# was the call unsuccessful | ||
assert ret_val.status_code == 200 |