From 1f563fe9c828650e3e417feed5b1412ee6cd533b Mon Sep 17 00:00:00 2001 From: Phil Owen <19691521+PhillipsOwen@users.noreply.github.com> Date: Thu, 13 Jul 2023 14:53:50 -0400 Subject: [PATCH] adding JWT security to all endpoints --- src/common/bearer.py | 53 ++++++++++++++++++++++ src/common/security.py | 85 ++++++++++++++++++++++++++++++++++++ src/server.py | 28 +++++++----- src/test/test_security.py | 92 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 247 insertions(+), 11 deletions(-) create mode 100644 src/common/bearer.py create mode 100644 src/common/security.py create mode 100644 src/test/test_security.py diff --git a/src/common/bearer.py b/src/common/bearer.py new file mode 100644 index 0000000..b85da0d --- /dev/null +++ b/src/common/bearer.py @@ -0,0 +1,53 @@ +# SPDX-FileCopyrightText: 2022 Renaissance Computing Institute. All rights reserved. +# SPDX-FileCopyrightText: 2023 Renaissance Computing Institute. All rights reserved. +# +# SPDX-License-Identifier: GPL-3.0-or-later +# SPDX-License-Identifier: LicenseRef-RENCI +# SPDX-License-Identifier: MIT + +""" + JWT bearer utilities. + + Author: Phil Owen, 6/27/2023 +""" +from fastapi import Request, HTTPException +from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials + +from src.common.security import Security + + +class JWTBearer(HTTPBearer): + """ + class to handle JWT operations + + """ + def __init__(self, sec: Security, auto_error: bool = True): + # save the security object + self.sec = sec + + # call the superclass to init + super().__init__(auto_error=auto_error) + + async def __call__(self, request: Request): + """ + called by fastapi to authenticate the request + + :param request: + :return: + """ + # get the JWT Bearer token from the request + auth: HTTPAuthorizationCredentials = await super().__call__(request) + + # if we got the bearer creds + if auth: + # make sure that the request has an auth bearer + if not auth.scheme == "Bearer": + # raise error if no bearer specified + raise HTTPException(status_code=403, detail="Invalid authentication scheme.") + + # decode and validate the JWT auth token + if not self.sec.decode_jwt(auth.credentials): + raise HTTPException(status_code=403, detail="Invalid authentication token.") + + # return the JWT creds + return auth.credentials diff --git a/src/common/security.py b/src/common/security.py new file mode 100644 index 0000000..93bd0af --- /dev/null +++ b/src/common/security.py @@ -0,0 +1,85 @@ +# SPDX-FileCopyrightText: 2022 Renaissance Computing Institute. All rights reserved. +# SPDX-FileCopyrightText: 2023 Renaissance Computing Institute. All rights reserved. +# +# SPDX-License-Identifier: GPL-3.0-or-later +# SPDX-License-Identifier: LicenseRef-RENCI +# SPDX-License-Identifier: MIT + +""" + Security utilities. + + Author: Phil Owen, 6/27/2023 +""" +import os +import jwt + +from pydantic import BaseModel, Field + + +class BearerSchema(BaseModel): + """ + declare a data model for the Bearer details + """ + bearer_name: str = Field(...) + bearer_secret: str = Field(...) + + class Config: + """ + an example usage of the model + """ + schema_extra = {"bearer_name": "SomeBearerName", "bearer_secret": "SomeBearerSecret"} + + +class Security: + """ + Methods to handle security + + """ + + def __init__(self): + """ + Init this class with the JWT params + + """ + self.bearer_name = os.environ.get('BEARER_NAME') + self.bearer_secret = os.environ.get('BEARER_SECRET') + self.jwt_algorithm = os.environ.get('JWT_ALGORITHM') + self.jwt_secret = os.environ.get('JWT_SECRET') + + def sign_jwt(self, token_def: dict): + """ + creates and returns a signed token + + :return: + """ + # create the jwt token + jwt_token = jwt.encode(token_def, self.jwt_secret, algorithm=self.jwt_algorithm) + + # return the new token + return {"access_token": jwt_token} + + def decode_jwt(self, token: str) -> bool: + """ + decodes and validates the JWT token + + :param token: + :return: + """ + # init the return + ret_val: bool = False + + try: + # try to decode the token passed + decoded_token = jwt.decode(token, self.jwt_secret, algorithms=[self.jwt_algorithm]) + + # verify that the token is legit + if 'bearer_name' in decoded_token and decoded_token['bearer_name'] == self.bearer_name and 'bearer_secret' in decoded_token and \ + decoded_token['bearer_secret'] == self.bearer_secret: + ret_val = True + + except Exception: + # trap a decode error + ret_val = False + + # return to the caller + return ret_val diff --git a/src/server.py b/src/server.py index 183fd70..aff8e5f 100644 --- a/src/server.py +++ b/src/server.py @@ -15,13 +15,15 @@ from pathlib import Path -from fastapi import FastAPI, Query, Request +from fastapi import FastAPI, Query, Request, Depends from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, FileResponse from src.common.logger import LoggingUtil from src.common.pg_impl import PGImplementation from src.common.utils import GenUtils, WorkflowTypeName, ImageRepo, RunStatus, JobTypeName, NextJobTypeName +from src.common.security import Security +from src.common.bearer import JWTBearer # set the app version app_version = os.getenv('APP_VERSION', 'Version number not set') @@ -48,8 +50,11 @@ # create a DB connection object with auto-commit turned off db_info_no_auto_commit: PGImplementation = PGImplementation(db_names, _logger=logger, _auto_commit=False) +# create a Security object +security = Security() -@APP.get('/get_job_order/{workflow_type_name}', status_code=200, response_model=None) + +@APP.get('/get_job_order/{workflow_type_name}', dependencies=[Depends(JWTBearer(security))], status_code=200, response_model=None) async def display_job_order(workflow_type_name: WorkflowTypeName) -> json: """ Displays the job order for the workflow type selected. @@ -77,7 +82,7 @@ async def display_job_order(workflow_type_name: WorkflowTypeName) -> json: return JSONResponse(content=ret_val, status_code=status_code, media_type="application/json") -@APP.get('/reset_job_order/{workflow_type_name}', status_code=200, response_model=None) +@APP.get('/reset_job_order/{workflow_type_name}', dependencies=[Depends(JWTBearer(security))], status_code=200, response_model=None) async def reset_job_order(workflow_type_name: WorkflowTypeName) -> json: """ resets the job process order to the default for the workflow selected. @@ -128,7 +133,7 @@ async def reset_job_order(workflow_type_name: WorkflowTypeName) -> json: return JSONResponse(content=ret_val, status_code=status_code, media_type="application/json") -@APP.get('/get_job_defs', status_code=200, response_model=None) +@APP.get('/get_job_defs', dependencies=[Depends(JWTBearer(security))], status_code=200, response_model=None) async def display_job_definitions() -> json: """ Displays the job definitions for all workflows. Note that this list is in alphabetical order (not in job execute order). @@ -174,7 +179,7 @@ async def display_job_definitions() -> json: return JSONResponse(content=job_config_data, status_code=status_code, media_type="application/json") -@APP.get("/get_log_file_list", response_model=None) +@APP.get("/get_log_file_list", dependencies=[Depends(JWTBearer(security))], response_model=None) async def get_the_log_file_list(request: Request): """ Gets the log file list. each of these entries could be used in the get_log_file endpoint @@ -186,7 +191,7 @@ async def get_the_log_file_list(request: Request): media_type="application/json") -@APP.get("/get_log_file/", response_model=None) +@APP.get("/get_log_file/", dependencies=[Depends(JWTBearer(security))], response_model=None) async def get_the_log_file(log_file: str = Query('log_file')): """ Gets the log file specified. This method only expects a properly named file. @@ -209,7 +214,7 @@ async def get_the_log_file(log_file: str = Query('log_file')): return JSONResponse(content={'Response': 'Error - Log file does not exist.'}, status_code=404, media_type="application/json") -@APP.get("/get_run_list", status_code=200, response_model=None) +@APP.get("/get_run_list", dependencies=[Depends(JWTBearer(security))], status_code=200, response_model=None) async def get_the_run_list(): """ Gets the run information for the last 100 runs. @@ -245,7 +250,7 @@ async def get_the_run_list(): # sets the run.properties run status to 'new' for a job -@APP.put('/instance_id/{instance_id}/uid/{uid}/status/{status}', status_code=200, response_model=None) +@APP.put('/instance_id/{instance_id}/uid/{uid}/status/{status}', dependencies=[Depends(JWTBearer(security))], status_code=200, response_model=None) async def set_the_run_status(instance_id: int, uid: str, status: RunStatus = RunStatus('new')): """ Updates the run status of a selected job. @@ -288,7 +293,8 @@ async def set_the_run_status(instance_id: int, uid: str, status: RunStatus = Run # Updates the image version for a job -@APP.put('/image_repo/{image_repo}/job_type_name/{job_type_name}/image_version/{version}', status_code=200, response_model=None) +@APP.put('/image_repo/{image_repo}/job_type_name/{job_type_name}/image_version/{version}', dependencies=[Depends(JWTBearer(security))], + status_code=200, response_model=None) async def set_the_supervisor_component_image_version(image_repo: ImageRepo, job_type_name: JobTypeName, version: str): """ Updates a supervisor component image version label in the supervisor job run configuration. @@ -358,8 +364,8 @@ async def set_the_supervisor_component_image_version(image_repo: ImageRepo, job_ # Updates a supervisor component's next process. -@APP.put('/workflow_type_name/{workflow_type_name}/job_type_name/{job_type_name}/next_job_type/{next_job_type_name}', status_code=200, - response_model=None) +@APP.put('/workflow_type_name/{workflow_type_name}/job_type_name/{job_type_name}/next_job_type/{next_job_type_name}', + dependencies=[Depends(JWTBearer(security))], status_code=200, response_model=None) async def set_the_supervisor_job_order(workflow_type_name: WorkflowTypeName, job_type_name: JobTypeName, next_job_type_name: NextJobTypeName): """ Modifies the supervisor component's linked list of jobs. Select the workflow type, then select the job process name and the next job diff --git a/src/test/test_security.py b/src/test/test_security.py new file mode 100644 index 0000000..340b2e4 --- /dev/null +++ b/src/test/test_security.py @@ -0,0 +1,92 @@ +# SPDX-FileCopyrightText: 2022 Renaissance Computing Institute. All rights reserved. +# SPDX-FileCopyrightText: 2023 Renaissance Computing Institute. All rights reserved. +# +# SPDX-License-Identifier: GPL-3.0-or-later +# SPDX-License-Identifier: LicenseRef-RENCI +# SPDX-License-Identifier: MIT + +""" + Settings tests. + + Author: Phil Owen, 6/27/2023 +""" +import os + +import requests +import pytest + +from src.common.security import Security + + +#@pytest.mark.skip(reason="Local test only") +def test_sign_jwt(): + """ + tests the creation of a JWT token + + :return: + """ + # create a security object + sec = Security() + + # create a payload for the token generation + payload = {'bearer_name': os.environ.get("BEARER_NAME"), 'bearer_secret': os.environ.get("BEARER_SECRET")} + + # create a new token + token = sec.sign_jwt(payload) + + # check the result + assert token and 'access_token' in token + + +@pytest.mark.skip(reason="Local test only") +def test_decode_jwt(): + """ + tests the decode and validation of a JWT token + + :return: + """ + # create a security object + sec = Security() + + # create a payload for the token generation + payload = {'bearer_name': os.environ.get("BEARER_NAME"), 'bearer_secret': os.environ.get("BEARER_SECRET")} + + # create a new token + token = sec.sign_jwt(payload) + + # decode the jwt token + ret_val = sec.decode_jwt(token['access_token']) + + # validate the result + assert ret_val + + # decode the jwt token + ret_val = sec.decode_jwt(token['access_token'] + 'this-will-fail') + + assert not ret_val + + +#@pytest.mark.skip(reason="Local test only") +def test_access(): + """ + makes a secure request to the app running locally + + :return: + """ + # create a security object + sec = Security() + + # create a payload for the token generation + payload = {'bearer_name': os.environ.get("BEARER_NAME"), 'bearer_secret': os.environ.get("BEARER_SECRET")} + + # create a new token + token = sec.sign_jwt(payload) + + # create an auth header + auth_header: dict = {'Content-Type': 'application/json', 'Authorization': f'Bearer {token["access_token"]}'} + + # execute the post + ret_val = requests.get('http://localhost:4000/get_job_order/ASGS', headers=auth_header, timeout=10) + + # was the call unsuccessful + assert ret_val.status_code == 200