Skip to content

Commit

Permalink
action tracing for anomaly detection (#1038)
Browse files Browse the repository at this point in the history
* Add trace log level

add

trace fix

* Add persistent trace logging

* Convert sandbox messages

* Convert eval log file operations

* Convert model calls

* Convert cache

* Give trace actions unique ids

* Add simple sampe logging

* Add simple trace to task init

* Correct old log mapping

* Correct trace level

* fix typing error

* tweaks

* revisiosn to trace logging

* trace log using json lines

* pydantic for trace log

* anomolies

* get trace file path

* Basic trace anomoly logic

* backstop for when solvers fail to handle their own TimeoutError

* timeout for docker listing operations

* timeout on write file

* fix formatting

* correct spelling for anomolies

* Don’t require trace file name (use current if none provided)

* Improve trace output

- sort by completed time (not start time)
- display errors
- display duration

* fix formatting errors

* sort descending so last finished item is at the top

* Update CHANGELOG.md

---------

Co-authored-by: J.J. Allaire <[email protected]>
  • Loading branch information
dragonstyle and jjallaire authored Dec 23, 2024
1 parent 2552383 commit 36b2a7f
Show file tree
Hide file tree
Showing 18 changed files with 612 additions and 186 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## Unreleased

- [Action tracing](https://github.com/UKGovernmentBEIS/inspect_ai/pull/1038) for diagnosing runs with unterminated action (e.g. model calls, docker commands, etc.).
- Task display: Added `--no-score-display` option to disable realtime scoring metrics.
- Bugfix: Fix failure to fully clone samples that have message lists as input.

Expand Down
2 changes: 2 additions & 0 deletions src/inspect_ai/_cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .log import log_command
from .sandbox import sandbox_command
from .score import score_command
from .trace import trace_command
from .view import view_command


Expand Down Expand Up @@ -46,6 +47,7 @@ def inspect(ctx: click.Context, version: bool) -> None:
inspect.add_command(score_command)
inspect.add_command(view_command)
inspect.add_command(sandbox_command)
inspect.add_command(trace_command)


def main() -> None:
Expand Down
165 changes: 165 additions & 0 deletions src/inspect_ai/_cli/trace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
import time
from datetime import datetime
from json import dumps
from pathlib import Path

import click
from pydantic_core import to_json
from rich import print as r_print

from inspect_ai._util.error import PrerequisiteError
from inspect_ai._util.logger import TRACE_FILE_NAME
from inspect_ai._util.trace import ActionTraceRecord, inspect_trace_dir, read_trace_file


@click.group("trace")
def trace_command() -> None:
"""List and read execution traces.
Inspect includes a TRACE log-level which is right below the HTTP and INFO log levels (so not written to the console by default). However, TRACE logs are always recorded to a separate file, and the last 10 TRACE logs are preserved. The 'trace' command provides ways to list and read these traces.
"""
return None


@trace_command.command("list")
@click.option(
"--json",
type=bool,
is_flag=True,
default=False,
help="Output listing as JSON",
)
def list_command(json: bool) -> None:
"""List all trace files."""
trace_dir = inspect_trace_dir()
trace_files = [f.absolute().as_posix() for f in trace_dir.iterdir() if f.is_file()]
if json:
print(dumps(trace_files, indent=2))
else:
print("\n".join(trace_files))


@trace_command.command("read")
@click.argument("trace-file", type=str, required=True)
def read_command(trace_file: str) -> None:
"""Read a trace file as a JSON array of log records."""
trace_file_path = resolve_trace_file_path(trace_file)

traces = read_trace_file(trace_file_path)
print(
to_json(traces, indent=2, exclude_none=True, fallback=lambda _: None).decode()
)


@trace_command.command("anomalies")
@click.argument("trace-file", type=str, required=False, default=TRACE_FILE_NAME)
def anomolies_command(trace_file: str) -> None:
"""Look for anomalies in a trace file (never completed or cancelled actions)."""
trace_file_path = resolve_trace_file_path(trace_file)
traces = read_trace_file(trace_file_path)

# Track started actions
running_actions: dict[str, ActionTraceRecord] = {}
error_actions: dict[str, ActionTraceRecord] = {}
canceled_actions: dict[str, ActionTraceRecord] = {}

def action_started(trace: ActionTraceRecord) -> None:
running_actions[trace.trace_id] = trace

def action_completed(trace: ActionTraceRecord) -> ActionTraceRecord:
start_trace = running_actions.get(trace.trace_id)
if start_trace:
del running_actions[trace.trace_id]
return start_trace
else:
raise RuntimeError(f"Expected {trace.trace_id} in action dictionary.")

def action_failed(trace: ActionTraceRecord) -> None:
error_actions[start_trace.trace_id] = trace

def action_canceled(trace: ActionTraceRecord) -> None:
canceled_actions[start_trace.trace_id] = trace

for trace in traces:
if isinstance(trace, ActionTraceRecord):
match trace.event:
case "enter":
action_started(trace)
case "exit":
action_completed(trace)
case "cancel":
# Complete with a cancellation
start_trace = action_completed(trace)

# add duration
trace.start_time = start_trace.start_time

action_canceled(trace)
case "error":
# Capture error events
start_trace = action_completed(trace)

# add start time
trace.start_time = start_trace.start_time

action_failed(trace)
continue
case _:
print(f"Unknown event type: {trace.event}")

_print_bucket("Running Actions", running_actions)
_print_bucket("Canceled Actions", canceled_actions)
_print_bucket("Error Actions", error_actions)


def _print_bucket(label: str, bucket: dict[str, ActionTraceRecord]) -> None:
if len(bucket) > 0:
# Sort the items in chronological order of when
# they finished so the first finished item is at the top
sorted_actions = sorted(
bucket.values(),
key=lambda record: (record.start_time or 0) + (record.duration or 0),
reverse=True,
)

r_print(f"[bold]{label}[/bold]")
for action in sorted_actions:
# Compute duration (use the event duration or time since started)
duration = (
action.duration
if action.duration is not None
else time.time() - action.start_time
if action.start_time is not None
else 0.0
)

# The event start time
start_time = formatTime(action.start_time) if action.start_time else "None"
if action.event == "error":
# print errors
print(
f"{start_time} ({round(duration, 2)}s): {action.message} {action.error}"
)
else:
# print the action
print(f"{start_time} ({round(duration, 2)}s): {action.message}")
print("")


def resolve_trace_file_path(trace_file: str) -> Path:
trace_file_path = Path(trace_file)
if not trace_file_path.is_absolute():
trace_file_path = inspect_trace_dir() / trace_file_path

if not trace_file_path.exists():
raise PrerequisiteError(
f"The specified trace file '{trace_file_path}' does not exist."
)

return trace_file_path


def formatTime(timestamp: float) -> str:
# ISO format with timezone
dt = datetime.fromtimestamp(timestamp)
return dt.isoformat()
18 changes: 11 additions & 7 deletions src/inspect_ai/_eval/task/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,14 +580,18 @@ def handle_error(ex: BaseException) -> EvalError:
state = await plan(state, generate)

except TimeoutError:
# notify the user
transcript()._event(
SampleLimitEvent(
type="time",
message=f"Sample completed: exceeded time limit ({time_limit:,} seconds)",
limit=time_limit,
if time_limit is not None:
transcript()._event(
SampleLimitEvent(
type="time",
message=f"Sample completed: exceeded time limit ({time_limit:,} seconds)",
limit=time_limit,
)
)
else:
py_logger.warning(
"Unexpected timeout error reached top of sample stack. Are you handling TimeoutError when applying timeouts?"
)
)

# capture most recent state for scoring
state = sample_state() or state
Expand Down
6 changes: 3 additions & 3 deletions src/inspect_ai/_util/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
DEFAULT_SERVER_HOST = "127.0.0.1"
HTTP = 15
HTTP_LOG_LEVEL = "HTTP"
SANDBOX = 17
SANDBOX_LOG_LEVEL = "SANDBOX"
TRACE = 13
TRACE_LOG_LEVEL = "TRACE"
ALL_LOG_LEVELS = [
"DEBUG",
TRACE_LOG_LEVEL,
HTTP_LOG_LEVEL,
SANDBOX_LOG_LEVEL,
"INFO",
"WARNING",
"ERROR",
Expand Down
41 changes: 33 additions & 8 deletions src/inspect_ai/_util/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,28 @@
getLevelName,
getLogger,
)
from logging.handlers import RotatingFileHandler

import rich
from rich.console import ConsoleRenderable
from rich.logging import RichHandler
from rich.text import Text
from typing_extensions import override

from inspect_ai._util.constants import (
from .constants import (
ALL_LOG_LEVELS,
DEFAULT_LOG_LEVEL,
DEFAULT_LOG_LEVEL_TRANSCRIPT,
HTTP,
HTTP_LOG_LEVEL,
PKG_NAME,
SANDBOX,
SANDBOX_LOG_LEVEL,
TRACE,
TRACE_LOG_LEVEL,
)
from inspect_ai._util.error import PrerequisiteError
from .error import PrerequisiteError
from .trace import TraceFormatter, inspect_trace_dir

TRACE_FILE_NAME = "trace.log"


# log handler that filters messages to stderr and the log file
Expand All @@ -52,6 +56,23 @@ def __init__(self, levelno: int, transcript_levelno: int) -> None:
else:
self.file_logger_level = 0

# add a trace handler
default_trace_file = inspect_trace_dir() / TRACE_FILE_NAME
have_existing_trace_file = default_trace_file.exists()
trace_file = os.environ.get("INSPECT_TRACE_FILE", default_trace_file)
trace_total_files = 10
self.trace_logger = RotatingFileHandler(
trace_file,
backupCount=trace_total_files - 1, # exclude the current file (10 total)
)
self.trace_logger.setFormatter(TraceFormatter())
if have_existing_trace_file:
self.trace_logger.doRollover()

# set trace level
trace_level = os.environ.get("INSPECT_TRACE_LEVEL", TRACE_LOG_LEVEL)
self.trace_logger_level = int(getLevelName(trace_level.upper()))

@override
def emit(self, record: LogRecord) -> None:
# demote httpx and return notifications to log_level http
Expand Down Expand Up @@ -79,6 +100,10 @@ def emit(self, record: LogRecord) -> None:
):
self.file_logger.emit(record)

# write to trace if the trace level matches.
if self.trace_logger and record.levelno >= self.trace_logger_level:
self.trace_logger.emit(record)

# eval log always gets info level and higher records
# eval log only gets debug or http if we opt-in
write = record.levelno >= self.transcript_levelno
Expand All @@ -95,12 +120,12 @@ def init_logger(
log_level: str | None = None, log_level_transcript: str | None = None
) -> None:
# backwards compatibility for 'tools'
if log_level == "tools":
log_level = "sandbox"
if log_level == "sandbox" or log_level == "tools":
log_level = "trace"

# register http and tools levels
addLevelName(HTTP, HTTP_LOG_LEVEL)
addLevelName(SANDBOX, SANDBOX_LOG_LEVEL)
addLevelName(TRACE, TRACE_LOG_LEVEL)

def validate_level(option: str, level: str) -> None:
if level not in ALL_LOG_LEVELS:
Expand Down Expand Up @@ -134,7 +159,7 @@ def validate_level(option: str, level: str) -> None:
getLogger().addHandler(_logHandler)

# establish default capture level
capture_level = min(HTTP, levelno)
capture_level = min(TRACE, levelno)

# see all the messages (we won't actually display/write all of them)
getLogger().setLevel(capture_level)
Expand Down
Loading

0 comments on commit 36b2a7f

Please sign in to comment.