-
-
Notifications
You must be signed in to change notification settings - Fork 720
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
293 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,245 @@ | ||
from __future__ import annotations | ||
|
||
import contextlib | ||
import logging | ||
import pathlib | ||
import subprocess | ||
import time | ||
import uuid | ||
from collections.abc import Iterator, Sequence | ||
from typing import Any, Literal | ||
from urllib.parse import quote | ||
|
||
from toolz.itertoolz import partition | ||
|
||
from distributed import get_client | ||
from distributed.worker import Worker | ||
|
||
try: | ||
import memray | ||
except ImportError: | ||
raise ImportError("You have to install memray to use this module.") | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def _start_memray(dask_worker: Worker, filename: str, **kwargs: Any) -> bool: | ||
"""Start the memray Tracker on a Server""" | ||
if hasattr(dask_worker, "_memray"): | ||
dask_worker._memray.close() | ||
|
||
path = pathlib.Path(dask_worker.local_directory) / (filename + str(dask_worker.id)) | ||
if path.exists(): | ||
path.rmdir() | ||
|
||
dask_worker._memray = contextlib.ExitStack() # type: ignore[attr-defined] | ||
dask_worker._memray.enter_context( # type: ignore[attr-defined] | ||
memray.Tracker(path, native_traces=True, **kwargs) | ||
) | ||
|
||
return True | ||
|
||
|
||
def _fetch_memray_profile( | ||
dask_worker: Worker, filename: str, report_args: Sequence[str] | Literal[False] | ||
) -> bytes: | ||
"""Generate and fetch the memray report""" | ||
if not hasattr(dask_worker, "_memray"): | ||
return b"" | ||
path = pathlib.Path(dask_worker.local_directory) / (filename + str(dask_worker.id)) | ||
dask_worker._memray.close() | ||
del dask_worker._memray | ||
|
||
if not report_args: | ||
with open(path, "rb") as fd: | ||
return fd.read() | ||
|
||
report_filename = path.with_suffix(".html") | ||
if not report_args[0] == "memray": | ||
report_args = ["memray"] + list(report_args) | ||
assert "-f" not in report_args, "Cannot provide filename for report generation" | ||
assert ( | ||
"-o" not in report_args | ||
), "Cannot provide output filename for report generation" | ||
report_args = list(report_args) + ["-f", str(path), "-o", str(report_filename)] | ||
subprocess.run(report_args) | ||
with open(report_filename, "rb") as fd: | ||
return fd.read() | ||
|
||
|
||
@contextlib.contextmanager | ||
def memray_workers( | ||
directory: str | pathlib.Path = "memray-profiles", | ||
workers: int | None | list[str] = None, | ||
report_args: Sequence[str] | ||
| Literal[False] = ("flamegraph", "--temporal", "--leaks"), | ||
fetch_reports_parallel: bool | int = True, | ||
**memray_kwargs: Any, | ||
) -> Iterator[None]: | ||
"""Generate a Memray profile on the workers and download the generated report. | ||
Example:: | ||
with memray_workers(): | ||
client.submit(my_function).result() | ||
# Or even while the computation is already running | ||
fut = client.submit(my_function) | ||
with memray_workers(): | ||
time.sleep(10) | ||
fut.result() | ||
Parameters | ||
---------- | ||
directory : str | ||
The directory to save the reports to. | ||
workers : int | None | list[str] | ||
The workers to profile. If int, the first n workers will be used. | ||
If None, all workers will be used. | ||
If list[str], the workers with the given addresses will be used. | ||
report_args : tuple[str] | ||
Particularly for native_traces=True, the reports have to be | ||
generated on the same host using the same Python interpreter as the | ||
profile was generated. Otherwise, native traces will yield unusable | ||
results. Therefore, we're generating the reports on the workers and | ||
download them afterwards. You can modify the report generation by | ||
providing additional arguments and we will generate the reports as:: | ||
memray *report_args -f <filename> -o <filename>.html | ||
If the raw data should be fetched instead of the report, set this to | ||
False. | ||
fetch_reports_parallel : bool | int | ||
Fetching results is sometimes slow and it's sometimes not desired to | ||
wait for all workers to finish before receiving the first reports. | ||
This controls how many workers are fetched concurrently. | ||
int: Number of workers to fetch concurrently | ||
True: All workers concurrently | ||
False: One worker at a time | ||
**memray_kwargs | ||
Keyword arguments to be passed to memray.Tracker, e.g. | ||
{"native_traces": True} | ||
""" | ||
directory = pathlib.Path(directory) | ||
client = get_client() | ||
scheduler_info = client.scheduler_info() | ||
worker_addr = scheduler_info["workers"] | ||
worker_names = { | ||
addr: winfo["name"] for addr, winfo in scheduler_info["workers"].items() | ||
} | ||
if not workers or isinstance(workers, int): | ||
nworkers = len(worker_addr) | ||
if isinstance(workers, int): | ||
nworkers = workers | ||
workers = list(worker_addr)[:nworkers] | ||
workers = list(workers) | ||
filename = uuid.uuid4().hex | ||
assert all(client.run(_start_memray, filename=filename, **memray_kwargs).values()) | ||
# Sleep for a brief moment such that we get | ||
# a clear profiling signal when everything starts | ||
time.sleep(0.1) | ||
yield | ||
directory.mkdir(exist_ok=True) | ||
|
||
client = get_client() | ||
if fetch_reports_parallel is True: | ||
fetch_parallel = len(workers) | ||
elif fetch_reports_parallel is False: | ||
fetch_parallel = 1 | ||
else: | ||
fetch_parallel = fetch_reports_parallel | ||
|
||
for w in partition(fetch_parallel, workers): | ||
try: | ||
profiles = client.run( | ||
_fetch_memray_profile, | ||
filename=filename, | ||
report_args=report_args, | ||
workers=w, | ||
) | ||
for worker_addr, profile in profiles.items(): | ||
path = directory / quote(worker_names[worker_addr], safe="") | ||
if report_args: | ||
suffix = ".html" | ||
else: | ||
suffix = ".memray" | ||
with open(str(path) + suffix, "wb") as fd: | ||
fd.write(profile) | ||
|
||
except Exception: | ||
logger.exception("Exception during report downloading from worker %s", w) | ||
|
||
|
||
@contextlib.contextmanager | ||
def memray_scheduler( | ||
directory: str | pathlib.Path = "memray-profiles", | ||
report_args: Sequence[str] | ||
| Literal[False] = ("flamegraph", "--temporal", "--leaks"), | ||
**memray_kwargs: Any, | ||
) -> Iterator[None]: | ||
"""Generate a Memray profile on the Scheduler and download the generated report. | ||
Example:: | ||
with memray_scheduler(): | ||
client.submit(my_function).result() | ||
# Or even while the computation is already running | ||
fut = client.submit(my_function) | ||
with memray_scheduler(): | ||
time.sleep(10) | ||
fut.result() | ||
Parameters | ||
---------- | ||
directory : str | ||
The directory to save the reports to. | ||
report_args : tuple[str] | ||
Particularly for native_traces=True, the reports have to be | ||
generated on the same host using the same Python interpreter as the | ||
profile was generated. Otherwise, native traces will yield unusable | ||
results. Therefore, we're generating the reports on the Scheduler and | ||
download them afterwards. You can modify the report generation by | ||
providing additional arguments and we will generate the reports as:: | ||
memray *report_args -f <filename> -o <filename>.html | ||
If the raw data should be fetched instead of the report, set this to | ||
False. | ||
**memray_kwargs | ||
Keyword arguments to be passed to memray.Tracker, e.g. | ||
{"native_traces": True} | ||
""" | ||
directory = pathlib.Path(directory) | ||
client = get_client() | ||
filename = uuid.uuid4().hex | ||
assert client.run_on_scheduler(_start_memray, filename=filename, **memray_kwargs) | ||
# Sleep for a brief moment such that we get | ||
# a clear profiling signal when everything starts | ||
time.sleep(0.1) | ||
yield | ||
directory.mkdir(exist_ok=True) | ||
|
||
client = get_client() | ||
|
||
profile = client.run_on_scheduler( | ||
_fetch_memray_profile, | ||
filename=filename, | ||
report_args=report_args, | ||
) | ||
path = directory / "scheduler" | ||
if report_args: | ||
suffix = ".html" | ||
else: | ||
suffix = ".memray" | ||
with open(str(path) + suffix, "wb") as fd: | ||
fd.write(profile) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
from __future__ import annotations | ||
|
||
import pytest | ||
|
||
memray = pytest.importorskip("memray") | ||
|
||
|
||
from distributed.diagnostics.memray import memray_scheduler, memray_workers | ||
|
||
|
||
@pytest.mark.parametrize("fetch_reports_parallel", [True, False, 1]) | ||
def test_basic_integration_workers(client, tmp_path, fetch_reports_parallel): | ||
with memray_workers(tmp_path, fetch_reports_parallel=fetch_reports_parallel): | ||
pass | ||
|
||
assert len(list(tmp_path.glob("*.html"))) == 2 | ||
|
||
|
||
@pytest.mark.parametrize("report_args", [("flamegraph", "--leaks"), False]) | ||
def test_basic_integration_workers_report_args(client, tmp_path, report_args): | ||
with memray_workers(tmp_path, report_args=report_args): | ||
pass | ||
|
||
if report_args: | ||
assert len(list(tmp_path.glob("*.html"))) == 2 | ||
else: | ||
assert len(list(tmp_path.glob("*.html"))) == 0 | ||
assert len(list(tmp_path.glob("*.memray"))) == 2 | ||
|
||
|
||
def test_basic_integration_scheduler(client, tmp_path): | ||
with memray_scheduler(tmp_path): | ||
pass | ||
|
||
assert len(list(tmp_path.glob("*.html"))) == 1 | ||
|
||
|
||
@pytest.mark.parametrize("report_args", [("flamegraph", "--leaks"), False]) | ||
def test_basic_integration_scheduler_report_args(client, tmp_path, report_args): | ||
with memray_scheduler(tmp_path, report_args=report_args): | ||
pass | ||
|
||
if report_args: | ||
assert len(list(tmp_path.glob("*.html"))) == 1 | ||
else: | ||
assert len(list(tmp_path.glob("*.html"))) == 0 | ||
assert len(list(tmp_path.glob("*.memray"))) == 1 |