diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..ecd26af --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,25 @@ +name: Tests +on: + push: + pull_request: + +jobs: + tests: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v4 + with: + python-version: 3.9 + + - name: Install poetry + run: pip install poetry + + - name: Install dependencies + run: | + poetry config virtualenvs.create false + poetry install + + - name: Run tests + run: | + poetry run pytest diff --git a/README.md b/README.md index 718e7d1..221081d 100644 --- a/README.md +++ b/README.md @@ -40,8 +40,13 @@ $ pip install profyle ### 1. Implement In order to track all your API requests you must implement the ProfyleMiddleware #### ProfyleMiddleware -* enabled : Default true. You can use an env variable to decide if profyle is enabled. -* pattern: Profyle only will trace those paths that match with pattern (glob pattern) +| Attribute | Required | Default | Description | +| --- | --- | --- | --- | +| `enabled` | No | `True` | Enable or disable Profyle | +| `pattern` | No | `None` | 0nly trace those paths that match with pattern (glob pattern) | +| `max_stack_depth` | No | `-1` | Limit maximum stack trace depth | +| `min_duration` | No | `0` (milisecons) | Only record traces with a greather duration than the limit. | +
FastAPI @@ -51,7 +56,27 @@ from fastapi import FastAPI from profyle.fastapi import ProfyleMiddleware app = FastAPI() -app.add_middleware(ProfyleMiddleware, pattern='*/api/v2/*') +# Trace all requests +app.add_middleware(ProfyleMiddleware) + +@app.get("/items/{item_id}") +async def read_item(item_id: int): + return {"item_id": item_id} +``` + +```Python +from fastapi import FastAPI +from profyle.fastapi import ProfyleMiddleware + +app = FastAPI() +# Trace all requests that match that start with /api/products +# with a minimum duration of 100ms and a maximum stack depth of 20 +app.add_middleware( + ProfyleMiddleware, + pattern="/api/products*", + max_stack_depth=20, + min_duration=100 +) @app.get("/items/{item_id}") async def read_item(item_id: int): @@ -68,7 +93,7 @@ from profyle.flask import ProfyleMiddleware app = Flask(__name__) -app.wsgi_app = ProfyleMiddleware(app.wsgi_app, pattern='*/api/products*') +app.wsgi_app = ProfyleMiddleware(app.wsgi_app, pattern="*/api/products*") @app.route("/") def hello_world(): @@ -120,12 +145,19 @@ INFO: Application startup complete. ## CLI Commands ### start * Start the web server and view profile traces + +| Options | Type | Default | Description | +| --- | --- | --- | --- | +| --port | INTEGER | 0 | web server port | +| --host | TEXT | 127.0.0.1 | web server host | + +
```console -$ profyle start +$ profyle start --port 5432 -INFO: Uvicorn running on http://127.0.0.1:8000 (Press CTRL+C to quit) +INFO: Uvicorn running on http://127.0.0.1:5432 (Press CTRL+C to quit) INFO: Started reloader process [28720] INFO: Started server process [28722] INFO: Waiting for application startup. diff --git a/profyle/application/profyle.py b/profyle/application/profyle.py index 56e12a8..1dd1d4f 100644 --- a/profyle/application/profyle.py +++ b/profyle/application/profyle.py @@ -3,6 +3,7 @@ import re from typing import Optional from dataclasses import dataclass +from tempfile import NamedTemporaryFile from viztracer import VizTracer @@ -15,33 +16,42 @@ class profyle: name: str repo: TraceRepository + max_stack_depth: int = -1 + min_duration: float = 0 pattern: Optional[str] = None - tracer: VizTracer = VizTracer( - verbose=0, - log_async=True - ) + tracer: Optional[VizTracer] = None - def __enter__(self): + def __enter__(self) -> None: if self.should_trace(): + self.tracer = VizTracer( + log_func_args=True, + log_print=True, + log_func_retval=True, + log_async=True, + file_info=True, + min_duration=self.min_duration, + max_stack_depth=self.max_stack_depth + ) self.tracer.start() def __exit__( self, *args, - ): - if not self.tracer.enable: - return - self.tracer.stop() - self.tracer.parse() - new_trace = TraceCreate( - data=json.dumps(self.tracer.data), - name=self.name - ) - store_trace( - new_trace=new_trace, - repo=self.repo - ) + ) -> None: + if self.tracer and self.tracer.enable: + self.tracer.stop() + temp_file = NamedTemporaryFile(suffix=".json") + self.tracer.save(temp_file.name) + temp_file.close() + new_trace = TraceCreate( + data=json.dumps(self.tracer.data), + name=self.name + ) + store_trace( + new_trace=new_trace, + repo=self.repo + ) def should_trace(self) -> bool: if not self.pattern: @@ -49,4 +59,7 @@ def should_trace(self) -> bool: regex = fnmatch.translate(self.pattern) reobj = re.compile(regex) + method_and_name = self.name.split(' ') + if len(method_and_name) > 1: + return bool(reobj.match(method_and_name[1])) return bool(reobj.match(self.name)) diff --git a/profyle/application/trace/delete.py b/profyle/application/trace/delete.py index 90097d6..84246c6 100644 --- a/profyle/application/trace/delete.py +++ b/profyle/application/trace/delete.py @@ -4,5 +4,5 @@ def delete_all_traces(repo: TraceRepository) -> int: return repo.delete_all_traces() -def delete_trace_by_id(repo: TraceRepository, trace_id: int) -> int: - return repo.delete_trace_by_id(trace_id=trace_id) \ No newline at end of file +def delete_trace_by_id(repo: TraceRepository, trace_id: int): + repo.delete_trace_by_id(trace_id=trace_id) \ No newline at end of file diff --git a/profyle/application/trace/get.py b/profyle/application/trace/get.py index 0a80e41..c3f6ed5 100644 --- a/profyle/application/trace/get.py +++ b/profyle/application/trace/get.py @@ -1,9 +1,9 @@ -from typing import List, Optional +from typing import Optional from profyle.domain.trace import Trace from profyle.domain.trace_repository import TraceRepository -def get_all_traces(repo: TraceRepository) -> List[Trace]: +def get_all_traces(repo: TraceRepository) -> list[Trace]: return repo.get_all_traces() diff --git a/profyle/domain/trace.py b/profyle/domain/trace.py index 7397adf..cbdd7a6 100644 --- a/profyle/domain/trace.py +++ b/profyle/domain/trace.py @@ -6,7 +6,7 @@ class Trace(BaseModel): data: Json[Any] = {} name: str duration: float = 0 - timestamp: str = '' + timestamp: str = "" id: int @@ -17,13 +17,21 @@ class TraceCreate(BaseModel): @computed_field @property def duration(self) -> float: + any_trace_to_analize = any( + True + for trace in self.data.get("traceEvents", []) + if trace.get("ts") + ) + if not any_trace_to_analize: + return 0 + start = min( - trace.get('ts') - for trace in self.data.get('traceEvents', []) - if trace.get('ts') + trace.get("ts",0) + for trace in self.data.get("traceEvents", []) + if trace.get("ts") ) end = max( - trace.get('ts', 0) - for trace in self.data.get('traceEvents', []) + trace.get("ts", 0) + for trace in self.data.get("traceEvents", []) ) return end-start diff --git a/profyle/domain/trace_repository.py b/profyle/domain/trace_repository.py index ad012cc..0196a73 100644 --- a/profyle/domain/trace_repository.py +++ b/profyle/domain/trace_repository.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import List, Optional +from typing import Optional from profyle.domain.trace import Trace, TraceCreate @@ -31,7 +31,7 @@ def store_trace(self, new_trace: TraceCreate) -> None: ... @abstractmethod - def get_all_traces(self) -> List[Trace]: + def get_all_traces(self) -> list[Trace]: ... @abstractmethod @@ -43,5 +43,5 @@ def get_trace_selected(self) -> Optional[int]: ... @abstractmethod - def delete_trace_by_id(self, trace_id: int) -> int: + def delete_trace_by_id(self, trace_id: int): ... diff --git a/profyle/fastapi.py b/profyle/fastapi.py index 745dd55..fb9910e 100644 --- a/profyle/fastapi.py +++ b/profyle/fastapi.py @@ -1 +1 @@ -from .infrastructure.middleware.fastapi import ProfyleMiddleware +from profyle.infrastructure.middleware.fastapi import ProfyleMiddleware \ No newline at end of file diff --git a/profyle/flask.py b/profyle/flask.py index 2d13ed6..5884c49 100644 --- a/profyle/flask.py +++ b/profyle/flask.py @@ -1 +1 @@ -from .infrastructure.middleware.flask import ProfyleMiddleware +from profyle.infrastructure.middleware.flask import ProfyleMiddleware \ No newline at end of file diff --git a/profyle/infrastructure/http_server.py b/profyle/infrastructure/http_server.py index d1da864..96bd5cd 100644 --- a/profyle/infrastructure/http_server.py +++ b/profyle/infrastructure/http_server.py @@ -16,48 +16,48 @@ app = FastAPI( - title='Profyle', - version='1.0.0' + title="Profyle", + version="1.0.0" ) app.add_middleware( CORSMiddleware, - allow_origins=['*'], - allow_methods=['*'], - allow_headers=['*'], + allow_origins=["*"], + allow_methods=["*"], + allow_headers=["*"], ) -@app.on_event('startup') +@app.on_event("startup") async def startup_event(): db = get_connection() sqlite_trace_repo = SQLiteTraceRepository(db) create_trace_table(repo=sqlite_trace_repo) create_trace_selected_table(repo=sqlite_trace_repo) -STATIC_PATH = ('infrastructure', 'web', 'static') +STATIC_PATH = ("infrastructure", "web", "static") app.mount( - '/static', + "/static", StaticFiles(directory=settings.get_path(*STATIC_PATH)), - name='static' + name="static" ) app.mount( - '/show', + "/show", StaticFiles(directory=settings.get_viztracer_static_files(), html=True), - name='perfetto' + name="perfetto" ) -TEMPLATES_PATH = ('infrastructure', 'web', 'templates') +TEMPLATES_PATH = ("infrastructure", "web", "templates") templates = Jinja2Templates(directory=settings.get_path(*TEMPLATES_PATH)) -@app.get('/vizviewer_info') +@app.get("/vizviewer_info") async def vizviewer_info(): - return {'is_flamegraph': False} + return {"is_flamegraph": False} -@app.get('/file_info') +@app.get("/file_info") async def file_info( db: Connection = Depends(get_connection), ): @@ -71,10 +71,10 @@ async def file_info( ) if not trace: return {} - return trace.data.get('file_info') + return trace.data.get("file_info") -@app.get('/localtrace') +@app.get("/localtrace") async def localtrace( db: Connection = Depends(get_connection), ): @@ -91,12 +91,12 @@ async def localtrace( return trace.data -@app.get('/') +@app.get("/") async def index(): - return RedirectResponse('/traces') + return RedirectResponse("/traces") -@app.get('/traces') +@app.get("/traces") async def traces( request: Request, db: Connection = Depends(get_connection), @@ -105,15 +105,15 @@ async def traces( sqlite_trace_repo = SQLiteTraceRepository(db) traces = get_all_traces(repo=sqlite_trace_repo) return templates.TemplateResponse( - name='traces.html', + name="traces.html", context={ - 'request': request, - 'traces': [trace.dict() for trace in traces] + "request": request, + "traces": [trace.dict() for trace in traces] } ) -@app.get('/traces/{id}') +@app.get("/traces/{id}") async def get_trace( id: int, db: Connection = Depends(get_connection), @@ -123,10 +123,10 @@ async def get_trace( trace_id=id, repo=sqlite_trace_repo ) - return RedirectResponse(url='/show') + return RedirectResponse(url="/show") -@app.delete('/traces/{id}', status_code=204) +@app.delete("/traces/{id}", status_code=204) async def delete_trace( id: int, db: Connection = Depends(get_connection), @@ -136,7 +136,7 @@ async def delete_trace( sqlite_trace_repo.delete_trace_by_id(id) -async def start_server(): - config = uvicorn.Config(app, port=0, log_level='info') +async def start_server(port: int = 0, host: str = "127.0.0.1"): + config = uvicorn.Config(app, port=port, log_level="info", host=host) server = uvicorn.Server(config) await server.serve() diff --git a/profyle/infrastructure/middleware/fastapi.py b/profyle/infrastructure/middleware/fastapi.py index 5eb1b7b..4d368d6 100644 --- a/profyle/infrastructure/middleware/fastapi.py +++ b/profyle/infrastructure/middleware/fastapi.py @@ -2,7 +2,6 @@ from starlette.types import ASGIApp, Scope, Receive, Send from profyle.application.profyle import profyle -from profyle.infrastructure.sqlite3.get_connection import get_connection from profyle.infrastructure.sqlite3.repository import SQLiteTraceRepository @@ -11,20 +10,28 @@ def __init__( self, app: ASGIApp, enabled: bool = True, - pattern: Optional[str] = None + pattern: Optional[str] = None, + max_stack_depth: int = -1, + min_duration: int = 0, + trace_repo: SQLiteTraceRepository = SQLiteTraceRepository() ): self.app = app self.enabled = enabled self.pattern = pattern + self.max_stack_depth = max_stack_depth + self.min_duration = min_duration + self.trace_repo = trace_repo async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - if self.enabled and scope['type'] == 'http': - db = get_connection() - sqlite_repo = SQLiteTraceRepository(db) + if self.enabled and scope["type"] == "http": + method = scope.get('method', '').upper() + path = scope.get('raw_path', b'').decode('utf-8') with profyle( - name=scope['raw_path'].decode("utf-8"), + name=f"{method} {path}", pattern=self.pattern, - repo=sqlite_repo + repo=self.trace_repo, + max_stack_depth=self.max_stack_depth, + min_duration=self.min_duration ): await self.app(scope, receive, send) return diff --git a/profyle/infrastructure/middleware/flask.py b/profyle/infrastructure/middleware/flask.py index 2f267a9..6cb18ac 100644 --- a/profyle/infrastructure/middleware/flask.py +++ b/profyle/infrastructure/middleware/flask.py @@ -1,6 +1,7 @@ from typing import Optional + from profyle.application.profyle import profyle -from profyle.infrastructure.sqlite3.get_connection import get_connection +from profyle.domain.trace_repository import TraceRepository from profyle.infrastructure.sqlite3.repository import SQLiteTraceRepository @@ -9,20 +10,28 @@ def __init__( self, app, enabled: bool = True, - pattern: Optional[str] = None + pattern: Optional[str] = None, + max_stack_depth: int = -1, + min_duration: int = 0, + trace_repo: TraceRepository = SQLiteTraceRepository() ): self.app = app self.enabled = enabled self.pattern = pattern + self.max_stack_depth = max_stack_depth + self.min_duration = min_duration + self.trace_repo = trace_repo def __call__(self, environ, start_response): - if environ.get('wsgi.url_scheme') == 'http' and self.enabled: - db = get_connection() - sqlite_repo = SQLiteTraceRepository(db) + if environ.get("wsgi.url_scheme") == "http" and self.enabled: + method = environ.get("REQUEST_METHOD", "").upper() + path = environ.get("REQUEST_URI") with profyle( - name=environ['REQUEST_URI'], + name=f"{method} {path}", pattern=self.pattern, - repo=sqlite_repo + max_stack_depth=self.max_stack_depth, + min_duration=self.min_duration, + repo=self.trace_repo ): return self.app(environ, start_response) return self.app(environ, start_response) diff --git a/profyle/infrastructure/sqlite3/get_connection.py b/profyle/infrastructure/sqlite3/get_connection.py index 65a2910..23cb92d 100644 --- a/profyle/infrastructure/sqlite3/get_connection.py +++ b/profyle/infrastructure/sqlite3/get_connection.py @@ -6,7 +6,7 @@ def get_connection() -> Connection: db = sqlite3.connect( - settings.get_path('profile.db'), + settings.get_path("profile.db"), check_same_thread=False ) return db diff --git a/profyle/infrastructure/sqlite3/repository.py b/profyle/infrastructure/sqlite3/repository.py index 7994c13..7026dfa 100644 --- a/profyle/infrastructure/sqlite3/repository.py +++ b/profyle/infrastructure/sqlite3/repository.py @@ -1,13 +1,16 @@ from sqlite3 import Connection, Error, Row -from typing import List, Optional +from typing import Optional import json from profyle.domain.trace import Trace, TraceCreate from profyle.domain.trace_repository import TraceRepository +from profyle.infrastructure.sqlite3.get_connection import get_connection class SQLiteTraceRepository(TraceRepository): - def __init__(self, db: Connection): + def __init__(self, db: Optional[Connection] = None): + if not db: + db = get_connection() self.db = db def create_trace_selected_table(self) -> None: @@ -96,7 +99,7 @@ def store_trace(self, trace: TraceCreate) -> None: except Error as error: print("Failed to insert data into trace table", error) - def get_all_traces(self) -> List[Trace]: + def get_all_traces(self) -> list[Trace]: self.db.row_factory = Row cursor = self.db.cursor() cursor.execute( @@ -123,7 +126,7 @@ def get_trace_selected(self) -> Optional[int]: cursor.execute( "SELECT trace_id FROM trace_selected where id = ?", (1,)) trace = cursor.fetchone() - return trace['trace_id'] if trace else None + return trace["trace_id"] if trace else None def delete_trace_by_id(self, trace_id: int): cursor = self.db.cursor() diff --git a/profyle/infrastructure/web/templates/traces.html b/profyle/infrastructure/web/templates/traces.html index 6491958..30c8542 100644 --- a/profyle/infrastructure/web/templates/traces.html +++ b/profyle/infrastructure/web/templates/traces.html @@ -10,16 +10,19 @@ -

- - Profyl - - - - -

+
+

+ + Profyl + + + + +

+
+
@@ -73,10 +76,20 @@

No traces

+ {%set method_and_name = trace.name.split(' ')%} + + {% if method_and_name[1] %} + + {{method_and_name[0]}} + + {{method_and_name[1]}} + {% else %} {{trace.name}} + {% endif %} - {{trace.duration / 1000}} + {{(trace.duration / 1000) | round(2)}} {{trace.timestamp}} diff --git a/profyle/main.py b/profyle/main.py index 9b59f67..e0e7788 100644 --- a/profyle/main.py +++ b/profyle/main.py @@ -17,8 +17,8 @@ @app.command() -def start(): - asyncio.run(start_server()) +def start(port: int = 0, host: str = "127.0.0.1"): + asyncio.run(start_server(port=port, host=host)) @app.command() @@ -27,16 +27,16 @@ def clean(): sqlite_repo = SQLiteTraceRepository(db) removed_traces = delete_all_traces(sqlite_repo) vacuum(sqlite_repo) - print(f'[green]{removed_traces} traces removed [/green]') + print(f"[green]{removed_traces} traces removed [/green]") @app.command() def check(): - db_size_in_bytes = os.path.getsize(settings.get_path('profile.db')) + db_size_in_bytes = os.path.getsize(settings.get_path("profile.db")) db_size_in_megabytes = round(db_size_in_bytes/10**6, 2) db_size_in_gigabytes = round(db_size_in_megabytes/10**3, 2) if db_size_in_megabytes > 1000: - print(f'[orange1]DB size: {db_size_in_gigabytes} GB [/orange1]') + print(f"[orange1]DB size: {db_size_in_gigabytes} GB [/orange1]") return - print(f'[orange1]DB size: {db_size_in_megabytes} MB [/orange1]') + print(f"[orange1]DB size: {db_size_in_megabytes} MB [/orange1]") diff --git a/profyle/settings.py b/profyle/settings.py index 7e630ff..2f70ead 100644 --- a/profyle/settings.py +++ b/profyle/settings.py @@ -5,11 +5,11 @@ class Settings(BaseSettings): - app_name: str = 'Profyle' + app_name: str = "Profyle" project_dir: str = os.path.normpath( os.path.join( os.path.abspath(__file__), - '..', + "..", ) ) @@ -22,8 +22,8 @@ def get_path(self, *args): def get_viztracer_static_files(self): return os.path.normpath(os.path.join( os.path.abspath(viztracer.__file__), - '..', - 'web_dist' + "..", + "web_dist" )) diff --git a/pyproject.toml b/pyproject.toml index 218f182..bceaac0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,6 +52,7 @@ autopep8 = "^2.0.1" pytest = "^7.2.1" httpx = "^0.23.3" flask = "^2.0.0" +pytest-asyncio = "^0.21.1" [tool.poetry.group.flask.dependencies] diff --git a/tests/middleware/test_fastapi_middleware.py b/tests/middleware/test_fastapi_middleware.py deleted file mode 100644 index c6e3d8f..0000000 --- a/tests/middleware/test_fastapi_middleware.py +++ /dev/null @@ -1,29 +0,0 @@ -from fastapi import APIRouter, FastAPI -from fastapi.testclient import TestClient - -from profyle.fastapi import ProfyleMiddleware - -app = FastAPI() -app.add_middleware(ProfyleMiddleware, pattern='*test[?]*') - -router = APIRouter() - - -@router.post('/test') -def run_middleware(): - return {'message': 'OK'} - - -@router.post('/test1') -def run_middleware_1(): - return {'message': 'OK'} - - -app.include_router(router) - -client = TestClient(app) - - -def test_fastapi_middleware(): - client.post('test?demo=true') - client.post('test1') diff --git a/tests/middleware/test_flask_middleware.py b/tests/middleware/test_flask_middleware.py deleted file mode 100644 index e7190cd..0000000 --- a/tests/middleware/test_flask_middleware.py +++ /dev/null @@ -1,23 +0,0 @@ -import os - -from flask import Flask - -from profyle.flask import ProfyleMiddleware - - -app = Flask('flask_test', root_path=os.path.dirname(__file__)) -app.config.update( - TESTING=True, - SECRET_KEY='test key', -) - -app.wsgi_app = ProfyleMiddleware(app.wsgi_app, pattern='*test*') - - -@app.route('/test') -def index(): - return 'Test' - - -def test_profyle_middleware(): - app.test_client().get('/test') diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/application/__init__.py b/tests/unit/application/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/application/test_profyle.py b/tests/unit/application/test_profyle.py new file mode 100644 index 0000000..2a264fa --- /dev/null +++ b/tests/unit/application/test_profyle.py @@ -0,0 +1,64 @@ +import pytest +import asyncio + +from profyle.application.profyle import profyle +from tests.unit.repository import InMemoryTraceRepository + + +def test_should_trace_a_process(): + trace_repo = InMemoryTraceRepository() + + with profyle( + name="test", + repo=trace_repo, + ): + print("demo") + + assert len(trace_repo.traces) == 1 + assert len(trace_repo.traces[0].data) + + +@pytest.mark.asyncio +async def test_should_trace_an_async_process(): + trace_repo = InMemoryTraceRepository() + + with profyle( + name="test", + repo=trace_repo, + ): + await asyncio.sleep(0.1) + + assert len(trace_repo.traces) == 1 + assert len(trace_repo.traces[0].data) + assert trace_repo.traces[0].duration/1000 > 0.1 + + +@pytest.mark.asyncio +async def test_should_trace_a_process_with_min_duration(): + trace_repo = InMemoryTraceRepository() + + with profyle( + name="test", + repo=trace_repo, + min_duration=1 + ): + await asyncio.sleep(2) + + assert len(trace_repo.traces) == 1 + assert len(trace_repo.traces[0].data) + assert trace_repo.traces[0].duration/1000 > 2 + + +@pytest.mark.asyncio +async def test_should_not_trace_a_process_if_min_duration_not_reached(): + trace_repo = InMemoryTraceRepository() + with profyle( + name="test", + repo=trace_repo, + min_duration=3000 + ): + await asyncio.sleep(2) + + assert len(trace_repo.traces) == 1 + assert len(trace_repo.traces[0].data) + assert int(trace_repo.traces[0].duration/1000) == 0 diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py new file mode 100644 index 0000000..533e742 --- /dev/null +++ b/tests/unit/conftest.py @@ -0,0 +1,65 @@ + +import pytest +import os + +from flask import Flask +from fastapi import APIRouter, FastAPI +from fastapi.testclient import TestClient + +@pytest.fixture +def fastapi_app(): + app = FastAPI() + + router = APIRouter() + + @router.post('/test-post') + async def test_post(): + return {'message': 'OK'} + + @router.get('/test-get') + async def test_get(demo: bool = False): + return {'message': demo} + + @router.patch('/test-patch') + async def test_patch(): + return {'message': 'OK'} + + @router.put('/test-put') + async def test_put(): + return {'message': 'OK'} + + app.include_router(router) + + yield app + + +@pytest.fixture +def flask_app(): + app = Flask('flask_test', root_path=os.path.dirname(__file__)) + app.config.update( + TESTING=True, + SECRET_KEY='test key', + ) + + @app.route('/test-post', methods=['POST']) + def test_post(): + return 'Test' + + @app.route('/test-get', methods=['GET']) + def test_get(): + return 'Test' + + @app.route('/test-patch', methods=['PATCH']) + def test_patch(): + return 'Test' + + yield app + +@pytest.fixture +def flask_client(flask_app): + yield flask_app.test_client() + + +@pytest.fixture() +def fastapi_client(fastapi_app): + yield TestClient(fastapi_app) diff --git a/tests/unit/infrastructure/__init__.py b/tests/unit/infrastructure/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/infrastructure/middleware/__init__.py b/tests/unit/infrastructure/middleware/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/infrastructure/middleware/test_fastapi_middleware.py b/tests/unit/infrastructure/middleware/test_fastapi_middleware.py new file mode 100644 index 0000000..6ea22cb --- /dev/null +++ b/tests/unit/infrastructure/middleware/test_fastapi_middleware.py @@ -0,0 +1,48 @@ +from profyle.fastapi import ProfyleMiddleware +from tests.unit.repository import InMemoryTraceRepository + + +def test_should_trace_all_requests(fastapi_client, fastapi_app): + trace_repo = InMemoryTraceRepository() + fastapi_app.add_middleware( + ProfyleMiddleware, + trace_repo=trace_repo + ) + + fastapi_client.post("test") + fastapi_client.get("test?demo=true") + + assert len(trace_repo.traces) == 2 + assert trace_repo.traces[0].name == "POST /test" + assert trace_repo.traces[1].name == "GET /test?demo=true" + + +def test_should_trace_filtered_requests(fastapi_client, fastapi_app): + trace_repo = InMemoryTraceRepository() + fastapi_app.add_middleware( + ProfyleMiddleware, + pattern="/test*", + trace_repo=trace_repo + ) + + fastapi_client.post("test") + fastapi_client.get("test?demo=true") + fastapi_client.get("other") + + assert len(trace_repo.traces) == 2 + assert trace_repo.traces[0].name == "POST /test" + assert trace_repo.traces[1].name == "GET /test?demo=true" + + +def test_should_no_trace_if_disabled(fastapi_client, fastapi_app): + trace_repo = InMemoryTraceRepository() + fastapi_app.add_middleware( + ProfyleMiddleware, + enabled=False, + trace_repo=InMemoryTraceRepository() + ) + + fastapi_client.post("test") + fastapi_client.get("test?demo=true") + + assert len(trace_repo.traces) == 0 diff --git a/tests/unit/infrastructure/middleware/test_flask_middleware.py b/tests/unit/infrastructure/middleware/test_flask_middleware.py new file mode 100644 index 0000000..41af5ff --- /dev/null +++ b/tests/unit/infrastructure/middleware/test_flask_middleware.py @@ -0,0 +1,48 @@ +from profyle.infrastructure.middleware.flask import ProfyleMiddleware +from tests.unit.repository import InMemoryTraceRepository + + +def test_should_trace_all_requests(flask_client, flask_app): + trace_repo = InMemoryTraceRepository() + flask_app.wsgi_app = ProfyleMiddleware( + flask_app.wsgi_app, + trace_repo=trace_repo + ) + + flask_client.post("/test") + flask_client.get("/test?demo=true") + + assert len(trace_repo.traces) == 2 + assert trace_repo.traces[0].name == "POST /test" + assert trace_repo.traces[1].name == "GET /test?demo=true" + + +def test_should_trace_filtered_requests(flask_client, flask_app): + trace_repo = InMemoryTraceRepository() + flask_app.wsgi_app = ProfyleMiddleware( + flask_app.wsgi_app, + trace_repo=trace_repo, + pattern="/test*", + ) + + flask_client.post("/test") + flask_client.get("/test?demo=true") + flask_client.get("/other") + + assert len(trace_repo.traces) == 2 + assert trace_repo.traces[0].name == "POST /test" + assert trace_repo.traces[1].name == "GET /test?demo=true" + + +def test_should_no_trace_if_disabled(flask_client, flask_app): + trace_repo = InMemoryTraceRepository() + flask_app.wsgi_app = ProfyleMiddleware( + flask_app.wsgi_app, + trace_repo=trace_repo, + enabled=False, + ) + + flask_client.post("/test") + flask_client.get("/test?demo=true") + + assert len(trace_repo.traces) == 0 diff --git a/tests/unit/repository.py b/tests/unit/repository.py new file mode 100644 index 0000000..8e1cc26 --- /dev/null +++ b/tests/unit/repository.py @@ -0,0 +1,60 @@ +from typing import Optional +from uuid import uuid4 +import time +import json + +from profyle.domain.trace import TraceCreate +from profyle.domain.trace_repository import TraceRepository +from profyle.domain.trace import Trace + + +class InMemoryTraceRepository(TraceRepository): + def __init__(self): + self.traces: list[Trace] = [] + self.selected_trace: int = 0 + + def create_trace_selected_table(self) -> None: + ... + + def create_trace_table(self) -> None: + ... + + def delete_all_traces(self) -> int: + removed = len(self.traces) + self.traces = [] + return removed + + def vacuum(self) -> None: + ... + + def store_trace_selected(self, trace_id: int) -> None: + self.selected_trace = trace_id + + def store_trace(self, new_trace: TraceCreate) -> None: + + trace = Trace( + id=uuid4().int, + timestamp=str(time.time()), + data= json.dumps(new_trace.data), + duration=new_trace.duration, + name=new_trace.name, + ) + self.traces.append(trace) + + def get_all_traces(self) -> list[Trace]: + return self.traces + + def get_trace_by_id(self, id: int) -> Optional[Trace]: + for trace in self.traces: + if trace.id == id: + return trace + return + + def get_trace_selected(self) -> Optional[int]: + return self.selected_trace + + def delete_trace_by_id(self, trace_id: int): + for trace in self.traces: + if trace.id == trace_id: + self.traces.remove(trace) + return