Skip to content

Commit

Permalink
adding JWT security to all endpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
PhillipsOwen committed Jul 13, 2023
1 parent 7f492b6 commit 1f563fe
Show file tree
Hide file tree
Showing 4 changed files with 247 additions and 11 deletions.
53 changes: 53 additions & 0 deletions src/common/bearer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# SPDX-FileCopyrightText: 2022 Renaissance Computing Institute. All rights reserved.
# SPDX-FileCopyrightText: 2023 Renaissance Computing Institute. All rights reserved.
#
# SPDX-License-Identifier: GPL-3.0-or-later
# SPDX-License-Identifier: LicenseRef-RENCI
# SPDX-License-Identifier: MIT

"""
JWT bearer utilities.
Author: Phil Owen, 6/27/2023
"""
from fastapi import Request, HTTPException
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials

from src.common.security import Security


class JWTBearer(HTTPBearer):
"""
class to handle JWT operations
"""
def __init__(self, sec: Security, auto_error: bool = True):
# save the security object
self.sec = sec

# call the superclass to init
super().__init__(auto_error=auto_error)

async def __call__(self, request: Request):
"""
called by fastapi to authenticate the request
:param request:
:return:
"""
# get the JWT Bearer token from the request
auth: HTTPAuthorizationCredentials = await super().__call__(request)

# if we got the bearer creds
if auth:
# make sure that the request has an auth bearer
if not auth.scheme == "Bearer":
# raise error if no bearer specified
raise HTTPException(status_code=403, detail="Invalid authentication scheme.")

# decode and validate the JWT auth token
if not self.sec.decode_jwt(auth.credentials):
raise HTTPException(status_code=403, detail="Invalid authentication token.")

# return the JWT creds
return auth.credentials
85 changes: 85 additions & 0 deletions src/common/security.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# SPDX-FileCopyrightText: 2022 Renaissance Computing Institute. All rights reserved.
# SPDX-FileCopyrightText: 2023 Renaissance Computing Institute. All rights reserved.
#
# SPDX-License-Identifier: GPL-3.0-or-later
# SPDX-License-Identifier: LicenseRef-RENCI
# SPDX-License-Identifier: MIT

"""
Security utilities.
Author: Phil Owen, 6/27/2023
"""
import os
import jwt

from pydantic import BaseModel, Field


class BearerSchema(BaseModel):
"""
declare a data model for the Bearer details
"""
bearer_name: str = Field(...)
bearer_secret: str = Field(...)

class Config:
"""
an example usage of the model
"""
schema_extra = {"bearer_name": "SomeBearerName", "bearer_secret": "SomeBearerSecret"}


class Security:
"""
Methods to handle security
"""

def __init__(self):
"""
Init this class with the JWT params
"""
self.bearer_name = os.environ.get('BEARER_NAME')
self.bearer_secret = os.environ.get('BEARER_SECRET')
self.jwt_algorithm = os.environ.get('JWT_ALGORITHM')
self.jwt_secret = os.environ.get('JWT_SECRET')

def sign_jwt(self, token_def: dict):
"""
creates and returns a signed token
:return:
"""
# create the jwt token
jwt_token = jwt.encode(token_def, self.jwt_secret, algorithm=self.jwt_algorithm)

# return the new token
return {"access_token": jwt_token}

def decode_jwt(self, token: str) -> bool:
"""
decodes and validates the JWT token
:param token:
:return:
"""
# init the return
ret_val: bool = False

try:
# try to decode the token passed
decoded_token = jwt.decode(token, self.jwt_secret, algorithms=[self.jwt_algorithm])

# verify that the token is legit
if 'bearer_name' in decoded_token and decoded_token['bearer_name'] == self.bearer_name and 'bearer_secret' in decoded_token and \
decoded_token['bearer_secret'] == self.bearer_secret:
ret_val = True

except Exception:
# trap a decode error
ret_val = False

# return to the caller
return ret_val
28 changes: 17 additions & 11 deletions src/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@

from pathlib import Path

