Skip to content

Commit

Permalink
Improve trace output
Browse files Browse the repository at this point in the history
- sort by completed time (not start time)
- display errors
- display duration
  • Loading branch information
dragonstyle committed Dec 23, 2024
1 parent dea06b0 commit 961ef22
Showing 1 changed file with 52 additions and 19 deletions.
71 changes: 52 additions & 19 deletions src/inspect_ai/_cli/trace.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from datetime import datetime
from json import dumps
from pathlib import Path
import time

import click
from pydantic_core import to_json
from rich import print as r_print

from inspect_ai._util.error import PrerequisiteError
from inspect_ai._util.logger import TRACE_FILE_NAME
Expand Down Expand Up @@ -57,17 +59,17 @@ def anomolies_command(trace_file: str) -> None:
traces = read_trace_file(trace_file_path)

# Track started actions
started_actions: dict[str, ActionTraceRecord] = {}
running_actions: dict[str, ActionTraceRecord] = {}
error_actions: dict[str, ActionTraceRecord] = {}
canceled_actions: dict[str, ActionTraceRecord] = {}

def action_started(trace: ActionTraceRecord) -> None:
started_actions[trace.trace_id] = trace
running_actions[trace.trace_id] = trace

def action_completed(trace: ActionTraceRecord) -> ActionTraceRecord:
start_trace = started_actions.get(trace.trace_id)
start_trace = running_actions.get(trace.trace_id)
if start_trace:
del started_actions[trace.trace_id]
del running_actions[trace.trace_id]
return start_trace
else:
raise RuntimeError(f"Expected {trace.trace_id} in action dictionary.")
Expand All @@ -78,16 +80,6 @@ def action_failed(trace: ActionTraceRecord) -> None:
def action_canceled(trace: ActionTraceRecord) -> None:
canceled_actions[start_trace.trace_id] = trace

def print_bucket(label: str, bucket: dict[str, ActionTraceRecord]) -> None:
if len(bucket) > 0:
print(label)
for id, action in bucket.items():
start_time = (
formatTime(action.start_time) if action.start_time else "None"
)
print(f"{start_time} {action.message}")
print("")

for trace in traces:
if isinstance(trace, ActionTraceRecord):
match trace.event:
Expand All @@ -98,18 +90,59 @@ def print_bucket(label: str, bucket: dict[str, ActionTraceRecord]) -> None:
case "cancel":
# Complete with a cancellation
start_trace = action_completed(trace)
action_canceled(start_trace)

# add duration
trace.start_time = start_trace.start_time

action_canceled(trace)
case "error":
# Capture error events
start_trace = action_completed(trace)
action_failed(start_trace)

# add start time
trace.start_time = start_trace.start_time

action_failed(trace)
continue
case _:
print(f"Unknown event type: {trace.event}")

print_bucket("Incomplete Actions", started_actions)
print_bucket("Error Actions", error_actions)
print_bucket("Canceled Actions", canceled_actions)
_print_bucket("Running Actions", running_actions)
_print_bucket("Canceled Actions", canceled_actions)
_print_bucket("Error Actions", error_actions)


def _print_bucket(label: str, bucket: dict[str, ActionTraceRecord]) -> None:
if len(bucket) > 0:
# Sort the items in chronological order of when
# they finished so the first finished item is at the top
sorted_actions = sorted(
bucket.values(),
key=lambda record: (record.start_time or 0) + (record.duration or 0),
)

r_print(f"[bold]{label}[/bold]")
for action in sorted_actions:
# Compute duration (use the event duration or time since started)
duration = (
action.duration
if action.duration is not None
else time.time() - action.start_time
if action.start_time is not None
else 0.0
)

# The event start time
start_time = formatTime(action.start_time) if action.start_time else "None"
if action.event == "error":
# print errors
print(
f"{start_time} ({round(duration, 2)}s): {action.message} {action.error}"
)
else:
# print the action
print(f"{start_time} ({round(duration, 2)}s): {action.message}")
print("")


def resolve_trace_file_path(trace_file: str) -> Path:
Expand Down

0 comments on commit 961ef22

Please sign in to comment.