Skip to content

Commit

Permalink
Merge pull request #13 from msnidal/feature/dot-prof-support
Browse files Browse the repository at this point in the history
Implement .prof output
  • Loading branch information
sunhailin-Leo authored May 5, 2023
2 parents abbc4e1 + 36a7327 commit 21ca3c0
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 17 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,7 @@ dmypy.json

# Pyre type checker
.pyre/

# Profiler outputs
*.prof
*.html
34 changes: 34 additions & 0 deletions example/fastapi_to_prof_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""
This example shows how to output the profile
to a .prof file.
"""
import os
import uvicorn

from fastapi import FastAPI
from fastapi.responses import JSONResponse

from fastapi_profiler import PyInstrumentProfilerMiddleware


app = FastAPI()
app.add_middleware(
PyInstrumentProfilerMiddleware,
server_app=app, # Required to output the profile on server shutdown
profiler_output_type="prof",
is_print_each_request=False, # Set to True to show request profile on
# stdout on each request
prof_file_name="example_profile.prof", # Filename for output
)


@app.get("/test")
async def normal_request():
return JSONResponse({"retMsg": "Hello World!"})


# Or you can use the console with command "uvicorn" to run this example.
# Command: uvicorn fastapi_example:app --host="0.0.0.0" --port=8080
if __name__ == "__main__":
app_name = os.path.basename(__file__).replace(".py", "")
uvicorn.run(app=f"{app_name}:app", host="0.0.0.0", port=8080, workers=1)
2 changes: 2 additions & 0 deletions example/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
starlette
uvicorn
58 changes: 46 additions & 12 deletions fastapi_profiler/profiler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
import time
import codecs
import cProfile
from io import StringIO
from typing import Optional
from logging import getLogger

Expand All @@ -17,24 +19,27 @@

class PyInstrumentProfilerMiddleware:
DEFAULT_HTML_FILENAME = "./fastapi-profiler.html"
DEFAULT_PROF_FILENAME = "./fastapi-profiler.prof"

def __init__(
self, app: ASGIApp,
self,
app: ASGIApp,
*,
server_app: Optional[Router] = None,
profiler_interval: float = 0.0001,
profiler_output_type: str = "text",
is_print_each_request: bool = True,
async_mode: str = "enabled",
html_file_name: Optional[str] = None,
prof_file_name: Optional[str] = None,
open_in_browser: bool = False,
**profiler_kwargs
**profiler_kwargs,
):
self.app = app
self._profiler = Profiler(interval=profiler_interval, async_mode=async_mode)
self._output_type = profiler_output_type
self._print_each_request = is_print_each_request
self._html_file_name: Optional[str] = html_file_name
self._prof_file_name: Optional[str] = prof_file_name
self._open_in_browser: bool = open_in_browser
self._profiler_kwargs: dict = profiler_kwargs

Expand All @@ -44,6 +49,15 @@ def __init__(
"to set shutdown event handler to output profile."
)

if profiler_output_type == "prof":
self._profiler = cProfile.Profile()
self._start_profiler = self._profiler.enable
self._stop_profiler = self._profiler.disable
else:
self._profiler = Profiler(interval=profiler_interval, async_mode=async_mode)
self._start_profiler = self._profiler.start
self._stop_profiler = self._profiler.stop

# register an event handler for profiler stop
if server_app is not None:
server_app.add_event_handler("shutdown", self.get_profiler_result)
Expand All @@ -53,7 +67,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
await self.app(scope, receive, send)
return

self._profiler.start()
self._start_profiler()

request = Request(scope, receive=receive)
method = request.method
Expand All @@ -65,23 +79,31 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
status_code = 500

async def wrapped_send(message: Message) -> None:
if message['type'] == 'http.response.start':
if message["type"] == "http.response.start":
nonlocal status_code
status_code = message['status']
status_code = message["status"]
await send(message)

try:
await self.app(scope, receive, wrapped_send)
finally:
if scope["type"] == "http":
self._profiler.stop()
self._stop_profiler()
end = time.perf_counter()
if self._print_each_request:
print(f"Method: {method}, "
f"Path: {path}, "
f"Duration: {end - begin}, "
f"Status: {status_code}")
print(self._profiler.output_text(**self._profiler_kwargs))
print(
f"Method: {method}, "
f"Path: {path}, "
f"Duration: {end - begin}, "
f"Status: {status_code}"
)

if self._output_type == "prof":
s = StringIO()
self._profiler.print_stats(stream=s)
print(s.getvalue())
else:
print(self._profiler.output_text(**self._profiler_kwargs))

async def get_profiler_result(self):
if self._output_type == "text":
Expand Down Expand Up @@ -109,3 +131,15 @@ async def get_profiler_result(self):
f.write(html_code)

logger.info("Done writing profile to %r", html_file_name)
elif self._output_type == "prof":
prof_file_name = self.DEFAULT_PROF_FILENAME
if self._prof_file_name is not None:
prof_file_name = self._prof_file_name

logger.info(
"Compiling and dumping final profile to %r - this may take some time",
prof_file_name,
)

self._profiler.dump_stats(prof_file_name)
logger.info("Done writing profile to %r", prof_file_name)
31 changes: 26 additions & 5 deletions test/test_middleware.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
import sys
import pytest
from io import StringIO
Expand All @@ -13,7 +12,6 @@

@pytest.fixture(name="test_middleware")
def test_middleware():

def _test_middleware(**profiler_kwargs):
app = FastAPI()
if profiler_kwargs.get("profiler_output_type") != "text":
Expand All @@ -25,6 +23,7 @@ async def normal_request(request):
return JSONResponse({"retMsg": "Normal Request test Success!"})

return app

return _test_middleware


Expand All @@ -43,19 +42,41 @@ def test_profiler_print_at_console(self, client):
client.get(request_path)

sys.stdout = temp_stdout
assert (f"Path: {request_path}" in stdout_redirect.fp.getvalue())
assert f"Path: {request_path}" in stdout_redirect.fp.getvalue()

def test_profiler_export_to_html(self, test_middleware, tmpdir):
full_path = tmpdir / "test.html"

with TestClient(test_middleware(
with TestClient(
test_middleware(
profiler_output_type="html",
is_print_each_request=False,
profiler_interval=0.0000001,
html_file_name=str(full_path))) as client:
html_file_name=str(full_path),
)
) as client:
# request
request_path = "/test"
client.get(request_path)

# HTML will record the py file name.
assert "profiler.py" in full_path.read_text("utf-8")

def test_profiler_export_to_prof(self, test_middleware, tmpdir):
full_path = tmpdir / "test.prof"

with TestClient(
test_middleware(
profiler_output_type="prof",
is_print_each_request=False,
profiler_interval=0.0000001,
prof_file_name=str(full_path),
)
) as client:
# request
request_path = "/test"
client.get(request_path)

# Check if the .prof file has been created and has content
assert full_path.exists()
assert full_path.read_binary()

0 comments on commit 21ca3c0

Please sign in to comment.