diff --git a/.gitignore b/.gitignore index 5ae96bb8d..8da51edfc 100644 --- a/.gitignore +++ b/.gitignore @@ -1,12 +1,13 @@ -# Our re-organized logs. +# Our re-organized logs sanitized_logs -investigation-data +/investigation-data + _backup/** -# their logs. +# their logs logs -trajectories/** !trajectories/demonstrations/** +trajectories/** # Other @@ -15,6 +16,8 @@ service-account-key.json **/investigations/investigation-data **/investigations/run-logs* *~ +/*.log +*tmp.log* # Byte-compiled / optimized / DLL files __pycache__/ @@ -74,7 +77,6 @@ coverage.xml *.pot # Django stuff: -*.log local_settings.py db.sqlite3 db.sqlite3-journal diff --git a/.vscode/settings.json b/.vscode/settings.json index aa8f220dc..299a70c4b 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,25 +1,25 @@ { "python.terminal.activateEnvironment": true, "python.defaultInterpreterPath": "/.venv/bin/python", - "peacock.color": "#15141f", - "peacock.remoteColor": "#15141f", + "peacock.color": "#2b283d", + "peacock.remoteColor": "#2b283d", "workbench.colorCustomizations": { - "activityBar.activeBackground": "#2a283e", - "activityBar.background": "#2a283e", + "activityBar.activeBackground": "#413c5c", + "activityBar.background": "#413c5c", "activityBar.foreground": "#e7e7e7", "activityBar.inactiveForeground": "#e7e7e799", - "activityBarBadge.background": "#744f4b", + "activityBarBadge.background": "#000000", "activityBarBadge.foreground": "#e7e7e7", "commandCenter.border": "#e7e7e799", - "sash.hoverBorder": "#2a283e", - "statusBar.background": "#15141f", + "sash.hoverBorder": "#413c5c", + "statusBar.background": "#2b283d", "statusBar.foreground": "#e7e7e7", - "statusBarItem.hoverBackground": "#2a283e", - "statusBarItem.remoteBackground": "#15141f", + "statusBarItem.hoverBackground": "#413c5c", + "statusBarItem.remoteBackground": "#2b283d", "statusBarItem.remoteForeground": "#e7e7e7", - "titleBar.activeBackground": "#15141f", + "titleBar.activeBackground": "#2b283d", "titleBar.activeForeground": "#e7e7e7", - "titleBar.inactiveBackground": "#15141f99", + "titleBar.inactiveBackground": "#2b283d99", "titleBar.inactiveForeground": "#e7e7e799" }, } \ No newline at end of file diff --git a/config/_tdd_repro_prompt.md b/config/_tdd_repro_prompt.md index 09cf1b7db..834efa729 100644 --- a/config/_tdd_repro_prompt.md +++ b/config/_tdd_repro_prompt.md @@ -1,9 +1,7 @@ -## Requirements +# Requirements * You are provided `tdd_*` tools to reproduce the issue with one or more golden tests and also check for regressions. -* Always start your investigations from the failing test. -* When deciding to investigate some part of the code, EXPLAIN CLEARLY why you think that this is the next step to take. * Don't submit until the reproduction command proves your fix. - -## HOW TO BEGIN -To start things off, run the 'tdd_repro' command to reproduce the issue. \ No newline at end of file +* IMPORTANT: Always FIRST RUN the `tdd_repro` command to reproduce the issue. +* This provides you with in-depth test failure and runtime information. +* This includes `CALL_GRAPH_ON_EXCEPTION`: It contains the entire call hierarchy of all functions from that test. diff --git a/config/commands/_tdd.sh b/config/commands/_tdd.sh index c03d038fb..a1ead85c8 100644 --- a/config/commands/_tdd.sh +++ b/config/commands/_tdd.sh @@ -1,6 +1,20 @@ # @yaml -# docstring: Reproduces the current bug by running a test that was designed to fail as long as the bug exists. # tdd: true +# signature: tdd_repro [""] [""] [] +# docstring: Reproduces the bug by running bug-specific tests. Provide optional target arguments to get runtime context for the target function. +# arguments: +# target_file: +# type: string +# description: The file containing the target function, relative to CWD. If provided, target_function_name is also required. +# required: false +# target_function_name: +# type: string +# description: The UNQUALIFIED(!) name of the target function or method. +# required: false +# decl_lineno: +# type: integer +# description: The lineno of target_function's declaration. Only required if target_function_name is ambiguous within the file. +# required: false tdd_repro() { # set -euo pipefail if [ -z "$TEST_CMD_FAIL_TO_PASS" ]; then @@ -9,6 +23,10 @@ tdd_repro() { fi pushd $REPO_ROOT > /dev/null echo -e "Running tests to reproduce the bug (from $PWD):\n >$TEST_CMD_FAIL_TO_PASS\n" + if [ $# -ge 1 ]; then + line_no=${3:-0} + export TDD_TRACE_TARGET_CONFIG="{ \"target_file\": \"$1\", \"target_function_name\": \"$2\", \"decl_lineno\": $line_no}" + fi eval "$TEST_CMD_FAIL_TO_PASS" # include the continuation file if it exists diff --git a/config/commands/defaults.sh b/config/commands/defaults.sh index f0e046694..79d218c7e 100644 --- a/config/commands/defaults.sh +++ b/config/commands/defaults.sh @@ -221,9 +221,8 @@ create() { # arguments: # command: # type: string -# description: Shell command string to execute. Will execute with `eval "$@"`. -# required: truewith $PWD. -# required: false +# description: Shell command string to execute. Executes `eval "$@"`. +# required: true exec() { if [ $# -eq 0 ]; then echo "Usage: exec " diff --git a/prediction_assets/call_graph_tracer.py b/prediction_assets/call_graph_tracer.py new file mode 100644 index 000000000..7ccb31e40 --- /dev/null +++ b/prediction_assets/call_graph_tracer.py @@ -0,0 +1,660 @@ +# noqa: I002 +# ruff: noqa: UP006 + +# NOTE: We ignore some warnings in this file to allow for backward compatability. + +import inspect +import json # noqa: I002 +import linecache +import os +import sys +import threading +import traceback +from json.decoder import JSONDecodeError +from types import FrameType, TracebackType +from typing import Any, Callable, Dict, List, Optional, Union + +from tests.prediction_assets.use_tracer import register_call_graph + +TargetConfig = Dict[str, Union[Optional[str], Optional[int]]] + +# ############################################################################ +# parse_json util +# ############################################################################ + + +def parse_json(json_string): + try: + return json.loads(json_string) + except JSONDecodeError as e: + # Get the position of the error + pos = e.pos + + # Get the line and column of the error + lineno = json_string.count("\n", 0, pos) + 1 + colno = pos - json_string.rfind("\n", 0, pos) + + # Get the problematic lines (including context) + lines = json_string.splitlines() + context_range = 2 # Number of lines to show before and after the error + start = max(0, lineno - context_range - 1) + end = min(len(lines), lineno + context_range) + context_lines = lines[start:end] + + # Create the context string with line numbers + context = "" + for i, line in enumerate(context_lines, start=start + 1): + if i == lineno: + context += f"{i:4d} > {line}\n" + context += " " + " " * (colno - 1) + "^\n" + else: + context += f"{i:4d} {line}\n" + + # Construct and raise a new error with more information + error_msg = f"JSON parsing failed at line {lineno}, column {colno}:\n\n{context.rstrip()}\nError: {str(e)}" + raise ValueError(error_msg) from e + + +# ############################################################################ +# Config +# ############################################################################ + +# Parse target config. +TRACE_TARGET_CONFIG_STR = os.environ.get("TDD_TRACE_TARGET_CONFIG") +TRACE_TARGET_CONFIG: Optional[TargetConfig] = None +if TRACE_TARGET_CONFIG_STR: + TRACE_TARGET_CONFIG = parse_json(TRACE_TARGET_CONFIG_STR) + if TRACE_TARGET_CONFIG: + if "target_file" not in TRACE_TARGET_CONFIG: + raise ValueError("TDD_TRACE_TARGET_CONFIG must provide 'target_file'.") + if "target_function_name" not in TRACE_TARGET_CONFIG: + raise ValueError( + "TDD_TRACE_TARGET_CONFIG must provide 'target_function_name' if 'target_file' is provided." + ) + +# Whether to only report call graphs on assert failures +ASSERTS_ONLY = True + +# Record parameter values and return values, only if target region is sufficiently scoped. +RECORD_VALUES = not not TRACE_TARGET_CONFIG +RECORD_PARAMS = True +RECORD_RETURN_VALUES = True + +# NOTE: We need to mute exceptions because some of them get thrown during teardown where builtins are straight up gone. +MUTE_EXCEPTIONS = True + +# Whether to print stringify parameter and return values. +# If set to `False`, will only print primitive values and for objects only a reference id. +DO_STRINGIFY = True + +MAX_PRINTED_CALL_GRAPHS = 3 + +# OVERRIDES +RECORD_VALUES = True +# RECORD_PARAMS = False +MUTE_EXCEPTIONS = False +DO_STRINGIFY = False +MAX_PRINTED_CALL_GRAPHS = 1 + +# REPO_ROOT = os.environ.get("REPO_ROOT") or os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +# INSTANCE_NAME = os.environ.get("TDD_INSTANCE_NAME") + +_n_printed_call_graphs = 0 + +# ############################################################################ +# FrameInfo +# ############################################################################ + + +class FrameInfo: + def __init__( + self, + frame: Optional[FrameType] = None, + decl_filename: Optional[str] = None, + decl_lineno: Optional[int] = None, + call_filename: Optional[str] = None, + call_lineno: Optional[int] = None, + function_name: Optional[str] = None, + code_context: Optional[List[str]] = None, + index: Optional[int] = None, + ): + self.frame = frame + self.decl_filename = decl_filename or (frame.f_code.co_filename if frame else None) + self.decl_lineno = decl_lineno or (frame.f_code.co_firstlineno if frame else None) + + self.call_filename = call_filename or (frame.f_back.f_code.co_filename if frame and frame.f_back else None) + self.call_lineno = self.call_lineno = call_lineno or (frame.f_back.f_lineno if frame and frame.f_back else None) + + self.function_name = function_name or (frame.f_code.co_name if frame else None) + self.code_context = code_context + self.index = index + + @classmethod + def from_frame(cls, frame: FrameType) -> "FrameInfo": + return cls(frame=frame) + + @classmethod + def from_traceback(cls, tb: TracebackType) -> "FrameInfo": + return cls(frame=tb.tb_frame, call_lineno=tb.tb_lineno) + + @classmethod + def from_frame_summary(cls, summary: traceback.FrameSummary) -> "FrameInfo": + return cls( + call_filename=summary.filename, + call_lineno=summary.lineno, + function=summary.name, + code_context=summary.line and [summary.line] or None, + ) + + def get_name(self) -> str: + """ + This does NOT actually provide the *qualified* name. It looks like we can't do that reliably across versions. + """ + frame = self.frame + if frame is None: + return self.function_name or "" + + code = frame.f_code + return code.co_name + + def get_locals(self) -> Dict[str, any]: + return self.frame.f_locals + + +# ############################################################################ +# Polyfill stuff +# ############################################################################ + +python_version = sys.version_info + +# Compatibility for ContextVar +if python_version >= (3, 7): + from contextvars import ContextVar +else: + + class ContextVar: + def __init__(self, name, default=None): + self.local = threading.local() + self.default = default + + def get(self): + return getattr(self.local, "value", self.default) + + def set(self, value): + setattr(self.local, "value", value) + + +# ############################################################################ +# util +# ############################################################################ + + +def get_relative_filename(filename: str) -> str: + try: + rel_path = os.path.relpath(filename) + if rel_path.startswith("..") or os.path.isabs(rel_path): + return f"EXTERNAL/{os.path.basename(filename)}" + else: + return rel_path + except Exception: + return filename + + +# ############################################################################ +# CallGraphNode + Object and Value tracking +# ############################################################################ + + +class ObjectTracker: + def __init__(self): + self.object_id_counter = 0 + self.object_ids = {} + + def get_object_id(self, obj: Any) -> int: + if id(obj) not in self.object_ids: + self.object_id_counter += 1 + self.object_ids[id(obj)] = self.object_id_counter + return self.object_ids[id(obj)] + + +object_tracker = ObjectTracker() + + +def stringify_with_ref(value: Any) -> str: + if isinstance(value, (int, float, str, bool, type(None))): + if DO_STRINGIFY: + return str(value) + return f"({type(value).__name__})" + + obj_id = object_tracker.get_object_id(value) + if DO_STRINGIFY: + stringified = ": " + try: + stringified = str(value) + except Exception as e: + stringified = f"(Error when stringifying: {e})" + else: + stringified = "" + + return f"REF#{obj_id}{stringified}" + + +def stringify_dict(d: Dict[str, Any]) -> str: + result = {} + for key, value in d.items(): + result[key] = stringify_with_ref(value) + return str(result) + + +class BaseNode: + def __init__(self): + self.children: List[BaseNode] = [] + self.parent: Optional[CallGraphNode] = None + + def add_child(self, child: "BaseNode"): + self.children.append(child) + child.parent = self + + def __str__(self, level=0, visited=None): + indent = " " * level + return f"{indent}{self.name}\n" + + +class LineNode(BaseNode): + def __init__(self, value: str): + super().__init__() + self.name = value + + +class OmittedNode(BaseNode): + def __init__(self): + super().__init__() + self.name = "(omitted child)" + + +class CallGraphNode(BaseNode): + is_partial: bool = False + + def __init__(self, frame_info: FrameInfo): + super().__init__() + self.frame_info = frame_info + self.children: list[CallGraphNode] = [] + self.exception: Optional[str] = None + self.parameters: Dict[str, any] = {} + self.return_value: any = None + + @property + def name(self) -> str: + return self.frame_info.get_name() + + @property + def decl_filename(self) -> str: + return get_relative_filename(self.frame_info.decl_filename) + + @property + def call_filename(self) -> str: + return get_relative_filename(self.frame_info.call_filename) + + def set_exception(self, exc_name: str): + self.exception = exc_name + + def set_parameters(self, params: Dict[str, Any]): + if not RECORD_VALUES or not RECORD_PARAMS: + return + self.parameters = stringify_dict(params) + + def set_return_value(self, value: Any): + if not RECORD_VALUES or not RECORD_RETURN_VALUES: + return + self.return_value = stringify_with_ref(value) + + def __str__(self, level=0, visited=None): + if visited is None: + visited = set() + if id(self) in visited: + return " " * level + "[Recursion]\n" + visited.add(id(self)) + indent = " " * level + result = f"{indent}{self.name}" + f" [decl: {self.decl_filename or '?'}:{self.frame_info.decl_lineno or '?'}" + if self.frame_info.call_lineno != self.frame_info.decl_lineno: + result += f", called_from: {self.call_filename or '?'}:{self.frame_info.call_lineno or '?'}" + result += "]" + if self.parameters: + result += f", Parameters:{self.parameters}" + if self.return_value is not None: + result += f", ReturnValue:{self.return_value}" + result += "\n" + for child in self.children: + result += child.__str__(level + 1, visited) + return result + + +# ############################################################################ +# CallGraph impl +# ############################################################################ + + +class CallGraph: + def __init__(self): + self.call_stack: List[CallGraphNode] = ContextVar("call_stack", default=[]) + self.root: Optional[CallGraphNode] = None + self.is_partial: bool = False + try: + self.cwd = os.getcwd() + except Exception: + self.cwd = None + + def should_trace(self, frame: FrameType) -> bool: + if frame is None or frame.f_code is None or frame.f_code.co_filename is None: + # Ignore code without code or filename. + # TODO: Not sure why frames can have no code or filename. Might be some builtins? + return False + if os.path.dirname(frame.f_code.co_filename) == os.path.dirname(__file__): + # Ignore all code from within this directory. + return False + filename = frame.f_code.co_filename + abs_filename = os.path.abspath(filename) + if self.cwd: + # Ignore external code. + return abs_filename.startswith(self.cwd) + else: + return True + + def access_call_stack(self) -> List[CallGraphNode]: + call_stack = self.call_stack.get() + res: List[CallGraphNode] = call_stack.copy() + return res + + def trace_line(self, frame: FrameType, arg: any): + code = frame.f_code + function_name = code.co_name + lineno = frame.f_lineno + filename = code.co_filename + line_of_code = linecache.getline(filename, lineno).strip() + + # Get the variables in the current frame + variables = frame.f_locals + + code_query = "_iterable_class" + if code_query in line_of_code: + node_str = f"CODE_EXECUTED: `{line_of_code} (at {filename}:{lineno})`" + self.append_to_current_node(node_str) + + def append_to_current_node(self, s: str): + call_stack = self.access_call_stack() + if call_stack: + call_stack[-1].add_child(LineNode(f'"{s}"')) + + def trace_calls(self, event_frame: FrameType, event: str, arg: any) -> Optional[Callable]: + try: + frame_info = FrameInfo.from_frame(event_frame) + + if not self.should_trace(event_frame): + return None + if event == "call": + call_stack = self.access_call_stack() + + node = CallGraphNode(frame_info) + + # Store parameter values + node.set_parameters(frame_info.get_locals()) + + if call_stack: + call_stack[-1].add_child(node) + else: + self.root = node + call_stack.append(node) + self.call_stack.set(call_stack) + event_frame.f_trace = self.trace_calls + elif event == "return": + call_stack = self.access_call_stack() + if call_stack: + # Store return value + call_stack[-1].set_return_value(arg) + call_stack.pop() + self.call_stack.set(call_stack) + elif event == "exception": + exc_type, exc_str, _ = arg + if exc_type is GeneratorExit: + return None + call_stack = self.access_call_stack() + if call_stack: + call_stack[-1].set_exception(exc_type.__name__) + test_node = next( + (node for node in reversed(call_stack) if node.name.startswith("test_")), + None, + ) + if test_node and (not ASSERTS_ONLY or exc_type is AssertionError): + self.print_graph_on_exception("EXCEPTION", test_node, exc_str, exc_type) + elif event == "line": + self.trace_line(event_frame, arg) + return self.trace_calls + except Exception as err: + if not MUTE_EXCEPTIONS: + print("\n\n\nERROR IN trace_calls:\n\n\n", file=sys.stderr) + traceback.print_exc() + return None + + def find_node(self, target_config: TargetConfig) -> Optional[CallGraphNode]: + root = self.root + stack = [root] + while stack: + node = stack.pop() + if ( + node.name == target_config.get("target_function_name") + and node.decl_filename == target_config.get("target_file") + and ( + not target_config.get("decl_lineno") or node.frame_info.decl_lineno == target_config["decl_lineno"] + ) + ): + return node + stack.extend(reversed(node.children)) + return None + + def get_partial_graph(self, target_config: TargetConfig) -> Optional[BaseNode]: + """ + Create a partial graph of the first function call of given name. + Should only contain the flat call graph surrounding that node, i.e. its parent and its children. + """ + + def create_partial_node(node: CallGraphNode) -> CallGraphNode: + partial_node = CallGraphNode(node.frame_info) + partial_node.parameters = node.parameters + partial_node.return_value = node.return_value + partial_node.exception = node.exception + partial_node.is_partial = True + return partial_node + + if not self.root: + return None + + target_node = self.find_node(target_config) + if not target_node: + return None + + partial_node = create_partial_node(target_node) + + if target_node.parent: + # Pick parent of target node, if available. + root = create_partial_node(target_node.parent) + root.add_child(partial_node) # Add the target node + if len(target_node.parent.children) > 1: + # Add OmittedNode to represent siblings. + root.add_child(OmittedNode()) + else: + root = partial_node + + # Add children + for child in target_node.children: + child_partial = create_partial_node(child) + partial_node.add_child(child_partial) + if child.children: + # Add OmittedNode to represent children. + child_partial.add_child(OmittedNode()) + + return root + + def print_graph_on_exception( + self, cause: str, node: BaseNode, exception_details: Optional[Any], exc_type: Optional[Any] + ): + global _n_printed_call_graphs + if _n_printed_call_graphs >= MAX_PRINTED_CALL_GRAPHS: + return + _n_printed_call_graphs += 1 + try: + result: str = None + if TRACE_TARGET_CONFIG: + partial_graph = self.get_partial_graph(TRACE_TARGET_CONFIG) + partial_info = f" PARTIAL='{str(TRACE_TARGET_CONFIG)}'" + if partial_graph: + result = str(partial_graph) + else: + # Hackfix: Stringify without values, if we could not target the function. + global RECORD_VALUES + RECORD_VALUES = False + result = ( + "(❌ ERROR: Could not find target function. Providing high-level call graph instead. ❌)\n" + + str(node) + ) + RECORD_VALUES = True + else: + partial_info = "" + result = str(node) + + ln = "\n" + print(f"{ln}{ln}", file=sys.stderr) + if exception_details or exc_type: + print( + f"{ln}{str(exc_type)}{ln}{str(exception_details)}{ln}", + file=sys.stderr, + ) + print(f"", file=sys.stderr) + print(result, file=sys.stderr) + print("", file=sys.stderr) + print("", file=sys.stderr) + except Exception as err: + print(f"INTERNAL ERROR when printing EXCEPTION_EVENT: {err}", file=sys.stderr) + + if python_version >= (3, 7): + + def task_done_callback(self, task): + exc = task.exception() + if exc: + try: + tb = exc.__traceback__ + nodes: list[CallGraphNode] = [] + while tb: + frame_info = FrameInfo.from_traceback(tb) + if not self.should_trace(frame_info.frame): + tb = tb.tb_next + continue + node = CallGraphNode(frame_info) + node.set_parameters(frame_info.get_locals()) + node.set_exception(type(exc).__name__) + nodes.append(node) + tb = tb.tb_next + creation_stack: list[CallGraphNode] = getattr(task, "_creation_stack", []) + full_stack = creation_stack + nodes + for i in range(len(full_stack) - 1): + full_stack[i].add_child(full_stack[i + 1]) + if full_stack: + self.print_graph_on_exception("FUTURE_DONE_CALLBACK", full_stack[0]) + except Exception: + if not MUTE_EXCEPTIONS: + print("\n\n\nERROR IN task_done_callback:\n\n\n", file=sys.stderr) + traceback.print_exc() + + +# ############################################################################ +# register_runtime_trace +# ############################################################################ + + +def register_runtime_trace(): + global _current_graph + _current_graph = CallGraph() + register_call_graph(_current_graph) + sys.settrace(_current_graph.trace_calls) + threading.settrace(_current_graph.trace_calls) + + if python_version >= (3, 7): + import asyncio + + class TracingEventLoopPolicy(asyncio.DefaultEventLoopPolicy): + def new_event_loop(self): + loop = super().new_event_loop() + loop.set_task_factory(self.tracing_task_factory) + loop.set_exception_handler(self.exception_handler) + return loop + + def tracing_task_factory(self, loop, coro): + task = asyncio.Task(coro, loop=loop) + creation_stack: list[CallGraphNode] = [] + + # Get the current stack with frame objects + stack = inspect.stack() + + # Skip the last frame (this method) + for stack_frame in stack[1:]: + if stack_frame.filename is None: + continue + abs_filename = os.path.abspath(stack_frame.filename) + if _current_graph.cwd and not abs_filename.startswith(_current_graph.cwd): + continue + + # Create FrameInfo object directly from the frame + frame_info = FrameInfo.from_frame(stack_frame.frame) + + node = CallGraphNode(frame_info) + creation_stack.append(node) + + task._creation_stack = creation_stack + task.add_done_callback(_current_graph.task_done_callback) + return task + + def exception_handler(self, loop, context): + if "exception" in context and "task" in context: + _current_graph.task_done_callback(context["task"]) + loop.default_exception_handler(context) + + asyncio.set_event_loop_policy(TracingEventLoopPolicy()) + + +_current_graph = None + +# ############################################################################ +# exception_handler +# ############################################################################ + + +def exception_handler(exc_type, exc_value, exc_traceback): + try: + if _current_graph and exc_type is not GeneratorExit: + tb = exc_traceback + nodes: list[CallGraphNode] = [] + while tb: + frame_info = FrameInfo.from_traceback(tb) + if not _current_graph.should_trace(frame_info.frame): + tb = tb.tb_next + continue + node = CallGraphNode(frame_info) + node.set_exception(exc_type.__name__) + node.set_parameters(frame_info.get_locals()) + nodes.append(node) + tb = tb.tb_next + nodes.reverse() + for i in range(len(nodes) - 1): + nodes[i].add_child(nodes[i + 1]) + if nodes: + _current_graph.print_graph_on_exception("UNCAUGHT_EXCEPTION_HANDLER", nodes[0]) + except Exception: + if not MUTE_EXCEPTIONS: + print("\n\n\nERROR IN exception_handler:\n\n\n", file=sys.stderr) + traceback.print_exc() + sys.__excepthook__(exc_type, exc_value, exc_traceback) + + +sys.excepthook = exception_handler + +if __name__ == "__main__": + register_runtime_trace() diff --git a/prediction_assets/call_graph_tracer_demo.py b/prediction_assets/call_graph_tracer_demo.py new file mode 100644 index 000000000..359a76e7d --- /dev/null +++ b/prediction_assets/call_graph_tracer_demo.py @@ -0,0 +1,100 @@ +import asyncio # noqa: I002 +import importlib.util +import os +import sys + +os.environ["TDD_TRACE_TARGET_CONFIG"] = """ +{ + "target_file": "prediction_assets/call_graph_tracer_demo.py", + "target_function_name": "function_that_throws", + "decl_lineno": 48 +} +""".strip() + +# Add this directory to sys.path. +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__)))) + +# Dynamically import ./call_graph_tracer.py and inject it: +module = importlib.import_module("call_graph_tracer") +module.register_runtime_trace() + + +class A: + class B: + @classmethod + def my_method(cls): + print("A.B.my_method") + +def function_a(): + print("Function A") + A().B.my_method() + function_b() + function_b2() + +def function_b(x = 2): + print("Function B") + +def function_b2(x = 2): + print("Function B2") + function_c(x) + function_that_throws(x) + +def function_c(x = 3): + print("Function C") + function_d() + +def function_d(x = 3): + print("Function D") + +def function_that_throws(x = 3): + print("Function that throws") + raise ValueError("Simulated error") + +async def async_function_1(): + print("Async function 1") + await asyncio.sleep(0.1) + function_that_throws() + +async def async_function_2(): + print("Async function 2") + await asyncio.sleep(0.2) + function_that_throws() + +def test_example(): + print("Test example") + function_a() + +async def test_async_example1(): + print("Test async example 1") + await asyncio.gather( + async_function_1(), + async_function_2(), + return_exceptions=True + ) + +async def test_async_example2(): + print("Test async example 2") + await asyncio.gather( + async_function_1(), + async_function_2(), + return_exceptions=True + ) + +if __name__ == "__main__": + try: + test_example() + except ValueError: + print("Caught ValueError") + + if sys.version_info >= (3, 7): + asyncio.run(test_async_example1()) + asyncio.run(test_async_example2()) + else: + # For Python 3.6, use alternative to asyncio.run() + loop = asyncio.get_event_loop() + loop.run_until_complete(test_async_example1()) + loop.run_until_complete(test_async_example2()) + loop.close() + + # Demonstrate uncaught exception + test_example() diff --git a/prediction_assets/use_tracer.py b/prediction_assets/use_tracer.py new file mode 100644 index 000000000..1b66920c1 --- /dev/null +++ b/prediction_assets/use_tracer.py @@ -0,0 +1,11 @@ +_current_call_graph = None + +def trace_line(s): + global _current_call_graph + if _current_call_graph: + _current_call_graph.append_to_current_node(f"TRACE_LINE: {s}") + + +def register_call_graph(call_graph): + global _current_call_graph + _current_call_graph = call_graph diff --git a/sweagent/agent/agents.py b/sweagent/agent/agents.py index efd562eb3..d03ac08d4 100644 --- a/sweagent/agent/agents.py +++ b/sweagent/agent/agents.py @@ -23,15 +23,15 @@ ModelQueryResult, get_last_valid_tool_use_name, get_model, - make_model_query_result, make_assistant_content, + make_model_query_result, make_user_reply_content, ) from sweagent.agent.parsing import FormatError, ParseFunction from sweagent.environment.swe_env import SWEEnv from sweagent.utils.config import convert_path_to_abspath, convert_paths_to_abspath -from sweagent.utils.log import get_logger from sweagent.utils.instrumentation import instrument +from sweagent.utils.log import get_logger logger = get_logger("agents") @@ -644,6 +644,13 @@ def forward_model(self, observation: str, state: str) -> ModelQueryResult: # Show instance template if prev. obs. was initial system message self.made_initial_prompt = True templates = [self.config.instance_template] + # ## [PRO-864] Dynamic analysis work + # NOTE: Uncomment this to force-feed manually constructed data into initial prompt. + # from sweagent.agent.manual_prompt_input import MANUAL_ANALYSIS_PROMPT + # templates = [ + # self.config.instance_template, + # self.config.next_step_template] + # observation = MANUAL_ANALYSIS_PROMPT if self.config.strategy_template is not None: templates.append(self.config.strategy_template) elif observation is None or observation.strip() == "": diff --git a/sweagent/environment/swe_env.py b/sweagent/environment/swe_env.py index 00147afc3..91b69570e 100644 --- a/sweagent/environment/swe_env.py +++ b/sweagent/environment/swe_env.py @@ -285,6 +285,13 @@ def _copy_repo(self) -> str: timeout_duration=LONG_TIMEOUT, ) return self._repo_name + + def _write_cached_image(self): + self.communicate("env >> /.env") + assert self.container_obj is not None # mypy + cached_image_name = self._get_cached_task_image_name() + self.container_obj.commit(cached_image_name) + self.logger.info(f"Container with environment {self.container_obj.id} cached as image {cached_image_name}") def reset(self, index: int | None = None) -> tuple[str | None, dict]: """ @@ -326,6 +333,10 @@ def reset(self, index: int | None = None) -> tuple[str | None, dict]: self.communicate("export $(xargs tuple[str | None, dict]: self.init_container_prebake() if self.args.cache_task_images: - self.communicate("env >> /.env") - assert self.container_obj is not None # mypy - self.container_obj.commit(cached_image_name) - self.logger.info(f"Container with environment {self.container_obj.id} cached as image {cached_image_name}") + self._write_cached_image() self.init_container_postbake() @@ -416,15 +424,37 @@ def init_container_prebake(self): 3. Re-run to have settings baked into the image. """ self.communicate(f"export REPO_ROOT=\"/{self._repo_name}\"") + self.communicate(f"export CONDA_ENV_NAME=\"{self.env_name}\"") if self.tdd: - # Provide test_cmd for the instance's repo. + # Provide tdd data to the container. + self._prepare_test_patch() install_configs = self._get_install_configs() - # self.logger.warning(f"test_cmd: {repr(install_configs['test_cmd'])}") + instance_id = self.record['instance_id'] if not install_configs["test_cmd"]: - raise RuntimeError(f"No test_cmd found in install configs for instance {self.record['instance_id']}: {repr(install_configs)}") + raise RuntimeError(f"No test_cmd found in install configs for instance {instance_id}: {repr(install_configs)}") fail_to_pass_cmd = make_fail_to_pass_test_cmd(self.record, install_configs["test_cmd"]) self.communicate_with_handling(f'export TEST_CMD_FAIL_TO_PASS="{fail_to_pass_cmd}"') + self.communicate_with_handling(f'export TDD_INSTANCE_NAME="{instance_id}"') + + # Copy prediction assets to the container + host_root = os.path.join(os.path.dirname(__file__), "../..") + host_asset_folder = os.path.abspath(os.path.join(host_root, "prediction_assets")) + container_test_folder = os.path.join("/" + self._repo_name, "tests") + self.communicate_with_handling(f"ls -l {container_test_folder}") + copy_anything_to_container( + self.container_obj, + host_asset_folder, + container_test_folder + ) + + # ## [PRO-864] Dynamic analysis work + # # TODO: generalize the file location + # # Inject test injecter: + # # (Choosing urls.py is a hack: we know that `urls.py` is one of the dynamic imports from runtests.py, so appending to it works.) + # test_file_to_override = os.path.join(container_test_folder, "urls.py") + # self.communicate_with_handling(f"cd /{self._repo_name} && git checkout -- tests/urls.py") + # self.communicate_with_handling(f"echo 'from tests.prediction_assets.call_graph_tracer import register_runtime_trace; register_runtime_trace()' >> {test_file_to_override}") # pass_to_pass_cmd = make_pass_to_pass_test_cmd() @@ -437,8 +467,9 @@ def init_container_postbake(self): self.communicate_with_handling(f"source activate {self.env_name}") self.logger.debug(f"Activated container environment: {self.env_name}") if self.tdd: - # Apply test patch so the bug can be repro'ed at all. + # Apply test patch so the bug can be repro'ed. self._apply_test_patch() + def copy_string_to_container_file(self, content: str, container_file_path: str) -> None: with tempfile.NamedTemporaryFile(mode='w', delete=True) as temp_file: @@ -453,15 +484,22 @@ def copy_string_to_container_file(self, content: str, container_file_path: str) text=True ) + @property + def _container_patch_path(self): + return "/root/test.patch" + + def _prepare_test_patch(self): + patch = self.record["test_patch"] + self.copy_string_to_container_file(patch, self._container_patch_path) + + def _apply_test_patch(self): """ Apply test patch for oracle setting """ assert self.record is not None - container_patch_path = "/root/test.patch" - self.copy_string_to_container_file(self.record["test_patch"], container_patch_path) res = self.communicate_with_handling( - input=f"cd /{self._repo_name} && git apply -v {container_patch_path}", + input=f"cd /{self._repo_name} && git apply -v {self._container_patch_path}", error_msg="Failed to apply test patch correctly", ) self.logger.debug(f"[TDD] Applied test patch - output:\n{res}") diff --git a/sweagent/environment/utils.py b/sweagent/environment/utils.py index e65a48b62..c2dd98016 100644 --- a/sweagent/environment/utils.py +++ b/sweagent/environment/utils.py @@ -227,13 +227,15 @@ def ready_to_read(fd) -> bool: if ready_to_read(fd): try: data = os.read(fd, 4096) - except BlockingIOError: + except BlockingIOError as err: logger.error("BlockingIOError while reading from subprocess.", exc_info=True) - break + # break + raise err if data: buffer += data decoded = buffer.decode("utf-8", errors="backslashreplace") - if PROCESS_DONE_MARKER_START in decoded: + # if PROCESS_DONE_MARKER_START in decoded: + if PROCESS_DONE_REGEX.search(decoded): break time.sleep(0.01) # Prevents CPU hogging diff --git a/sweagent/investigations/instance_data.py b/sweagent/investigations/instance_data.py index 9aa711776..e4ce9e46b 100644 --- a/sweagent/investigations/instance_data.py +++ b/sweagent/investigations/instance_data.py @@ -6,38 +6,49 @@ import pandas as pd +_swe_bench_data = None + def get_swe_bench_data(): - return pd.read_parquet("hf://datasets/princeton-nlp/SWE-bench_Verified/data/test-00000-of-00001.parquet") + # Keep it cached + global _swe_bench_data + if _swe_bench_data is None: + _swe_bench_data = pd.read_parquet("hf://datasets/princeton-nlp/SWE-bench_Verified/data/test-00000-of-00001.parquet") + return _swe_bench_data + def truncate_string(text, max_length=100): - return (text[:max_length] + '...') if len(text) > max_length else text + return (text[:max_length] + "...") if len(text) > max_length else text + +def get_swe_bench_cell(instance_id: str, col: str): + return get_swe_bench_data().loc[get_swe_bench_data()["instance_id"] == instance_id, col].iloc[0] def get_swe_bench_instance_markdown(instance_id: str): # Get the DataFrame df = get_swe_bench_data() - + # Select the specific row - specific_row = df[df['instance_id'] == instance_id] - + specific_row = df[df["instance_id"] == instance_id] + if specific_row.empty: return "No data found for the given instance_id." - + # Truncation - if 'PASS_TO_PASS' in specific_row.columns: - specific_row.loc[:, 'PASS_TO_PASS'] = specific_row['PASS_TO_PASS'].apply(truncate_string) - + if "PASS_TO_PASS" in specific_row.columns: + specific_row.loc[:, "PASS_TO_PASS"] = specific_row["PASS_TO_PASS"].apply(truncate_string) + # Transpose the row transposed = specific_row.transpose() - + # Reset the index to turn the column names into a regular column transposed = transposed.reset_index() - + # Rename the columns - transposed.columns = ['Field', 'Value'] - + transposed.columns = ["Field", "Value"] + # Convert to Markdown return transposed.to_markdown(index=False) + def generate_cached_image_id(instance_id: str, environment_setup: str = "no_setup") -> str: cached_image_prefix = "swe-agent-task-env-" @@ -76,4 +87,3 @@ def generate_cached_image_id(instance_id: str, environment_setup: str = "no_setu result = generate_cached_image_id(instance_id, environment_setup) print(result) - diff --git a/sweagent/investigations/summarize_instance.py b/sweagent/investigations/summarize_instance.py index f7d7525b0..1aef74385 100755 --- a/sweagent/investigations/summarize_instance.py +++ b/sweagent/investigations/summarize_instance.py @@ -5,7 +5,7 @@ import sys from argparse import ArgumentParser -from sweagent.investigations.instance_data import get_swe_bench_instance_markdown +from sweagent.investigations.instance_data import get_swe_bench_cell, get_swe_bench_instance_markdown from sweagent.investigations.run_logs_sync import RunLogsSync investigation_data_folder_name = "investigation-data" @@ -60,7 +60,7 @@ def summarize_instance(instance_id: str): run_data.append(f""" ### {run_name} -* [PR Link]({make_bug_href(instance_id)}) +* [Golden Patch Link]({make_bug_href(instance_id)}) * Prediction * [Run Log]({make_relative_path(prediction_run_logs)}) * [Trajectory json]({make_relative_path(prediction_trajectories)}) @@ -83,6 +83,14 @@ def summarize_instance(instance_id: str): {get_swe_bench_instance_markdown(instance_id)} +### test_patch + +Copy-and-pasteable version of the test patch: + +```patch +{get_swe_bench_cell(instance_id, "test_patch")} +``` + """.strip() # * {f"[Evaluation Results Folder]({eval_folder_href})" if eval_folder_href else "(no evaluation found)"} f.write(contents)