Skip to content

Commit

Permalink
refactor api
Browse files Browse the repository at this point in the history
format code

add options for rate limiting, delay and more for API
  • Loading branch information
dmdhrumilmistry committed Nov 14, 2023
1 parent 7a3d253 commit df32e84
Show file tree
Hide file tree
Showing 11 changed files with 429 additions and 388 deletions.
102 changes: 44 additions & 58 deletions src/offat/api/app.py
Original file line number Diff line number Diff line change
@@ -1,101 +1,87 @@
from fastapi import status, Request, Response
from json import loads as json_loads
from yaml import SafeLoader, load as yaml_loads
from .config import app, task_queue, task_timeout, auth_secret_key
from .jobs import scan_api
from .models import CreateScanModel
from ..logger import create_logger
from offat.api.config import app, task_queue, task_timeout, auth_secret_key
from offat.api.jobs import scan_api
from offat.api.models import CreateScanModel
from offat.logger import create_logger
from os import uname, environ


logger = create_logger(__name__)
logger.info(f'Secret Key: {auth_secret_key}')


if uname().sysname == 'Darwin' and environ.get('OBJC_DISABLE_INITIALIZE_FORK_SAFETY') != 'YES':
logger.warning('Mac Users might need to configure OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES in env\nVisit StackOverFlow link for more info: https://stackoverflow.com/questions/50168647/multiprocessing-causes-python-to-crash-and-gives-an-error-may-have-been-in-progr')


@app.get('/', status_code=status.HTTP_200_OK)
async def root():
return {
"name":"OFFAT API",
"project":"https://github.com/dmdhrumilmistry/offat",
"license":"https://github.com/dmdhrumilmistry/offat/blob/main/LICENSE",
"name": "OFFAT API",
"project": "https://github.com/dmdhrumilmistry/offat",
"license": "https://github.com/dmdhrumilmistry/offat/blob/main/LICENSE",
}


@app.post('/api/v1/scan', status_code=status.HTTP_201_CREATED)
async def add_scan_task(scan_data: CreateScanModel, request:Request ,response: Response):
async def add_scan_task(scan_data: CreateScanModel, request: Request, response: Response):
# for auth
client_ip = request.client.host
secret_key = request.headers.get('SECRET-KEY', None)
if secret_key != auth_secret_key:
# return 404 for better endpoint security
response.status_code = status.HTTP_401_UNAUTHORIZED
logger.warning(f'INTRUSION: {client_ip} tried to create a new scan job')
return {"message":"Unauthorized"}

openapi_doc = scan_data.openAPI
file_data_type = scan_data.type
logger.warning(
f'INTRUSION: {client_ip} tried to create a new scan job')
return {"message": "Unauthorized"}

msg = {
"msg":"Scan Task Created",
"msg": "Scan Task Created",
"job_id": None
}
create_task = True

match file_data_type:
case 'json':
openapi_doc = json_loads(openapi_doc)
case 'yaml':
openapi_doc = yaml_loads(openapi_doc, SafeLoader)
case _:
response.status_code = status.HTTP_400_BAD_REQUEST
msg = {
"msg":"Invalid Request Data"
}
create_task = False

if create_task:
job = task_queue.enqueue(scan_api, openapi_doc, job_timeout=task_timeout)
msg['job_id'] = job.id

logger.info(f'SUCCESS: {client_ip} created new scan job - {job.id}')
else:
logger.error(f'FAILED: {client_ip} tried creating new scan job but it failed due to unknown file data type')

job = task_queue.enqueue(scan_api, scan_data, job_timeout=task_timeout)
msg['job_id'] = job.id

logger.info(f'SUCCESS: {client_ip} created new scan job - {job.id}')

return msg


