Skip to content

Commit

Permalink
Add memray integration (#8044)
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter authored Aug 14, 2023
1 parent b3dde5c commit a1659fd
Show file tree
Hide file tree
Showing 3 changed files with 293 additions and 0 deletions.
1 change: 1 addition & 0 deletions continuous_integration/environment-3.9.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ dependencies:
- jinja2 >=2.10.3
- locket >=1.0
- lz4 >=0.23.1 # Only tested here
- memray # Only tested here
- msgpack-python
- netcdf4
- paramiko
Expand Down
245 changes: 245 additions & 0 deletions distributed/diagnostics/memray.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
from __future__ import annotations

import contextlib
import logging
import pathlib
import subprocess
import time
import uuid
from collections.abc import Iterator, Sequence
from typing import Any, Literal
from urllib.parse import quote

from toolz.itertoolz import partition

from distributed import get_client
from distributed.worker import Worker

try:
import memray
except ImportError:
raise ImportError("You have to install memray to use this module.")

logger = logging.getLogger(__name__)


def _start_memray(dask_worker: Worker, filename: str, **kwargs: Any) -> bool:
"""Start the memray Tracker on a Server"""
if hasattr(dask_worker, "_memray"):
dask_worker._memray.close()

path = pathlib.Path(dask_worker.local_directory) / (filename + str(dask_worker.id))
if path.exists():
path.rmdir()

dask_worker._memray = contextlib.ExitStack() # type: ignore[attr-defined]
dask_worker._memray.enter_context( # type: ignore[attr-defined]
memray.Tracker(path, native_traces=True, **kwargs)
)

return True


def _fetch_memray_profile(
dask_worker: Worker, filename: str, report_args: Sequence[str] | Literal[False]
) -> bytes:
"""Generate and fetch the memray report"""
if not hasattr(dask_worker, "_memray"):
return b""
path = pathlib.Path(dask_worker.local_directory) / (filename + str(dask_worker.id))
dask_worker._memray.close()
del dask_worker._memray

if not report_args:
with open(path, "rb") as fd:
return fd.read()

report_filename = path.with_suffix(".html")
if not report_args[0] == "memray":
report_args = ["memray"] + list(report_args)
assert "-f" not in report_args, "Cannot provide filename for report generation"
assert (
"-o" not in report_args
), "Cannot provide output filename for report generation"
report_args = list(report_args) + ["-f", str(path), "-o", str(report_filename)]
subprocess.run(report_args)
with open(report_filename, "rb") as fd:
return fd.read()


@contextlib.contextmanager
def memray_workers(
directory: str | pathlib.Path = "memray-profiles",
workers: int | None | list[str] = None,
report_args: Sequence[str]
| Literal[False] = ("flamegraph", "--temporal", "--leaks"),
fetch_reports_parallel: bool | int = True,
**memray_kwargs: Any,
) -> Iterator[None]:
"""Generate a Memray profile on the workers and download the generated report.
Example::
with memray_workers():
client.submit(my_function).result()
# Or even while the computation is already running
fut = client.submit(my_function)
with memray_workers():
time.sleep(10)
fut.result()
Parameters
----------
directory : str
The directory to save the reports to.
workers : int | None | list[str]
The workers to profile. If int, the first n workers will be used.
If None, all workers will be used.
If list[str], the workers with the given addresses will be used.
report_args : tuple[str]
Particularly for native_traces=True, the reports have to be
generated on the same host using the same Python interpreter as the
profile was generated. Otherwise, native traces will yield unusable
results. Therefore, we're generating the reports on the workers and
download them afterwards. You can modify the report generation by
providing additional arguments and we will generate the reports as::
memray *report_args -f <filename> -o <filename>.html
If the raw data should be fetched instead of the report, set this to
False.
fetch_reports_parallel : bool | int
Fetching results is sometimes slow and it's sometimes not desired to
wait for all workers to finish before receiving the first reports.
This controls how many workers are fetched concurrently.
int: Number of workers to fetch concurrently
True: All workers concurrently
False: One worker at a time
**memray_kwargs
Keyword arguments to be passed to memray.Tracker, e.g.
{"native_traces": True}
"""
directory = pathlib.Path(directory)
client = get_client()
scheduler_info = client.scheduler_info()
worker_addr = scheduler_info["workers"]
worker_names = {
addr: winfo["name"] for addr, winfo in scheduler_info["workers"].items()
}
if not workers or isinstance(workers, int):
nworkers = len(worker_addr)
if isinstance(workers, int):
nworkers = workers
workers = list(worker_addr)[:nworkers]
workers = list(workers)
filename = uuid.uuid4().hex
assert all(client.run(_start_memray, filename=filename, **memray_kwargs).values())
# Sleep for a brief moment such that we get
# a clear profiling signal when everything starts
time.sleep(0.1)
yield
directory.mkdir(exist_ok=True)

client = get_client()
if fetch_reports_parallel is True:
fetch_parallel = len(workers)
elif fetch_reports_parallel is False:
fetch_parallel = 1
else:
fetch_parallel = fetch_reports_parallel

for w in partition(fetch_parallel, workers):
try:
profiles = client.run(
_fetch_memray_profile,
filename=filename,
report_args=report_args,
workers=w,
)
for worker_addr, profile in profiles.items():
path = directory / quote(worker_names[worker_addr], safe="")
if report_args:
suffix = ".html"
else:
suffix = ".memray"
with open(str(path) + suffix, "wb") as fd:
fd.write(profile)

except Exception:
logger.exception("Exception during report downloading from worker %s", w)


@contextlib.contextmanager
def memray_scheduler(
directory: str | pathlib.Path = "memray-profiles",
report_args: Sequence[str]
| Literal[False] = ("flamegraph", "--temporal", "--leaks"),
**memray_kwargs: Any,
) -> Iterator[None]:
"""Generate a Memray profile on the Scheduler and download the generated report.
Example::
with memray_scheduler():
client.submit(my_function).result()
# Or even while the computation is already running
fut = client.submit(my_function)
with memray_scheduler():
time.sleep(10)
fut.result()
Parameters
----------
directory : str
The directory to save the reports to.
report_args : tuple[str]
Particularly for native_traces=True, the reports have to be
generated on the same host using the same Python interpreter as the
profile was generated. Otherwise, native traces will yield unusable
results. Therefore, we're generating the reports on the Scheduler and
download them afterwards. You can modify the report generation by
providing additional arguments and we will generate the reports as::
memray *report_args -f <filename> -o <filename>.html
If the raw data should be fetched instead of the report, set this to
False.
**memray_kwargs
Keyword arguments to be passed to memray.Tracker, e.g.
{"native_traces": True}
"""
directory = pathlib.Path(directory)
client = get_client()
filename = uuid.uuid4().hex
assert client.run_on_scheduler(_start_memray, filename=filename, **memray_kwargs)
# Sleep for a brief moment such that we get
# a clear profiling signal when everything starts
time.sleep(0.1)
yield
directory.mkdir(exist_ok=True)

client = get_client()

profile = client.run_on_scheduler(
_fetch_memray_profile,
filename=filename,
report_args=report_args,
)
path = directory / "scheduler"
if report_args:
suffix = ".html"
else:
suffix = ".memray"
with open(str(path) + suffix, "wb") as fd:
fd.write(profile)
47 changes: 47 additions & 0 deletions distributed/diagnostics/tests/test_memray.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from __future__ import annotations

import pytest

memray = pytest.importorskip("memray")


from distributed.diagnostics.memray import memray_scheduler, memray_workers


@pytest.mark.parametrize("fetch_reports_parallel", [True, False, 1])
def test_basic_integration_workers(client, tmp_path, fetch_reports_parallel):
with memray_workers(tmp_path, fetch_reports_parallel=fetch_reports_parallel):
pass

assert len(list(tmp_path.glob("*.html"))) == 2


@pytest.mark.parametrize("report_args", [("flamegraph", "--leaks"), False])
def test_basic_integration_workers_report_args(client, tmp_path, report_args):
with memray_workers(tmp_path, report_args=report_args):
pass

if report_args:
assert len(list(tmp_path.glob("*.html"))) == 2
else:
assert len(list(tmp_path.glob("*.html"))) == 0
assert len(list(tmp_path.glob("*.memray"))) == 2


def test_basic_integration_scheduler(client, tmp_path):
with memray_scheduler(tmp_path):
pass

assert len(list(tmp_path.glob("*.html"))) == 1


@pytest.mark.parametrize("report_args", [("flamegraph", "--leaks"), False])
def test_basic_integration_scheduler_report_args(client, tmp_path, report_args):
with memray_scheduler(tmp_path, report_args=report_args):
pass

if report_args:
assert len(list(tmp_path.glob("*.html"))) == 1
else:
assert len(list(tmp_path.glob("*.html"))) == 0
assert len(list(tmp_path.glob("*.memray"))) == 1

0 comments on commit a1659fd

Please sign in to comment.