Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

action tracing for anomaly detection #1038

Merged
merged 31 commits into from
Dec 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
cb5e0eb
Add trace log level
dragonstyle Dec 20, 2024
cbbe6d1
Add persistent trace logging
dragonstyle Dec 20, 2024
1d1c2cd
Convert sandbox messages
dragonstyle Dec 20, 2024
97f835f
Convert eval log file operations
dragonstyle Dec 20, 2024
7c603b8
Convert model calls
dragonstyle Dec 20, 2024
01a508b
Convert cache
dragonstyle Dec 20, 2024
6bdbfc2
Give trace actions unique ids
dragonstyle Dec 20, 2024
d050aa5
Add simple sampe logging
dragonstyle Dec 20, 2024
bfc1cf2
Add simple trace to task init
dragonstyle Dec 20, 2024
f10cad5
Correct old log mapping
dragonstyle Dec 20, 2024
243daea
Correct trace level
dragonstyle Dec 20, 2024
74ef411
Merge remote-tracking branch 'origin/main' into feature/trace
jjallaire Dec 22, 2024
40acf7f
fix typing error
jjallaire Dec 22, 2024
dedeb1b
tweaks
jjallaire Dec 22, 2024
a7c4a35
revisiosn to trace logging
jjallaire Dec 22, 2024
ab659f1
trace log using json lines
jjallaire Dec 23, 2024
84244d2
pydantic for trace log
jjallaire Dec 23, 2024
7ea5875
anomolies
jjallaire Dec 23, 2024
24544ec
get trace file path
jjallaire Dec 23, 2024
76963cd
Basic trace anomoly logic
dragonstyle Dec 23, 2024
b9aa123
backstop for when solvers fail to handle their own TimeoutError
jjallaire Dec 23, 2024
3f1fd06
timeout for docker listing operations
jjallaire Dec 23, 2024
1c3afd5
timeout on write file
jjallaire Dec 23, 2024
f1b6a0f
fix formatting
jjallaire Dec 23, 2024
d60813e
correct spelling for anomolies
jjallaire Dec 23, 2024
dea06b0
Don’t require trace file name (use current if none provided)
dragonstyle Dec 23, 2024
961ef22
Improve trace output
dragonstyle Dec 23, 2024
cba0904
fix formatting errors
dragonstyle Dec 23, 2024
7a51490
sort descending so last finished item is at the top
dragonstyle Dec 23, 2024
d2b821d
Update CHANGELOG.md
jjallaire Dec 23, 2024
6d84d40
Merge branch 'main' into feature/trace
jjallaire Dec 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## Unreleased