@app.get('/api/v1/scan/{job_id}/result')
async def get_scan_task_result(job_id:str, request: Request, response:Response):
async def get_scan_task_result(job_id: str, request: Request, response: Response):
# for auth
client_ip = request.client.host
secret_key = request.headers.get('SECRET-KEY', None)
if secret_key != auth_secret_key:
# return 404 for better endpoint security
response.status_code = status.HTTP_401_UNAUTHORIZED
logger.warning(f'INTRUSION: {client_ip} tried to access {job_id} job scan results')
return {"message":"Unauthorized"}

scan_results = task_queue.fetch_job(job_id=job_id)
logger.warning(
f'INTRUSION: {client_ip} tried to access {job_id} job scan results')
return {"message": "Unauthorized"}

scan_results_job = task_queue.fetch_job(job_id=job_id)

logger.info(f'SUCCESS: {client_ip} accessed {job_id} job scan results')

msg = {
'msg':'Task Remaining or Invalid Job Id',
'results': None,
}
msg = 'Task Remaining or Invalid Job Id'
results = None
response.status_code = status.HTTP_202_ACCEPTED

if scan_results and scan_results.is_finished:
msg = {
'msg':'Task Completed',
'results': scan_results.result,
}
response.status_code = status.HTTP_200_OK
if scan_results_job and scan_results_job.is_started:
msg = 'Job In Progress'

elif scan_results_job and scan_results_job.is_finished:
msg = 'Task Completed'
results = scan_results_job.result
response.status_code = status.HTTP_200_OK

elif scan_results and scan_results.is_failed:
msg = {
'msg':'Task Failed. Try Creating Task Again.',
'results': None,
}
elif scan_results_job and scan_results_job.is_failed:
msg = 'Task Failed. Try Creating Task Again.'
response.status_code = status.HTTP_200_OK

return msg
msg = {
'msg': msg,
'results': results,
}
return msg
2 changes: 1 addition & 1 deletion src/offat/api/auth_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ def generate_random_secret_key_string(length=128):
# Generate a random string of the specified length
random_string = ''.join(secrets.choice(characters) for _ in range(length))

return random_string
return random_string
8 changes: 5 additions & 3 deletions src/offat/api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
title='OFFAT - API'
)

auth_secret_key = environ.get('AUTH_SECRET_KEY', generate_random_secret_key_string())
redis_con = Redis(host=environ.get('REDIS_HOST','localhost'), port=int(environ.get('REDIS_PORT',6379)))
auth_secret_key = environ.get(
'AUTH_SECRET_KEY', generate_random_secret_key_string())
redis_con = Redis(host=environ.get('REDIS_HOST', 'localhost'),
port=int(environ.get('REDIS_PORT', 6379)))
task_queue = Queue(name='offat_task_queue', connection=redis_con)
task_timeout = 60 * 60 # 3600 s = 1 hour
task_timeout = 60 * 60 # 3600 s = 1 hour
41 changes: 25 additions & 16 deletions src/offat/api/jobs.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,28 @@
from ..tester.tester_utils import generate_and_run_tests
from ..openapi import OpenAPIParser
from traceback import print_exception
from offat.api.models import CreateScanModel
from offat.tester.tester_utils import generate_and_run_tests
from offat.openapi import OpenAPIParser
from offat.logger import create_logger


def scan_api(open_api:dict):
# TODO: validate `open_api` str against openapi specs.
api_parser = OpenAPIParser(fpath=None,spec=open_api)
logger = create_logger(__name__)

# TODO: accept commented options from API
results = generate_and_run_tests(
api_parser=api_parser,
# regex_pattern=args.path_regex_pattern,
# output_file=args.output_file,
# req_headers=headers_dict,
# rate_limit=rate_limit,
# delay=delay_rate,
# test_data_config=test_data_config,
)
return results

def scan_api(body_data: CreateScanModel):
try:
logger.info('test')
api_parser = OpenAPIParser(fpath_or_url=None, spec=body_data.openAPI)

