```console
-$ profyle start
+$ profyle start --port 5432
-INFO: Uvicorn running on http://127.0.0.1:8000 (Press CTRL+C to quit)
+INFO: Uvicorn running on http://127.0.0.1:5432 (Press CTRL+C to quit)
INFO: Started reloader process [28720]
INFO: Started server process [28722]
INFO: Waiting for application startup.
diff --git a/profyle/application/profyle.py b/profyle/application/profyle.py
index 56e12a8..1dd1d4f 100644
--- a/profyle/application/profyle.py
+++ b/profyle/application/profyle.py
@@ -3,6 +3,7 @@
import re
from typing import Optional
from dataclasses import dataclass
+from tempfile import NamedTemporaryFile
from viztracer import VizTracer
@@ -15,33 +16,42 @@
class profyle:
name: str
repo: TraceRepository
+ max_stack_depth: int = -1
+ min_duration: float = 0
pattern: Optional[str] = None
- tracer: VizTracer = VizTracer(
- verbose=0,
- log_async=True
- )
+ tracer: Optional[VizTracer] = None
- def __enter__(self):
+ def __enter__(self) -> None:
if self.should_trace():
+ self.tracer = VizTracer(
+ log_func_args=True,
+ log_print=True,
+ log_func_retval=True,
+ log_async=True,
+ file_info=True,
+ min_duration=self.min_duration,
+ max_stack_depth=self.max_stack_depth
+ )
self.tracer.start()
def __exit__(
self,
*args,
- ):
- if not self.tracer.enable:
- return
- self.tracer.stop()
- self.tracer.parse()
- new_trace = TraceCreate(
- data=json.dumps(self.tracer.data),
- name=self.name
- )
- store_trace(
- new_trace=new_trace,
- repo=self.repo
- )
+ ) -> None:
+ if self.tracer and self.tracer.enable:
+ self.tracer.stop()
+ temp_file = NamedTemporaryFile(suffix=".json")
+ self.tracer.save(temp_file.name)
+ temp_file.close()
+ new_trace = TraceCreate(
+ data=json.dumps(self.tracer.data),
+ name=self.name
+ )
+ store_trace(
+ new_trace=new_trace,
+ repo=self.repo
+ )
def should_trace(self) -> bool:
if not self.pattern:
@@ -49,4 +59,7 @@ def should_trace(self) -> bool:
regex = fnmatch.translate(self.pattern)
reobj = re.compile(regex)
+ method_and_name = self.name.split(' ')
+ if len(method_and_name) > 1:
+ return bool(reobj.match(method_and_name[1]))
return bool(reobj.match(self.name))
diff --git a/profyle/application/trace/delete.py b/profyle/application/trace/delete.py
index 90097d6..84246c6 100644
--- a/profyle/application/trace/delete.py
+++ b/profyle/application/trace/delete.py
@@ -4,5 +4,5 @@
def delete_all_traces(repo: TraceRepository) -> int:
return repo.delete_all_traces()
-def delete_trace_by_id(repo: TraceRepository, trace_id: int) -> int:
- return repo.delete_trace_by_id(trace_id=trace_id)
\ No newline at end of file
+def delete_trace_by_id(repo: TraceRepository, trace_id: int):
+ repo.delete_trace_by_id(trace_id=trace_id)
\ No newline at end of file
diff --git a/profyle/application/trace/get.py b/profyle/application/trace/get.py
index 0a80e41..c3f6ed5 100644
--- a/profyle/application/trace/get.py
+++ b/profyle/application/trace/get.py
@@ -1,9 +1,9 @@
-from typing import List, Optional
+from typing import Optional
from profyle.domain.trace import Trace
from profyle.domain.trace_repository import TraceRepository
-def get_all_traces(repo: TraceRepository) -> List[Trace]:
+def get_all_traces(repo: TraceRepository) -> list[Trace]:
return repo.get_all_traces()
diff --git a/profyle/domain/trace.py b/profyle/domain/trace.py
index 7397adf..cbdd7a6 100644
--- a/profyle/domain/trace.py
+++ b/profyle/domain/trace.py
@@ -6,7 +6,7 @@ class Trace(BaseModel):
data: Json[Any] = {}
name: str
duration: float = 0
- timestamp: str = ''
+ timestamp: str = ""
id: int
@@ -17,13 +17,21 @@ class TraceCreate(BaseModel):
@computed_field
@property
def duration(self) -> float:
+ any_trace_to_analize = any(
+ True
+ for trace in self.data.get("traceEvents", [])
+ if trace.get("ts")
+ )
+ if not any_trace_to_analize:
+ return 0
+
start = min(
- trace.get('ts')
- for trace in self.data.get('traceEvents', [])
- if trace.get('ts')
+ trace.get("ts",0)
+ for trace in self.data.get("traceEvents", [])
+ if trace.get("ts")
)
end = max(
- trace.get('ts', 0)
- for trace in self.data.get('traceEvents', [])
+ trace.get("ts", 0)
+ for trace in self.data.get("traceEvents", [])
)
return end-start
diff --git a/profyle/domain/trace_repository.py b/profyle/domain/trace_repository.py
index ad012cc..0196a73 100644
--- a/profyle/domain/trace_repository.py
+++ b/profyle/domain/trace_repository.py
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
-from typing import List, Optional
+from typing import Optional
from profyle.domain.trace import Trace, TraceCreate
@@ -31,7 +31,7 @@ def store_trace(self, new_trace: TraceCreate) -> None:
...
@abstractmethod
- def get_all_traces(self) -> List[Trace]:
+ def get_all_traces(self) -> list[Trace]:
...
@abstractmethod
@@ -43,5 +43,5 @@ def get_trace_selected(self) -> Optional[int]:
...
@abstractmethod
- def delete_trace_by_id(self, trace_id: int) -> int:
+ def delete_trace_by_id(self, trace_id: int):
...
diff --git a/profyle/fastapi.py b/profyle/fastapi.py
index 745dd55..fb9910e 100644
--- a/profyle/fastapi.py
+++ b/profyle/fastapi.py
@@ -1 +1 @@
-from .infrastructure.middleware.fastapi import ProfyleMiddleware
+from profyle.infrastructure.middleware.fastapi import ProfyleMiddleware
\ No newline at end of file
diff --git a/profyle/flask.py b/profyle/flask.py
index 2d13ed6..5884c49 100644
--- a/profyle/flask.py
+++ b/profyle/flask.py
@@ -1 +1 @@
-from .infrastructure.middleware.flask import ProfyleMiddleware
+from profyle.infrastructure.middleware.flask import ProfyleMiddleware
\ No newline at end of file
diff --git a/profyle/infrastructure/http_server.py b/profyle/infrastructure/http_server.py
index d1da864..96bd5cd 100644
--- a/profyle/infrastructure/http_server.py
+++ b/profyle/infrastructure/http_server.py
@@ -16,48 +16,48 @@
app = FastAPI(
- title='Profyle',
- version='1.0.0'
+ title="Profyle",
+ version="1.0.0"
)
app.add_middleware(
CORSMiddleware,
- allow_origins=['*'],
- allow_methods=['*'],
- allow_headers=['*'],
+ allow_origins=["*"],
+ allow_methods=["*"],
+ allow_headers=["*"],
)
-@app.on_event('startup')
+@app.on_event("startup")
async def startup_event():
db = get_connection()
sqlite_trace_repo = SQLiteTraceRepository(db)
create_trace_table(repo=sqlite_trace_repo)
create_trace_selected_table(repo=sqlite_trace_repo)
-STATIC_PATH = ('infrastructure', 'web', 'static')
+STATIC_PATH = ("infrastructure", "web", "static")
app.mount(
- '/static',
+ "/static",
StaticFiles(directory=settings.get_path(*STATIC_PATH)),
- name='static'
+ name="static"
)
app.mount(
- '/show',
+ "/show",
StaticFiles(directory=settings.get_viztracer_static_files(), html=True),
- name='perfetto'
+ name="perfetto"
)
-TEMPLATES_PATH = ('infrastructure', 'web', 'templates')
+TEMPLATES_PATH = ("infrastructure", "web", "templates")
templates = Jinja2Templates(directory=settings.get_path(*TEMPLATES_PATH))
-@app.get('/vizviewer_info')
+@app.get("/vizviewer_info")
async def vizviewer_info():
- return {'is_flamegraph': False}
+ return {"is_flamegraph": False}
-@app.get('/file_info')
+@app.get("/file_info")
async def file_info(
db: Connection = Depends(get_connection),
):
@@ -71,10 +71,10 @@ async def file_info(
)
if not trace:
return {}
- return trace.data.get('file_info')
+ return trace.data.get("file_info")
-@app.get('/localtrace')
+@app.get("/localtrace")
async def localtrace(
db: Connection = Depends(get_connection),
):
@@ -91,12 +91,12 @@ async def localtrace(
return trace.data
-@app.get('/')
+@app.get("/")
async def index():
- return RedirectResponse('/traces')
+ return RedirectResponse("/traces")
-@app.get('/traces')
+@app.get("/traces")
async def traces(
request: Request,
db: Connection = Depends(get_connection),
@@ -105,15 +105,15 @@ async def traces(
sqlite_trace_repo = SQLiteTraceRepository(db)
traces = get_all_traces(repo=sqlite_trace_repo)
return templates.TemplateResponse(
- name='traces.html',
+ name="traces.html",
context={
- 'request': request,
- 'traces': [trace.dict() for trace in traces]
+ "request": request,
+ "traces": [trace.dict() for trace in traces]
}
)
-@app.get('/traces/{id}')
+@app.get("/traces/{id}")
async def get_trace(
id: int,
db: Connection = Depends(get_connection),
@@ -123,10 +123,10 @@ async def get_trace(
trace_id=id,
repo=sqlite_trace_repo
)
- return RedirectResponse(url='/show')
+ return RedirectResponse(url="/show")
-@app.delete('/traces/{id}', status_code=204)
+@app.delete("/traces/{id}", status_code=204)
async def delete_trace(
id: int,
db: Connection = Depends(get_connection),
@@ -136,7 +136,7 @@ async def delete_trace(
sqlite_trace_repo.delete_trace_by_id(id)
-async def start_server():
- config = uvicorn.Config(app, port=0, log_level='info')
+async def start_server(port: int = 0, host: str = "127.0.0.1"):
+ config = uvicorn.Config(app, port=port, log_level="info", host=host)
server = uvicorn.Server(config)
await server.serve()
diff --git a/profyle/infrastructure/middleware/fastapi.py b/profyle/infrastructure/middleware/fastapi.py
index 5eb1b7b..4d368d6 100644
--- a/profyle/infrastructure/middleware/fastapi.py
+++ b/profyle/infrastructure/middleware/fastapi.py
@@ -2,7 +2,6 @@
from starlette.types import ASGIApp, Scope, Receive, Send
from profyle.application.profyle import profyle
-from profyle.infrastructure.sqlite3.get_connection import get_connection
from profyle.infrastructure.sqlite3.repository import SQLiteTraceRepository
@@ -11,20 +10,28 @@ def __init__(
self,
app: ASGIApp,
enabled: bool = True,
- pattern: Optional[str] = None
+ pattern: Optional[str] = None,
+ max_stack_depth: int = -1,
+ min_duration: int = 0,
+ trace_repo: SQLiteTraceRepository = SQLiteTraceRepository()
):
self.app = app
self.enabled = enabled
self.pattern = pattern
+ self.max_stack_depth = max_stack_depth
+ self.min_duration = min_duration
+ self.trace_repo = trace_repo
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
- if self.enabled and scope['type'] == 'http':
- db = get_connection()
- sqlite_repo = SQLiteTraceRepository(db)
+ if self.enabled and scope["type"] == "http":
+ method = scope.get('method', '').upper()
+ path = scope.get('raw_path', b'').decode('utf-8')
with profyle(
- name=scope['raw_path'].decode("utf-8"),
+ name=f"{method} {path}",
pattern=self.pattern,
- repo=sqlite_repo
+ repo=self.trace_repo,
+ max_stack_depth=self.max_stack_depth,
+ min_duration=self.min_duration
):
await self.app(scope, receive, send)
return
diff --git a/profyle/infrastructure/middleware/flask.py b/profyle/infrastructure/middleware/flask.py
index 2f267a9..6cb18ac 100644
--- a/profyle/infrastructure/middleware/flask.py
+++ b/profyle/infrastructure/middleware/flask.py
@@ -1,6 +1,7 @@
from typing import Optional
+
from profyle.application.profyle import profyle
-from profyle.infrastructure.sqlite3.get_connection import get_connection
+from profyle.domain.trace_repository import TraceRepository
from profyle.infrastructure.sqlite3.repository import SQLiteTraceRepository
@@ -9,20 +10,28 @@ def __init__(
self,
app,
enabled: bool = True,
- pattern: Optional[str] = None
+ pattern: Optional[str] = None,
+ max_stack_depth: int = -1,
+ min_duration: int = 0,
+ trace_repo: TraceRepository = SQLiteTraceRepository()
):
self.app = app
self.enabled = enabled
self.pattern = pattern
+ self.max_stack_depth = max_stack_depth
+ self.min_duration = min_duration
+ self.trace_repo = trace_repo
def __call__(self, environ, start_response):
- if environ.get('wsgi.url_scheme') == 'http' and self.enabled:
- db = get_connection()
- sqlite_repo = SQLiteTraceRepository(db)
+ if environ.get("wsgi.url_scheme") == "http" and self.enabled:
+ method = environ.get("REQUEST_METHOD", "").upper()
+ path = environ.get("REQUEST_URI")
with profyle(
- name=environ['REQUEST_URI'],
+ name=f"{method} {path}",
pattern=self.pattern,
- repo=sqlite_repo
+ max_stack_depth=self.max_stack_depth,
+ min_duration=self.min_duration,
+ repo=self.trace_repo
):
return self.app(environ, start_response)
return self.app(environ, start_response)
diff --git a/profyle/infrastructure/sqlite3/get_connection.py b/profyle/infrastructure/sqlite3/get_connection.py
index 65a2910..23cb92d 100644
--- a/profyle/infrastructure/sqlite3/get_connection.py
+++ b/profyle/infrastructure/sqlite3/get_connection.py
@@ -6,7 +6,7 @@
def get_connection() -> Connection:
db = sqlite3.connect(
- settings.get_path('profile.db'),
+ settings.get_path("profile.db"),
check_same_thread=False
)
return db
diff --git a/profyle/infrastructure/sqlite3/repository.py b/profyle/infrastructure/sqlite3/repository.py
index 7994c13..7026dfa 100644
--- a/profyle/infrastructure/sqlite3/repository.py
+++ b/profyle/infrastructure/sqlite3/repository.py
@@ -1,13 +1,16 @@
from sqlite3 import Connection, Error, Row
-from typing import List, Optional
+from typing import Optional
import json
from profyle.domain.trace import Trace, TraceCreate
from profyle.domain.trace_repository import TraceRepository
+from profyle.infrastructure.sqlite3.get_connection import get_connection
class SQLiteTraceRepository(TraceRepository):
- def __init__(self, db: Connection):
+ def __init__(self, db: Optional[Connection] = None):
+ if not db:
+ db = get_connection()
self.db = db
def create_trace_selected_table(self) -> None:
@@ -96,7 +99,7 @@ def store_trace(self, trace: TraceCreate) -> None:
except Error as error:
print("Failed to insert data into trace table", error)
- def get_all_traces(self) -> List[Trace]:
+ def get_all_traces(self) -> list[Trace]:
self.db.row_factory = Row
cursor = self.db.cursor()
cursor.execute(
@@ -123,7 +126,7 @@ def get_trace_selected(self) -> Optional[int]:
cursor.execute(
"SELECT trace_id FROM trace_selected where id = ?", (1,))
trace = cursor.fetchone()
- return trace['trace_id'] if trace else None
+ return trace["trace_id"] if trace else None
def delete_trace_by_id(self, trace_id: int):
cursor = self.db.cursor()
diff --git a/profyle/infrastructure/web/templates/traces.html b/profyle/infrastructure/web/templates/traces.html
index 6491958..30c8542 100644
--- a/profyle/infrastructure/web/templates/traces.html
+++ b/profyle/infrastructure/web/templates/traces.html
@@ -10,16 +10,19 @@
-
-
- Profyl
-
-
-
+
+
@@ -73,10 +76,20 @@ No traces
+ {%set method_and_name = trace.name.split(' ')%}
+
+ {% if method_and_name[1] %}
+
+ {{method_and_name[0]}}
+
+ {{method_and_name[1]}}
+ {% else %}
{{trace.name}}
+ {% endif %}
|
- {{trace.duration / 1000}}
+ {{(trace.duration / 1000) | round(2)}}
|
{{trace.timestamp}}
diff --git a/profyle/main.py b/profyle/main.py
index 9b59f67..e0e7788 100644
--- a/profyle/main.py
+++ b/profyle/main.py
@@ -17,8 +17,8 @@
@app.command()
-def start():
- asyncio.run(start_server())
+def start(port: int = 0, host: str = "127.0.0.1"):
+ asyncio.run(start_server(port=port, host=host))
@app.command()
@@ -27,16 +27,16 @@ def clean():
sqlite_repo = SQLiteTraceRepository(db)
removed_traces = delete_all_traces(sqlite_repo)
vacuum(sqlite_repo)
- print(f'[green]{removed_traces} traces removed [/green]')
+ print(f"[green]{removed_traces} traces removed [/green]")
@app.command()
def check():
- db_size_in_bytes = os.path.getsize(settings.get_path('profile.db'))
+ db_size_in_bytes = os.path.getsize(settings.get_path("profile.db"))
db_size_in_megabytes = round(db_size_in_bytes/10**6, 2)
db_size_in_gigabytes = round(db_size_in_megabytes/10**3, 2)
if db_size_in_megabytes > 1000:
- print(f'[orange1]DB size: {db_size_in_gigabytes} GB [/orange1]')
+ print(f"[orange1]DB size: {db_size_in_gigabytes} GB [/orange1]")
return
- print(f'[orange1]DB size: {db_size_in_megabytes} MB [/orange1]')
+ print(f"[orange1]DB size: {db_size_in_megabytes} MB [/orange1]")
diff --git a/profyle/settings.py b/profyle/settings.py
index 7e630ff..2f70ead 100644
--- a/profyle/settings.py
+++ b/profyle/settings.py
@@ -5,11 +5,11 @@
class Settings(BaseSettings):
- app_name: str = 'Profyle'
+ app_name: str = "Profyle"
project_dir: str = os.path.normpath(
os.path.join(
os.path.abspath(__file__),
- '..',
+ "..",
)
)
@@ -22,8 +22,8 @@ def get_path(self, *args):
def get_viztracer_static_files(self):
return os.path.normpath(os.path.join(
os.path.abspath(viztracer.__file__),
- '..',
- 'web_dist'
+ "..",
+ "web_dist"
))
diff --git a/pyproject.toml b/pyproject.toml
index 218f182..bceaac0 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -52,6 +52,7 @@ autopep8 = "^2.0.1"
pytest = "^7.2.1"
httpx = "^0.23.3"
flask = "^2.0.0"
+pytest-asyncio = "^0.21.1"
[tool.poetry.group.flask.dependencies]
diff --git a/tests/middleware/test_fastapi_middleware.py b/tests/middleware/test_fastapi_middleware.py
deleted file mode 100644
index c6e3d8f..0000000
--- a/tests/middleware/test_fastapi_middleware.py
+++ /dev/null
@@ -1,29 +0,0 @@
-from fastapi import APIRouter, FastAPI
-from fastapi.testclient import TestClient
-
-from profyle.fastapi import ProfyleMiddleware
-
-app = FastAPI()
-app.add_middleware(ProfyleMiddleware, pattern='*test[?]*')
-
-router = APIRouter()
-
-
-@router.post('/test')
-def run_middleware():
- return {'message': 'OK'}
-
-
-@router.post('/test1')
-def run_middleware_1():
- return {'message': 'OK'}
-
-
-app.include_router(router)
-
-client = TestClient(app)
-
-
-def test_fastapi_middleware():
- client.post('test?demo=true')
- client.post('test1')
diff --git a/tests/middleware/test_flask_middleware.py b/tests/middleware/test_flask_middleware.py
deleted file mode 100644
index e7190cd..0000000
--- a/tests/middleware/test_flask_middleware.py
+++ /dev/null
@@ -1,23 +0,0 @@
-import os
-
-from flask import Flask
-
-from profyle.flask import ProfyleMiddleware
-
-
-app = Flask('flask_test', root_path=os.path.dirname(__file__))
-app.config.update(
- TESTING=True,
- SECRET_KEY='test key',
-)
-
-app.wsgi_app = ProfyleMiddleware(app.wsgi_app, pattern='*test*')
-
-
-@app.route('/test')
-def index():
- return 'Test'
-
-
-def test_profyle_middleware():
- app.test_client().get('/test')
diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/unit/application/__init__.py b/tests/unit/application/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/unit/application/test_profyle.py b/tests/unit/application/test_profyle.py
new file mode 100644
index 0000000..2a264fa
--- /dev/null
+++ b/tests/unit/application/test_profyle.py
@@ -0,0 +1,64 @@
+import pytest
+import asyncio
+
+from profyle.application.profyle import profyle
+from tests.unit.repository import InMemoryTraceRepository
+
+
+def test_should_trace_a_process():
+ trace_repo = InMemoryTraceRepository()
+
+ with profyle(
+ name="test",
+ repo=trace_repo,
+ ):
+ print("demo")
+
+ assert len(trace_repo.traces) == 1
+ assert len(trace_repo.traces[0].data)
+
+
+@pytest.mark.asyncio
+async def test_should_trace_an_async_process():
+ trace_repo = InMemoryTraceRepository()
+
+ with profyle(
+ name="test",
+ repo=trace_repo,
+ ):
+ await asyncio.sleep(0.1)
+
+ assert len(trace_repo.traces) == 1
+ assert len(trace_repo.traces[0].data)
+ assert trace_repo.traces[0].duration/1000 > 0.1
+
+
+@pytest.mark.asyncio
+async def test_should_trace_a_process_with_min_duration():
+ trace_repo = InMemoryTraceRepository()
+
+ with profyle(
+ name="test",
+ repo=trace_repo,
+ min_duration=1
+ ):
+ await asyncio.sleep(2)
+
+ assert len(trace_repo.traces) == 1
+ assert len(trace_repo.traces[0].data)
+ assert trace_repo.traces[0].duration/1000 > 2
+
+
+@pytest.mark.asyncio
+async def test_should_not_trace_a_process_if_min_duration_not_reached():
+ trace_repo = InMemoryTraceRepository()
+ with profyle(
+ name="test",
+ repo=trace_repo,
+ min_duration=3000
+ ):
+ await asyncio.sleep(2)
+
+ assert len(trace_repo.traces) == 1
+ assert len(trace_repo.traces[0].data)
+ assert int(trace_repo.traces[0].duration/1000) == 0
diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py
new file mode 100644
index 0000000..533e742
--- /dev/null
+++ b/tests/unit/conftest.py
@@ -0,0 +1,65 @@
+
+import pytest
+import os
+
+from flask import Flask
+from fastapi import APIRouter, FastAPI
+from fastapi.testclient import TestClient
+
+@pytest.fixture
+def fastapi_app():
+ app = FastAPI()
+
+ router = APIRouter()
+
+ @router.post('/test-post')
+ async def test_post():
+ return {'message': 'OK'}
+
+ @router.get('/test-get')
+ async def test_get(demo: bool = False):
+ return {'message': demo}
+
+ @router.patch('/test-patch')
+ async def test_patch():
+ return {'message': 'OK'}
+
+ @router.put('/test-put')
+ async def test_put():
+ return {'message': 'OK'}
+
+ app.include_router(router)
+
+ yield app
+
+
+@pytest.fixture
+def flask_app():
+ app = Flask('flask_test', root_path=os.path.dirname(__file__))
+ app.config.update(
+ TESTING=True,
+ SECRET_KEY='test key',
+ )
+
+ @app.route('/test-post', methods=['POST'])
+ def test_post():
+ return 'Test'
+
+ @app.route('/test-get', methods=['GET'])
+ def test_get():
+ return 'Test'
+
+ @app.route('/test-patch', methods=['PATCH'])
+ def test_patch():
+ return 'Test'
+
+ yield app
+
+@pytest.fixture
+def flask_client(flask_app):
+ yield flask_app.test_client()
+
+
+@pytest.fixture()
+def fastapi_client(fastapi_app):
+ yield TestClient(fastapi_app)
diff --git a/tests/unit/infrastructure/__init__.py b/tests/unit/infrastructure/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/unit/infrastructure/middleware/__init__.py b/tests/unit/infrastructure/middleware/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/unit/infrastructure/middleware/test_fastapi_middleware.py b/tests/unit/infrastructure/middleware/test_fastapi_middleware.py
new file mode 100644
index 0000000..6ea22cb
--- /dev/null
+++ b/tests/unit/infrastructure/middleware/test_fastapi_middleware.py
@@ -0,0 +1,48 @@
+from profyle.fastapi import ProfyleMiddleware
+from tests.unit.repository import InMemoryTraceRepository
+
+
+def test_should_trace_all_requests(fastapi_client, fastapi_app):
+ trace_repo = InMemoryTraceRepository()
+ fastapi_app.add_middleware(
+ ProfyleMiddleware,
+ trace_repo=trace_repo
+ )
+
+ fastapi_client.post("test")
+ fastapi_client.get("test?demo=true")
+
+ assert len(trace_repo.traces) == 2
+ assert trace_repo.traces[0].name == "POST /test"
+ assert trace_repo.traces[1].name == "GET /test?demo=true"
+
+
+def test_should_trace_filtered_requests(fastapi_client, fastapi_app):
+ trace_repo = InMemoryTraceRepository()
+ fastapi_app.add_middleware(
+ ProfyleMiddleware,
+ pattern="/test*",
+ trace_repo=trace_repo
+ )
+
+ fastapi_client.post("test")
+ fastapi_client.get("test?demo=true")
+ fastapi_client.get("other")
+
+ assert len(trace_repo.traces) == 2
+ assert trace_repo.traces[0].name == "POST /test"
+ assert trace_repo.traces[1].name == "GET /test?demo=true"
+
+
+def test_should_no_trace_if_disabled(fastapi_client, fastapi_app):
+ trace_repo = InMemoryTraceRepository()
+ fastapi_app.add_middleware(
+ ProfyleMiddleware,
+ enabled=False,
+ trace_repo=InMemoryTraceRepository()
+ )
+
+ fastapi_client.post("test")
+ fastapi_client.get("test?demo=true")
+
+ assert len(trace_repo.traces) == 0
diff --git a/tests/unit/infrastructure/middleware/test_flask_middleware.py b/tests/unit/infrastructure/middleware/test_flask_middleware.py
new file mode 100644
index 0000000..41af5ff
--- /dev/null
+++ b/tests/unit/infrastructure/middleware/test_flask_middleware.py
@@ -0,0 +1,48 @@
+from profyle.infrastructure.middleware.flask import ProfyleMiddleware
+from tests.unit.repository import InMemoryTraceRepository
+
+
+def test_should_trace_all_requests(flask_client, flask_app):
+ trace_repo = InMemoryTraceRepository()
+ flask_app.wsgi_app = ProfyleMiddleware(
+ flask_app.wsgi_app,
+ trace_repo=trace_repo
+ )
+
+ flask_client.post("/test")
+ flask_client.get("/test?demo=true")
+
+ assert len(trace_repo.traces) == 2
+ assert trace_repo.traces[0].name == "POST /test"
+ assert trace_repo.traces[1].name == "GET /test?demo=true"
+
+
+def test_should_trace_filtered_requests(flask_client, flask_app):
+ trace_repo = InMemoryTraceRepository()
+ flask_app.wsgi_app = ProfyleMiddleware(
+ flask_app.wsgi_app,
+ trace_repo=trace_repo,
+ pattern="/test*",
+ )
+
+ flask_client.post("/test")
+ flask_client.get("/test?demo=true")
+ flask_client.get("/other")
+
+ assert len(trace_repo.traces) == 2
+ assert trace_repo.traces[0].name == "POST /test"
+ assert trace_repo.traces[1].name == "GET /test?demo=true"
+
+
+def test_should_no_trace_if_disabled(flask_client, flask_app):
+ trace_repo = InMemoryTraceRepository()
+ flask_app.wsgi_app = ProfyleMiddleware(
+ flask_app.wsgi_app,
+ trace_repo=trace_repo,
+ enabled=False,
+ )
+
+ flask_client.post("/test")
+ flask_client.get("/test?demo=true")
+
+ assert len(trace_repo.traces) == 0
diff --git a/tests/unit/repository.py b/tests/unit/repository.py
new file mode 100644
index 0000000..8e1cc26
--- /dev/null
+++ b/tests/unit/repository.py
@@ -0,0 +1,60 @@
+from typing import Optional
+from uuid import uuid4
+import time
+import json
+
+from profyle.domain.trace import TraceCreate
+from profyle.domain.trace_repository import TraceRepository
+from profyle.domain.trace import Trace
+
+
+class InMemoryTraceRepository(TraceRepository):
+ def __init__(self):
+ self.traces: list[Trace] = []
+ self.selected_trace: int = 0
+
+ def create_trace_selected_table(self) -> None:
+ ...
+
+ def create_trace_table(self) -> None:
+ ...
+
+ def delete_all_traces(self) -> int:
+ removed = len(self.traces)
+ self.traces = []
+ return removed
+
+ def vacuum(self) -> None:
+ ...
+
+ def store_trace_selected(self, trace_id: int) -> None:
+ self.selected_trace = trace_id
+
+ def store_trace(self, new_trace: TraceCreate) -> None:
+
+ trace = Trace(
+ id=uuid4().int,
+ timestamp=str(time.time()),
+ data= json.dumps(new_trace.data),
+ duration=new_trace.duration,
+ name=new_trace.name,
+ )
+ self.traces.append(trace)
+
+ def get_all_traces(self) -> list[Trace]:
+ return self.traces
+
+ def get_trace_by_id(self, id: int) -> Optional[Trace]:
+ for trace in self.traces:
+ if trace.id == id:
+ return trace
+ return
+
+ def get_trace_selected(self) -> Optional[int]:
+ return self.selected_trace
+
+ def delete_trace_by_id(self, trace_id: int):
+ for trace in self.traces:
+ if trace.id == trace_id:
+ self.traces.remove(trace)
+ return
|