- [Action tracing](https://github.com/UKGovernmentBEIS/inspect_ai/pull/1038) for diagnosing runs with unterminated action (e.g. model calls, docker commands, etc.).
- Task display: Added `--no-score-display` option to disable realtime scoring metrics.
- Bugfix: Fix failure to fully clone samples that have message lists as input.

Expand Down
2 changes: 2 additions & 0 deletions src/inspect_ai/_cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .log import log_command
from .sandbox import sandbox_command
from .score import score_command
from .trace import trace_command
from .view import view_command


Expand Down Expand Up @@ -46,6 +47,7 @@ def inspect(ctx: click.Context, version: bool) -> None:
inspect.add_command(score_command)
inspect.add_command(view_command)
inspect.add_command(sandbox_command)
inspect.add_command(trace_command)


def main() -> None:
Expand Down
165 changes: 165 additions & 0 deletions src/inspect_ai/_cli/trace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
import time
from datetime import datetime
from json import dumps
from pathlib import Path

import click
from pydantic_core import to_json
from rich import print as r_print

from inspect_ai._util.error import PrerequisiteError
from inspect_ai._util.logger import TRACE_FILE_NAME
from inspect_ai._util.trace import ActionTraceRecord, inspect_trace_dir, read_trace_file


@click.group("trace")
def trace_command() -> None:
"""List and read execution traces.

Inspect includes a TRACE log-level which is right below the HTTP and INFO log levels (so not written to the console by default). However, TRACE logs are always recorded to a separate file, and the last 10 TRACE logs are preserved. The 'trace' command provides ways to list and read these traces.
"""
return None


@trace_command.command("list")
@click.option(
"--json",
type=bool,
is_flag=True,
default=False,
help="Output listing as JSON",
)
def list_command(json: bool) -> None:
"""List all trace files."""
trace_dir = inspect_trace_dir()
trace_files = [f.absolute().as_posix() for f in trace_dir.iterdir() if f.is_file()]
if json:
print(dumps(trace_files, indent=2))
else:
print("\n".join(trace_files))


@trace_command.command("read")
@click.argument("trace-file", type=str, required=True)
def read_command(trace_file: str) -> None:
"""Read a trace file as a JSON array of log records."""
trace_file_path = resolve_trace_file_path(trace_file)

traces = read_trace_file(trace_file_path)
print(
to_json(traces, indent=2, exclude_none=True, fallback=lambda _: None).decode()
)


@trace_command.command("anomalies")
@click.argument("trace-file", type=str, required=False, default=TRACE_FILE_NAME)
def anomolies_command(trace_file: str) -> None:
"""Look for anomalies in a trace file (never completed or cancelled actions)."""
trace_file_path = resolve_trace_file_path(trace_file)
traces = read_trace_file(trace_file_path)

# Track started actions
running_actions: dict[str, ActionTraceRecord] = {}
error_actions: dict[str, ActionTraceRecord] = {}
canceled_actions: dict[str, ActionTraceRecord] = {}

def action_started(trace: ActionTraceRecord) -> None:
running_actions[trace.trace_id] = trace

def action_completed(trace: ActionTraceRecord) -> ActionTraceRecord:
start_trace = running_actions.get(trace.trace_id)
if start_trace:
del running_actions[trace.trace_id]
return start_trace
else:
raise RuntimeError(f"Expected {trace.trace_id} in action dictionary.")

def action_failed(trace: ActionTraceRecord) -> None:
error_actions[start_trace.trace_id] = trace

def action_canceled(trace: ActionTraceRecord) -> None:
canceled_actions[start_trace.trace_id] = trace

for trace in traces:
if isinstance(trace, ActionTraceRecord):
match trace.event:
case "enter":
action_started(trace)
case "exit":
action_completed(trace)
case "cancel":
# Complete with a cancellation
start_trace = action_completed(trace)

# add duration
trace.start_time = start_trace.start_time

action_canceled(trace)
case "error":
# Capture error events
start_trace = action_completed(trace)

# add start time
trace.start_time = start_trace.start_time

action_failed(trace)
continue
case _:
print(f"Unknown event type: {trace.event}")

_print_bucket("Running Actions", running_actions)
_print_bucket("Canceled Actions", canceled_actions)
_print_bucket("Error Actions", error_actions)


def _print_bucket(label: str, bucket: dict[str, ActionTraceRecord]) -> None:
if len(bucket) > 0:
# Sort the items in chronological order of when
# they finished so the first finished item is at the top
sorted_actions = sorted(
bucket.values(),
key=lambda record: (record.start_time or 0) + (record.duration or 0),
reverse=True,
)

r_print(f"[bold]{label}[/bold]")
for action in sorted_actions:
# Compute duration (use the event duration or time since started)
duration = (
action.duration
if action.duration is not None
else time.time() - action.start_time
if action.start_time is not None
else 0.0
)

# The event start time
start_time = formatTime(action.start_time) if action.start_time else "None"
if action.event == "error":
# print errors
print(
f"{start_time} ({round(duration, 2)}s): {action.message} {action.error}"
)
else:
# print the action
print(f"{start_time} ({round(duration, 2)}s): {action.message}")
print("")


def resolve_trace_file_path(trace_file: str) -> Path:
trace_file_path = Path(trace_file)
if not trace_file_path.is_absolute():
trace_file_path = inspect_trace_dir() / trace_file_path

if not trace_file_path.exists():
raise PrerequisiteError(
f"The specified trace file '{trace_file_path}' does not exist."
)

return trace_file_path


def formatTime(timestamp: float) -> str:
# ISO format with timezone
dt = datetime.fromtimestamp(timestamp)
return dt.isoformat()
18 changes: 11 additions & 7 deletions src/inspect_ai/_eval/task/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,14 +580,18 @@ def handle_error(ex: BaseException) -> EvalError:
state = await plan(state, generate)

except TimeoutError:
# notify the user
transcript()._event(
SampleLimitEvent(
type="time",
message=f"Sample completed: exceeded time limit ({time_limit:,} seconds)",
limit=time_limit,
if time_limit is not None:
transcript()._event(
SampleLimitEvent(
type="time",
message=f"Sample completed: exceeded time limit ({time_limit:,} seconds)",
limit=time_limit,
)
)
else:
py_logger.warning(
"Unexpected timeout error reached top of sample stack. Are you handling TimeoutError when applying timeouts?"
)
)

# capture most recent state for scoring
state = sample_state() or state
Expand Down
6 changes: 3 additions & 3 deletions src/inspect_ai/_util/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
DEFAULT_SERVER_HOST = "127.0.0.1"
HTTP = 15
HTTP_LOG_LEVEL = "HTTP"
SANDBOX = 17
SANDBOX_LOG_LEVEL = "SANDBOX"
TRACE = 13
TRACE_LOG_LEVEL = "TRACE"
ALL_LOG_LEVELS = [
"DEBUG",
TRACE_LOG_LEVEL,
HTTP_LOG_LEVEL,
SANDBOX_LOG_LEVEL,
"INFO",
"WARNING",
"ERROR",
Expand Down
41 changes: 33 additions & 8 deletions src/inspect_ai/_util/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,28 @@
getLevelName,
getLogger,
)
from logging.handlers import RotatingFileHandler

import rich
from rich.console import ConsoleRenderable
from rich.logging import RichHandler
from rich.text import Text
from typing_extensions import override

from inspect_ai._util.constants import (
from .constants import (
ALL_LOG_LEVELS,
DEFAULT_LOG_LEVEL,
DEFAULT_LOG_LEVEL_TRANSCRIPT,
HTTP,
HTTP_LOG_LEVEL,
PKG_NAME,
SANDBOX,
SANDBOX_LOG_LEVEL,
TRACE,
TRACE_LOG_LEVEL,
)
from inspect_ai._util.error import PrerequisiteError
from .error import PrerequisiteError
from .trace import TraceFormatter, inspect_trace_dir

TRACE_FILE_NAME = "trace.log"


# log handler that filters messages to stderr and the log file
Expand All @@ -52,6 +56,23 @@ def __init__(self, levelno: int, transcript_levelno: int) -> None:
else:
self.file_logger_level = 0

# add a trace handler
default_trace_file = inspect_trace_dir() / TRACE_FILE_NAME
have_existing_trace_file = default_trace_file.exists()
trace_file = os.environ.get("INSPECT_TRACE_FILE", default_trace_file)
trace_total_files = 10
self.trace_logger = RotatingFileHandler(
trace_file,
backupCount=trace_total_files - 1, # exclude the current file (10 total)
)
self.trace_logger.setFormatter(TraceFormatter())
if have_existing_trace_file:
self.trace_logger.doRollover()

# set trace level
trace_level = os.environ.get("INSPECT_TRACE_LEVEL", TRACE_LOG_LEVEL)
self.trace_logger_level = int(getLevelName(trace_level.upper()))

@override
def emit(self, record: LogRecord) -> None:
# demote httpx and return notifications to log_level http
Expand Down Expand Up @@ -79,6 +100,10 @@ def emit(self, record: LogRecord) -> None:
):
self.file_logger.emit(record)

# write to trace if the trace level matches.
if self.trace_logger and record.levelno >= self.trace_logger_level:
self.trace_logger.emit(record)

# eval log always gets info level and higher records
# eval log only gets debug or http if we opt-in
write = record.levelno >= self.transcript_levelno
Expand All @@ -95,12 +120,12 @@ def init_logger(
log_level: str | None = None, log_level_transcript: str | None = None
) -> None:
# backwards compatibility for 'tools'
if log_level == "tools":
log_level = "sandbox"
if log_level == "sandbox" or log_level == "tools":
log_level = "trace"

# register http and tools levels
addLevelName(HTTP, HTTP_LOG_LEVEL)
addLevelName(SANDBOX, SANDBOX_LOG_LEVEL)
addLevelName(TRACE, TRACE_LOG_LEVEL)

def validate_level(option: str, level: str) -> None:
if level not in ALL_LOG_LEVELS:
Expand Down Expand Up @@ -134,7 +159,7 @@ def validate_level(option: str, level: str) -> None:
getLogger().addHandler(_logHandler)

# establish default capture level
capture_level = min(HTTP, levelno)
capture_level = min(TRACE, levelno)

# see all the messages (we won't actually display/write all of them)
getLogger().setLevel(capture_level)
Expand Down
Loading
Loading