results = generate_and_run_tests(
api_parser=api_parser,
regex_pattern=body_data.regex_pattern,
req_headers=body_data.req_headers,
rate_limit=body_data.rate_limit,
delay=body_data.delay,
test_data_config=body_data.test_data_config,
)
return results
except Exception as e:
logger.error(f'Error occurred while creating a job: {e}')
print_exception(e)
return [{'error': str(e)}]
8 changes: 7 additions & 1 deletion src/offat/api/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
from typing import Optional
from pydantic import BaseModel


class CreateScanModel(BaseModel):
openAPI: str
type: str
regex_pattern: Optional[str] = None
req_headers: Optional[dict] = None
rate_limit: Optional[int] = None
delay: Optional[float] = None
test_data_config: Optional[dict] = None
3 changes: 2 additions & 1 deletion src/offat/http.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from aiohttp import ClientSession, ClientResponse, TCPConnector
from os import name as os_name
from typing import Optional


import asyncio
Expand All @@ -14,7 +15,7 @@ class AsyncRequests:
AsyncRequests class helps to send HTTP requests with rate limiting options.
'''

def __init__(self, rate_limit:int=None, delay:float=None, headers: dict = None, proxy:str = None, ssl:bool=True, allow_redirects: bool=True) -> None:
def __init__(self, rate_limit:Optional[int]=None, delay:Optional[float]=None, headers: Optional[dict] = None, proxy:Optional[str] = None, ssl:Optional[bool]=True, allow_redirects: Optional[bool]=True) -> None:
'''AsyncRequests class constructor
Args:
Expand Down
37 changes: 20 additions & 17 deletions src/offat/report/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,46 +12,48 @@

class ReportGenerator:
@staticmethod
def generate_html_report(results:list[dict]):
def generate_html_report(results: list[dict]):
html_report_template_file_name = 'report.html'
html_report_file_path = path_join(dirname(templates.__file__),html_report_template_file_name)
html_report_file_path = path_join(
dirname(templates.__file__), html_report_template_file_name)

with open(html_report_file_path, 'r') as f:
report_file_content = f.read()

# TODO: validate report path to avoid injection attacks.
if not isinstance(results, list):
raise ValueError('results arg expects a list[dict].')

report_file_content = report_file_content.replace('{ results }', json_dumps(results))

report_file_content = report_file_content.replace(
'{ results }', json_dumps(results))

return report_file_content

@staticmethod
def handle_report_format(results:list[dict], report_format:str) -> str:
def handle_report_format(results: list[dict], report_format: str) -> str:
result = None

match report_format:
case 'html':
logger.warning('HTML output format displays only basic data.')
result = ReportGenerator.generate_html_report(results=results)
case 'yaml':
logger.warning('YAML output format needs to be sanitized before using it further.')
logger.warning(
'YAML output format needs to be sanitized before using it further.')
result = yaml_dump({
'results':results,
'results': results,
})
case _: # default json format
case _: # default json format
report_format = 'json'
result = json_dumps({
'results':results,
'results': results,
})

logger.info(f'Generated {report_format.upper()} format report.')
return result


@staticmethod
def save_report(report_path:str, report_file_content:str):
def save_report(report_path: str, report_file_content: str):
if report_path != '/':
dir_name = dirname(report_path)
makedirs(dir_name, exist_ok=True)
Expand All @@ -60,8 +62,9 @@ def save_report(report_path:str, report_file_content:str):
logger.info(f'Writing report to file: {report_path}')
f.write(report_file_content)


@staticmethod
def generate_report(results:list[dict], report_format:str, report_path:str):
formatted_results = ReportGenerator.handle_report_format(results=results, report_format=report_format)
ReportGenerator.save_report(report_path=report_path, report_file_content=formatted_results)
def generate_report(results: list[dict], report_format: str, report_path: str):
formatted_results = ReportGenerator.handle_report_format(
results=results, report_format=report_format)
ReportGenerator.save_report(
report_path=report_path, report_file_content=formatted_results)
Loading

0 comments on commit df32e84

Please sign in to comment.