Skip to content

Commit

Permalink
Add memray outputs to traces
Browse files Browse the repository at this point in the history
  • Loading branch information
Jay Chia committed Dec 11, 2024
1 parent 4266708 commit 1fe09de
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 5 deletions.
2 changes: 1 addition & 1 deletion benchmarking/ooms/big_task_heap_usage.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# /// script
# dependencies = ['numpy']
# dependencies = ['numpy', 'memray']
# ///

import argparse
Expand Down
20 changes: 18 additions & 2 deletions daft/runners/ray_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,14 @@ class EndTaskEvent(TaskEvent):

# End Unix timestamp
end: float
memory_stats: TaskMemoryStats


@dataclasses.dataclass(frozen=True)
class TaskMemoryStats:
peak_memory_allocated: int
total_memory_allocated: int
total_num_allocations: int


class _NodeInfo:
Expand Down Expand Up @@ -123,9 +131,15 @@ def mark_task_start(
)
)

def mark_task_end(self, execution_id: str, task_id: str, end: float):
def mark_task_end(
self,
execution_id: str,
task_id: str,
end: float,
memory_stats: TaskMemoryStats,
):
# Add an EndTaskEvent
self._task_events[execution_id].append(EndTaskEvent(task_id=task_id, end=end))
self._task_events[execution_id].append(EndTaskEvent(task_id=task_id, end=end, memory_stats=memory_stats))

def get_task_events(self, execution_id: str, idx: int) -> tuple[list[TaskEvent], int]:
events = self._task_events[execution_id]
Expand Down Expand Up @@ -177,11 +191,13 @@ def mark_task_end(
self,
task_id: str,
end: float,
memory_stats: TaskMemoryStats,
) -> None:
self.actor.mark_task_end.remote(
self.execution_id,
task_id,
end,
memory_stats,
)

def get_task_events(self, idx: int) -> tuple[list[TaskEvent], int]:
Expand Down
33 changes: 31 additions & 2 deletions daft/runners/ray_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import dataclasses
import json
import logging
import os
import pathlib
import time
from datetime import datetime
Expand Down Expand Up @@ -255,6 +256,11 @@ def _flush_task_metrics(self):
"ph": RunnerTracer.PHASE_ASYNC_END,
"pid": 1,
"tid": 2,
"args": {
"memray_peak_memory_allocated": task_event.memory_stats.peak_memory_allocated,
"memray_total_memory_allocated": task_event.memory_stats.total_memory_allocated,
"memray_total_num_allocations": task_event.memory_stats.total_num_allocations,
},
},
ts=end_ts,
)
Expand All @@ -272,6 +278,11 @@ def _flush_task_metrics(self):
"ph": RunnerTracer.PHASE_DURATION_END,
"pid": node_idx + RunnerTracer.NODE_PIDS_START,
"tid": worker_idx,
"args": {
"memray_peak_memory_allocated": task_event.memory_stats.peak_memory_allocated,
"memray_total_memory_allocated": task_event.memory_stats.total_memory_allocated,
"memray_total_num_allocations": task_event.memory_stats.total_num_allocations,
},
},
ts=end_ts,
)
Expand Down Expand Up @@ -656,8 +667,12 @@ def __next__(self):
def collect_ray_task_metrics(execution_id: str, task_id: str, stage_id: int, execution_config: PyDaftExecutionConfig):
"""Context manager that will ping the metrics actor to record various execution metrics about a given task."""
if execution_config.enable_ray_tracing:
import tempfile
import time

import memray
from memray._memray import compute_statistics

runtime_context = ray.get_runtime_context()

metrics_actor = ray_metrics.get_metrics_actor(execution_id)
Expand All @@ -670,7 +685,21 @@ def collect_ray_task_metrics(execution_id: str, task_id: str, stage_id: int, exe
runtime_context.get_assigned_resources(),
runtime_context.get_task_id(),
)
yield
metrics_actor.mark_task_end(task_id, time.time())
with tempfile.TemporaryDirectory() as tmpdir:
memray_tmpfile = os.path.join(tmpdir, f"task-{task_id}.memray.bin")
try:
with memray.Tracker(memray_tmpfile):
yield
finally:
stats = compute_statistics(memray_tmpfile)
metrics_actor.mark_task_end(
task_id,
time.time(),
ray_metrics.TaskMemoryStats(
peak_memory_allocated=stats.peak_memory_allocated,
total_memory_allocated=stats.total_memory_allocated,
total_num_allocations=stats.total_num_allocations,
),
)
else:
yield

0 comments on commit 1fe09de

Please sign in to comment.