diff --git a/src/routes/turnilo_dashboard_routes.py b/src/routes/turnilo_dashboard_routes.py index aeb4301..2492520 100644 --- a/src/routes/turnilo_dashboard_routes.py +++ b/src/routes/turnilo_dashboard_routes.py @@ -1,5 +1,6 @@ -from fastapi import APIRouter, Depends, HTTPException -from typing import List +import re +from fastapi import APIRouter, Depends, HTTPException, Query +from typing import List, Optional from sqlalchemy.orm import Session from models.turnilo_dashboard import TurniloDashboard from services import turnilo_dashboards as td @@ -10,6 +11,13 @@ turnilo_router = APIRouter() ### Dashboards ### +def is_valid_query(s: str) -> bool: + if not isinstance(s, str): + return False + if len(s) > 256: + return False + pattern = r'^[a-zA-Z0-9_-]+$' + return bool(re.match(pattern, s)) @turnilo_router.get( @@ -17,8 +25,12 @@ response_model=List[TurniloDashboard], summary="Gets all Turnilo dashboards" ) -def turnilo_get_dashboards(db_session: Session = Depends(db.get_session)): - return td.dashboards_get_all(db_session) +def turnilo_get_dashboards(db_session: Session = Depends(db.get_session), shortName: str=Query(None), dataCube: str=Query(None)): + if shortName and not is_valid_query(shortName): + raise HTTPException(status_code=400, detail=f"Invalid shortName='{shortName}'") + if dataCube and not is_valid_query(dataCube): + raise HTTPException(status_code=400, detail=f"Invalid dataCube='{dataCube}'") + return td.dashboards_get_all(db_session, shortName, dataCube) @turnilo_router.get( diff --git a/src/services/turnilo_dashboards.py b/src/services/turnilo_dashboards.py index b754742..c3ee7ed 100644 --- a/src/services/turnilo_dashboards.py +++ b/src/services/turnilo_dashboards.py @@ -7,8 +7,13 @@ # Turnilo Dashboards -def dashboards_get_all(session: Session) -> List[TurniloDashboard]: - return session.query(TurniloDashboard).all() +def dashboards_get_all(session: Session, shortName: str, dataCube:str) -> List[TurniloDashboard]: + query = session.query(TurniloDashboard) + if shortName: + query = query.filter(TurniloDashboard.shortName == shortName) + if dataCube: + query = query.filter(TurniloDashboard.dataCube == dataCube) + return query.all() def _dashboards_return_single_obj(results: List[TurniloDashboard]): diff --git a/test/Makefile b/test/Makefile index 1c699e2..0f2ffc0 100644 --- a/test/Makefile +++ b/test/Makefile @@ -3,4 +3,4 @@ test: @rm -rf data/ || true @mkdir -p data - @PYTHONPATH=`pwd`/../src/ pytest -v + @PYTHONPATH=`pwd`/../src/ pytest -v -s diff --git a/test/client.py b/test/client.py index af0b584..a4bdf57 100755 --- a/test/client.py +++ b/test/client.py @@ -3,6 +3,7 @@ import json import requests import argparse +from urllib.parse import urlencode DEFAULT_HOST = "127.0.0.1" DEFAULT_PORT = 8080 @@ -32,11 +33,18 @@ def delete_dashboard(host, port, dashboard_id) -> requests.Response: return response -def get_dashboard(host, port, dashboard_id=None) -> requests.Response: +def get_dashboard(host, port, dashboard_id=None, shortName=None, dataCube=None) -> requests.Response: if dashboard_id: url = f"http://{host}:{port}/{DEFAULT_PATH}/{dashboard_id}" else: - url = f"http://{host}:{port}/{DEFAULT_PATH}" + query={} + if shortName: + query["shortName"] = shortName + if dataCube: + query["dataCube"] = dataCube + query_str = urlencode(query) + query_str = "?"+query_str if query_str != "" else "" + url = f"http://{host}:{port}/{DEFAULT_PATH}{query_str}" response = requests.get(url) print_response(response) return response diff --git a/test/unit_test.py b/test/unit_test.py index 6a49ca8..740a2be 100644 --- a/test/unit_test.py +++ b/test/unit_test.py @@ -72,7 +72,6 @@ def test_create_dashboard(sample_dashboard: dict[str, Any], with_optional_fields res = delete_dashboard(HOST, PORT, 1) assert res.status_code == 200 - def test_get_all_dashboards(sample_dashboard: dict[str, Any]) -> None: res = get_dashboard(HOST, PORT) assert res.json() == [] @@ -105,6 +104,103 @@ def test_get_all_dashboards(sample_dashboard: dict[str, Any]) -> None: res = delete_dashboard(HOST, PORT, 2) assert res.status_code == 200 +def test_get_all_dashboards_query_params(sample_dashboard: dict[str, Any]) -> None: + res = get_dashboard(HOST, PORT) + assert res.json() == [] + + # Create two dashboards + dashboard = sample_dashboard.copy() + res = create_dashboard(HOST, PORT, json.dumps(dashboard)) + assert res.json()["id"] == 1 + + dashboard = sample_dashboard.copy() + dashboard["dataCube"] = "myDatacube" + dashboard["shortName"] = "shortName" + res = create_dashboard(HOST, PORT, json.dumps(dashboard)) + assert res.json()["id"] == 2 + + #Try with invalid query params + res = get_dashboard(HOST, PORT, shortName=" ") + assert res.status_code == 400 + res = get_dashboard(HOST, PORT, dataCube=" ") + assert res.status_code == 400 + res = get_dashboard(HOST, PORT, shortName="name;test") + assert res.status_code == 400 + res = get_dashboard(HOST, PORT, dataCube="name;test") + assert res.status_code == 400 + res = get_dashboard(HOST, PORT, shortName="name?test") + assert res.status_code == 400 + res = get_dashboard(HOST, PORT, dataCube="name?test") + assert res.status_code == 400 + res = get_dashboard(HOST, PORT, shortName="name'test") + assert res.status_code == 400 + res = get_dashboard(HOST, PORT, dataCube="name'test") + assert res.status_code == 400 + res = get_dashboard(HOST, PORT, shortName="name\"test") + assert res.status_code == 400 + res = get_dashboard(HOST, PORT, dataCube="name\"test") + assert res.status_code == 400 + long_name = 's'*280 + res = get_dashboard(HOST, PORT, shortName=long_name) + assert res.status_code == 400 + res = get_dashboard(HOST, PORT, dataCube=long_name) + assert res.status_code == 400 + + #Now validate functionality + + #Name only + res = get_dashboard(HOST, PORT, shortName="simple_dashboard") + assert res.status_code == 200 + assert len(res.json()) == 1 + assert res.json()[0]["id"] == 1 + res = get_dashboard(HOST, PORT, shortName="shortName") + assert res.status_code == 200 + assert len(res.json()) == 1 + assert res.json()[0]["id"] == 2 + res = get_dashboard(HOST, PORT, shortName="shortNameA") + assert res.status_code == 200 + assert len(res.json()) == 0 + res = get_dashboard(HOST, PORT, shortName="shortName2") + assert res.status_code == 200 + assert len(res.json()) == 0 + + #DataCube only + res = get_dashboard(HOST, PORT, dataCube="networkFlows") + assert res.status_code == 200 + assert len(res.json()) == 1 + assert res.json()[0]["id"] == 1 + res = get_dashboard(HOST, PORT, dataCube="myDatacube") + assert res.status_code == 200 + assert len(res.json()) == 1 + assert res.json()[0]["id"] == 2 + res = get_dashboard(HOST, PORT, dataCube="networkFlowsA") + assert res.status_code == 200 + assert len(res.json()) == 0 + res = get_dashboard(HOST, PORT, dataCube="networkFlows2") + assert res.status_code == 200 + assert len(res.json()) == 0 + + #Both + res = get_dashboard(HOST, PORT, dataCube="networkFlows", shortName="simple_dashboard") + assert res.status_code == 200 + assert len(res.json()) == 1 + assert res.json()[0]["id"] == 1 + res = get_dashboard(HOST, PORT, dataCube="myDatacube", shortName="shortName") + assert res.status_code == 200 + assert len(res.json()) == 1 + assert res.json()[0]["id"] == 2 + res = get_dashboard(HOST, PORT, dataCube="networkFlowsA", shortName="shortNameA") + assert res.status_code == 200 + assert len(res.json()) == 0 + res = get_dashboard(HOST, PORT, dataCube="networkFlows2", shortName="shortName2") + assert res.status_code == 200 + assert len(res.json()) == 0 + + # Cleanup + res = delete_dashboard(HOST, PORT, 1) + assert res.status_code == 200 + res = delete_dashboard(HOST, PORT, 2) + assert res.status_code == 200 def test_update_dashboard(sample_dashboard: dict[str, Any]) -> None: dashboard = sample_dashboard.copy()