from fastapi import FastAPI, Query, Request
from fastapi import FastAPI, Query, Request, Depends
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, FileResponse

from src.common.logger import LoggingUtil
from src.common.pg_impl import PGImplementation
from src.common.utils import GenUtils, WorkflowTypeName, ImageRepo, RunStatus, JobTypeName, NextJobTypeName
from src.common.security import Security
from src.common.bearer import JWTBearer

# set the app version
app_version = os.getenv('APP_VERSION', 'Version number not set')
Expand All @@ -48,8 +50,11 @@
# create a DB connection object with auto-commit turned off
db_info_no_auto_commit: PGImplementation = PGImplementation(db_names, _logger=logger, _auto_commit=False)

# create a Security object
security = Security()

@APP.get('/get_job_order/{workflow_type_name}', status_code=200, response_model=None)

@APP.get('/get_job_order/{workflow_type_name}', dependencies=[Depends(JWTBearer(security))], status_code=200, response_model=None)
async def display_job_order(workflow_type_name: WorkflowTypeName) -> json:
"""
Displays the job order for the workflow type selected.
Expand Down Expand Up @@ -77,7 +82,7 @@ async def display_job_order(workflow_type_name: WorkflowTypeName) -> json:
return JSONResponse(content=ret_val, status_code=status_code, media_type="application/json")


@APP.get('/reset_job_order/{workflow_type_name}', status_code=200, response_model=None)
@APP.get('/reset_job_order/{workflow_type_name}', dependencies=[Depends(JWTBearer(security))], status_code=200, response_model=None)
async def reset_job_order(workflow_type_name: WorkflowTypeName) -> json:
"""
resets the job process order to the default for the workflow selected.
Expand Down Expand Up @@ -128,7 +133,7 @@ async def reset_job_order(workflow_type_name: WorkflowTypeName) -> json:
return JSONResponse(content=ret_val, status_code=status_code, media_type="application/json")


@APP.get('/get_job_defs', status_code=200, response_model=None)
@APP.get('/get_job_defs', dependencies=[Depends(JWTBearer(security))], status_code=200, response_model=None)
async def display_job_definitions() -> json:
"""
Displays the job definitions for all workflows. Note that this list is in alphabetical order (not in job execute order).
Expand Down Expand Up @@ -174,7 +179,7 @@ async def display_job_definitions() -> json:
return JSONResponse(content=job_config_data, status_code=status_code, media_type="application/json")


@APP.get("/get_log_file_list", response_model=None)
@APP.get("/get_log_file_list", dependencies=[Depends(JWTBearer(security))], response_model=None)
async def get_the_log_file_list(request: Request):
"""
Gets the log file list. each of these entries could be used in the get_log_file endpoint
Expand All @@ -186,7 +191,7 @@ async def get_the_log_file_list(request: Request):
media_type="application/json")


@APP.get("/get_log_file/", response_model=None)
@APP.get("/get_log_file/", dependencies=[Depends(JWTBearer(security))], response_model=None)
async def get_the_log_file(log_file: str = Query('log_file')):
"""
Gets the log file specified. This method only expects a properly named file.
Expand All @@ -209,7 +214,7 @@ async def get_the_log_file(log_file: str = Query('log_file')):
return JSONResponse(content={'Response': 'Error - Log file does not exist.'}, status_code=404, media_type="application/json")


@APP.get("/get_run_list", status_code=200, response_model=None)
@APP.get("/get_run_list", dependencies=[Depends(JWTBearer(security))], status_code=200, response_model=None)
async def get_the_run_list():
"""
Gets the run information for the last 100 runs.
Expand Down Expand Up @@ -245,7 +250,7 @@ async def get_the_run_list():


# sets the run.properties run status to 'new' for a job
@APP.put('/instance_id/{instance_id}/uid/{uid}/status/{status}', status_code=200, response_model=None)
@APP.put('/instance_id/{instance_id}/uid/{uid}/status/{status}', dependencies=[Depends(JWTBearer(security))], status_code=200, response_model=None)
async def set_the_run_status(instance_id: int, uid: str, status: RunStatus = RunStatus('new')):
"""
Updates the run status of a selected job.
Expand Down Expand Up @@ -288,7 +293,8 @@ async def set_the_run_status(instance_id: int, uid: str, status: RunStatus = Run


# Updates the image version for a job
@APP.put('/image_repo/{image_repo}/job_type_name/{job_type_name}/image_version/{version}', status_code=200, response_model=None)
@APP.put('/image_repo/{image_repo}/job_type_name/{job_type_name}/image_version/{version}', dependencies=[Depends(JWTBearer(security))],
status_code=200, response_model=None)
async def set_the_supervisor_component_image_version(image_repo: ImageRepo, job_type_name: JobTypeName, version: str):
"""
Updates a supervisor component image version label in the supervisor job run configuration.
Expand Down Expand Up @@ -358,8 +364,8 @@ async def set_the_supervisor_component_image_version(image_repo: ImageRepo, job_


# Updates a supervisor component's next process.
@APP.put('/workflow_type_name/{workflow_type_name}/job_type_name/{job_type_name}/next_job_type/{next_job_type_name}', status_code=200,
response_model=None)
@APP.put('/workflow_type_name/{workflow_type_name}/job_type_name/{job_type_name}/next_job_type/{next_job_type_name}',
dependencies=[Depends(JWTBearer(security))], status_code=200, response_model=None)
async def set_the_supervisor_job_order(workflow_type_name: WorkflowTypeName, job_type_name: JobTypeName, next_job_type_name: NextJobTypeName):
"""
Modifies the supervisor component's linked list of jobs. Select the workflow type, then select the job process name and the next job
Expand Down
92 changes: 92 additions & 0 deletions src/test/test_security.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# SPDX-FileCopyrightText: 2022 Renaissance Computing Institute. All rights reserved.
# SPDX-FileCopyrightText: 2023 Renaissance Computing Institute. All rights reserved.
#
# SPDX-License-Identifier: GPL-3.0-or-later
# SPDX-License-Identifier: LicenseRef-RENCI
# SPDX-License-Identifier: MIT

"""
Settings tests.
Author: Phil Owen, 6/27/2023
"""
import os

import requests
import pytest

from src.common.security import Security


#@pytest.mark.skip(reason="Local test only")
def test_sign_jwt():
"""
tests the creation of a JWT token
:return:
"""
# create a security object
sec = Security()

# create a payload for the token generation
payload = {'bearer_name': os.environ.get("BEARER_NAME"), 'bearer_secret': os.environ.get("BEARER_SECRET")}

# create a new token
token = sec.sign_jwt(payload)

# check the result
assert token and 'access_token' in token


@pytest.mark.skip(reason="Local test only")
def test_decode_jwt():
"""
tests the decode and validation of a JWT token
:return:
"""
# create a security object
sec = Security()

# create a payload for the token generation
payload = {'bearer_name': os.environ.get("BEARER_NAME"), 'bearer_secret': os.environ.get("BEARER_SECRET")}

# create a new token
token = sec.sign_jwt(payload)

# decode the jwt token
ret_val = sec.decode_jwt(token['access_token'])

# validate the result
assert ret_val

# decode the jwt token
ret_val = sec.decode_jwt(token['access_token'] + 'this-will-fail')

assert not ret_val


#@pytest.mark.skip(reason="Local test only")
def test_access():
"""
makes a secure request to the app running locally
:return:
"""
# create a security object
sec = Security()

# create a payload for the token generation
payload = {'bearer_name': os.environ.get("BEARER_NAME"), 'bearer_secret': os.environ.get("BEARER_SECRET")}

# create a new token
token = sec.sign_jwt(payload)

# create an auth header
auth_header: dict = {'Content-Type': 'application/json', 'Authorization': f'Bearer {token["access_token"]}'}

# execute the post
ret_val = requests.get('http://localhost:4000/get_job_order/ASGS', headers=auth_header, timeout=10)

# was the call unsuccessful
assert ret_val.status_code == 200

0 comments on commit 1f563fe

Please sign in to comment.