From 3a9866d0a245fe4188872bcff53358f489de5776 Mon Sep 17 00:00:00 2001 From: Chris Toshok Date: Tue, 17 Sep 2024 12:40:21 -0700 Subject: [PATCH 1/6] fix some typing issues, and fix json serialization/deserialization of anthropic results <-> the cache --- sweagent/agent/model_cache.py | 90 +++++++++++++++++++++++++++++----- sweagent/agent/model_result.py | 21 ++++++++ sweagent/agent/models.py | 23 +-------- 3 files changed, 101 insertions(+), 33 deletions(-) create mode 100644 sweagent/agent/model_result.py diff --git a/sweagent/agent/model_cache.py b/sweagent/agent/model_cache.py index c9d9e06e4..0d6c6ad65 100644 --- a/sweagent/agent/model_cache.py +++ b/sweagent/agent/model_cache.py @@ -1,26 +1,86 @@ - import json import hashlib import os +from anthropic.types import ContentBlock, TextBlock, ToolUseBlock + from sweagent.utils.log import get_logger +from sweagent.agent.model_result import AnthropicModelResult, ModelQueryResult ModelCacheEnvVar = "MODEL_CACHE_DIRECTORY" logger = get_logger("model_cache") + +class CacheEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, AnthropicModelResult): + return { + "type": "anthropic_model_result", + "blocks": [ + encode_anthropic_content_block(block) for block in obj.blocks + ], + } + + if isinstance(obj, TextBlock) or isinstance(obj, ToolUseBlock): + return encode_anthropic_content_block(obj) + + return super().default(obj) + + +def encode_anthropic_content_block(block: ContentBlock) -> dict[str, str]: + # it should be possible to use the Anthropic library to encode these (using BaseModel#to_dict), + # but there doesn't seem to be a way to reconstruct instances from those dicts directly. Given + # that it seems I have to write the latter, I'd like to keep code for both close to the other to + # make it easier to keep them in sync. + if isinstance(block, TextBlock): + return { + "type": "text", + "text": block.text, + } + if isinstance(block, ToolUseBlock): + return { + "type": "tool_use", + "id": block.id, + "name": block.name, + "input": block.input, + } + + +def cache_decoder(dct: dict[str, str]): + if "type" not in dct: + return dct + + if dct["type"] == "anthropic_model_result": + return AnthropicModelResult(blocks=dct["blocks"]) + + if dct["type"] == "text": + return TextBlock(type="text", text=dct["text"]) + + if dct["type"] == "tool_use": + return ToolUseBlock( + type="tool_use", + id=dct["id"], + name=dct["name"], + input=dct["input"], + ) + + class ModelCache: def __init__(self): self.directory = None if ModelCacheEnvVar in os.environ: + logger.warning("⚠ ModelCache is enabled") self.directory = os.environ[ModelCacheEnvVar] def _get_file(self, history: list[dict[str, str]]) -> str: - hash_input = str(history) - hash_object = hashlib.sha256(hash_input.encode('utf-8')) - return f"{self.directory}/model-query-{hash_object.hexdigest()}.json" + hash_input = str(history) + hash_object = hashlib.sha256(hash_input.encode("utf-8")) + return f"{self.directory}/model-query-{hash_object.hexdigest()}.json" - def query(self, history: list[dict[str, str]]) -> tuple[str, list[dict[str, str]]] | None: + def query( + self, history: list[dict[str, str]] + ) -> tuple[ModelQueryResult, list[dict[str, str]]] | None: if self.directory is None: return None file = self._get_file(history) @@ -28,14 +88,22 @@ def query(self, history: list[dict[str, str]]) -> tuple[str, list[dict[str, str] logger.info(f"ModelCacheMiss file={file}") return None logger.info(f"ModelCacheHit file={file}") - file_handle = open(file, 'r') - entries = json.load(file_handle) - return entries[1], entries[2] + file_handle = open(file, "r") + [_, model_result, stats_calls] = json.load( + file_handle, object_hook=cache_decoder + ) + return model_result, stats_calls - def insert(self, history: list[dict[str, str]], result_string: str, stats_calls: list[dict[str,str]]): + def insert( + self, + history: list[dict[str, str]], + model_result: ModelQueryResult, + stats_calls: list[dict[str, str]], + ): if self.directory is None: return file = self._get_file(history) logger.info(f"ModelCacheInsert file={file}") - file_handle = open(file, 'w') - json.dump([history, result_string, stats_calls], file_handle) + + file_handle = open(file, "w") + json.dump([history, model_result, stats_calls], file_handle, cls=CacheEncoder) diff --git a/sweagent/agent/model_result.py b/sweagent/agent/model_result.py new file mode 100644 index 000000000..5b59b5ba0 --- /dev/null +++ b/sweagent/agent/model_result.py @@ -0,0 +1,21 @@ +from dataclasses import dataclass + +from anthropic.types import ContentBlock + +@dataclass +class AnthropicModelResult: + blocks: list[ContentBlock] + + def get_tool_uses(self): + return [block for block in self.blocks if block.type == "tool_use"] + + def get_last_tool_use(self): + return next(reversed(self.get_tool_uses()), None) + + def __init__(self, blocks): + self.blocks = blocks + + def __repr__(self) -> str: + return f"AnthropicModelResult(blocks={repr(self.blocks)})" + +ModelQueryResult = str | AnthropicModelResult diff --git a/sweagent/agent/models.py b/sweagent/agent/models.py index eabb0bbfe..7ddfe3911 100644 --- a/sweagent/agent/models.py +++ b/sweagent/agent/models.py @@ -23,6 +23,7 @@ from sweagent.agent.commands import Command from sweagent.agent.model_cache import ModelCache +from sweagent.agent.model_result import AnthropicModelResult, ModelQueryResult from sweagent.utils.config import keys_config from sweagent.utils.log import get_logger @@ -32,28 +33,6 @@ _MAX_RETRIES = keys_config.get("SWE_AGENT_MODEL_MAX_RETRIES", 0) -@dataclass -class AnthropicModelResult(dict): - blocks: list[ContentBlock] - - def get_tool_uses(self): - return [block for block in self.blocks if block.type == "tool_use"] - - def get_last_tool_use(self): - return next(reversed(self.get_tool_uses()), None) - - def __init__(self, blocks): - # Inherit from dict to make it JSON-serializable. - dict.__init__(self, blocks=blocks) - self.blocks = blocks - - def __repr__(self) -> str: - return f"AnthropicModelResult(blocks={repr(self.blocks)})" - - - -ModelQueryResult = str | AnthropicModelResult - def make_assistant_content(output: ModelQueryResult): if isinstance(output, str): return output From 28dd407b1bdc730defa918005b93701b72950dcf Mon Sep 17 00:00:00 2001 From: Chris Toshok Date: Tue, 17 Sep 2024 13:29:58 -0700 Subject: [PATCH 2/6] might as well make these all methods on AnthropicModelResult instead of doing some inline --- sweagent/agent/model_result.py | 10 ++++++++-- sweagent/agent/parsing.py | 6 +++--- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/sweagent/agent/model_result.py b/sweagent/agent/model_result.py index 5b59b5ba0..5c132ded1 100644 --- a/sweagent/agent/model_result.py +++ b/sweagent/agent/model_result.py @@ -6,11 +6,17 @@ class AnthropicModelResult: blocks: list[ContentBlock] - def get_tool_uses(self): + def get_tool_use_blocks(self): return [block for block in self.blocks if block.type == "tool_use"] def get_last_tool_use(self): - return next(reversed(self.get_tool_uses()), None) + return next(reversed(self.get_tool_use_blocks()), None) + + def get_text_blocks(self): + return [block for block in self.blocks if block.type == "text"] + + def get_non_content_blocks(self): + return [block for block in self.blocks if block.type not in ["tool_use", "text"]] def __init__(self, blocks): self.blocks = blocks diff --git a/sweagent/agent/parsing.py b/sweagent/agent/parsing.py index 6b280c449..cf4995166 100644 --- a/sweagent/agent/parsing.py +++ b/sweagent/agent/parsing.py @@ -176,14 +176,14 @@ def __call__(self, model_response: ModelQueryResult, commands: list[Command], st msg = f"{model_response.__class__.__name__}: model_response must be AnthropicModelResult. Can only work with Anthropic models. Found instead: {repr(model_response)}" raise TypeError(msg) - tool_blocks = list(model_response.get_tool_uses()) - texts = [block.text for block in model_response.blocks if block.type == "text"] + tool_blocks = model_response.get_tool_use_blocks() + texts = [block.text for block in model_response.get_text_blocks()] if len(tool_blocks) != 1: msg = "Exactly one tool_use block must be present in the model response." raise FormatError(msg) - other_blocks = [block for block in model_response.blocks if block.type not in ["tool_use", "text"]] + other_blocks = model_response.get_non_content_blocks() if other_blocks: msg = f"NYI: Found {len(other_blocks)} unknown blocks in model response. Only tool_use and text blocks are supported: {repr(model_response.blocks)}" raise FormatError(msg) From 09d4e444d6e841288e79778ef73abb8164532905 Mon Sep 17 00:00:00 2001 From: Chris Toshok Date: Wed, 18 Sep 2024 12:38:12 -0700 Subject: [PATCH 3/6] weird. E504 -> RET504 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index eaf9f6172..41aa7cbbd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -170,12 +170,12 @@ select = [ ignore = [ # flake8-return + "RET504", # Unnecessary assignment * before `return` statement "RET505", # can't autofix "RET506", # can't autofix "RET507", # can't autofix # error (E) "E501", # line too long - "E504", # Unnecessary assignment * before `return` statement "E402", # import not on top of file "E722", # bare except "E741", # ambiguous symbol From 0a1be02db8e2e7f81abb25546c65f0c808647a96 Mon Sep 17 00:00:00 2001 From: Chris Toshok Date: Wed, 18 Sep 2024 12:40:08 -0700 Subject: [PATCH 4/6] label our new deps, and add pyright --- requirements-dev.lock | 4 ++++ requirements.lock | 4 ++++ requirements.txt | 3 ++- 3 files changed, 10 insertions(+), 1 deletion(-) diff --git a/requirements-dev.lock b/requirements-dev.lock index f069efe65..a102214e4 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -110,6 +110,7 @@ frozenlist==1.4.1 # via aiosignal fsspec==2024.6.1 # via datasets + # via fsspec # via huggingface-hub ghapi==1.0.6 # via swebench @@ -187,6 +188,7 @@ multiprocess==0.70.16 # via datasets nodeenv==1.9.1 # via pre-commit + # via pyright numpy==2.1.1 # via datasets # via gymnasium @@ -257,6 +259,8 @@ pygments==2.18.0 # via rich pyparsing==3.1.4 # via httplib2 +pyright==1.1.380 + # via sweagent python-dateutil==2.9.0.post0 # via botocore # via pandas diff --git a/requirements.lock b/requirements.lock index f069efe65..a102214e4 100644 --- a/requirements.lock +++ b/requirements.lock @@ -110,6 +110,7 @@ frozenlist==1.4.1 # via aiosignal fsspec==2024.6.1 # via datasets + # via fsspec # via huggingface-hub ghapi==1.0.6 # via swebench @@ -187,6 +188,7 @@ multiprocess==0.70.16 # via datasets nodeenv==1.9.1 # via pre-commit + # via pyright numpy==2.1.1 # via datasets # via gymnasium @@ -257,6 +259,8 @@ pygments==2.18.0 # via rich pyparsing==3.1.4 # via httplib2 +pyright==1.1.380 + # via sweagent python-dateutil==2.9.0.post0 # via botocore # via pandas diff --git a/requirements.txt b/requirements.txt index 7a78ee61d..e80b6eb74 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,11 +22,12 @@ flask-socketio groq pandas -# Some of our own dependencies +# Replay: Some of our own dependencies google-api-python-client google-auth-httplib2 google-auth-oauthlib python-dotenv +pyright tqdm opentelemetry-api>=1.25.0 opentelemetry-exporter-otlp-proto-http>=1.25.0 From fdbe9391b1f9e7925938732b248ba4131acf5529 Mon Sep 17 00:00:00 2001 From: Chris Toshok Date: Wed, 18 Sep 2024 12:43:02 -0700 Subject: [PATCH 5/6] add support for json serializing/deserializing the anthropic types we use, use json serialized form to generate the hash, and add support for raising an error if we get a cache miss (to verify that things are 100% reproducible) --- sweagent/agent/model_cache.py | 112 ++++++++++++++++++++++++++-------- sweagent/agent/models.py | 4 ++ 2 files changed, 91 insertions(+), 25 deletions(-) diff --git a/sweagent/agent/model_cache.py b/sweagent/agent/model_cache.py index 0d6c6ad65..9a0661def 100644 --- a/sweagent/agent/model_cache.py +++ b/sweagent/agent/model_cache.py @@ -1,8 +1,11 @@ +from __future__ import annotations import json import hashlib import os +import copy +from typing import Any -from anthropic.types import ContentBlock, TextBlock, ToolUseBlock +from anthropic.types import ContentBlock, TextBlock, ToolUseBlock, ToolResultBlockParam from sweagent.utils.log import get_logger from sweagent.agent.model_result import AnthropicModelResult, ModelQueryResult @@ -13,41 +16,51 @@ class CacheEncoder(json.JSONEncoder): - def default(self, obj): - if isinstance(obj, AnthropicModelResult): + def default(self, o): + if isinstance(o, AnthropicModelResult): return { "type": "anthropic_model_result", - "blocks": [ - encode_anthropic_content_block(block) for block in obj.blocks - ], + "blocks": [encode_anthropic_types(block) for block in o.blocks], } - if isinstance(obj, TextBlock) or isinstance(obj, ToolUseBlock): - return encode_anthropic_content_block(obj) + encoded = encode_anthropic_types(o) + if encoded is not None: + return encoded - return super().default(obj) + return super().default(o) -def encode_anthropic_content_block(block: ContentBlock) -> dict[str, str]: +def encode_anthropic_types(obj) -> dict[str, Any] | None: # it should be possible to use the Anthropic library to encode these (using BaseModel#to_dict), # but there doesn't seem to be a way to reconstruct instances from those dicts directly. Given # that it seems I have to write the latter, I'd like to keep code for both close to the other to # make it easier to keep them in sync. - if isinstance(block, TextBlock): + if isinstance(obj, TextBlock): return { "type": "text", - "text": block.text, + "text": obj.text, } - if isinstance(block, ToolUseBlock): + + if isinstance(obj, ToolUseBlock): return { "type": "tool_use", - "id": block.id, - "name": block.name, - "input": block.input, + "id": obj.id, + "name": obj.name, + "input": obj.input, } + if isinstance(obj, dict) and "type" in obj and obj["type"] == "tool_result": + return { + "type": "tool_result", + "tool_use_id": obj["id"], + "is_error": obj["is_error"], + "content": [encode_anthropic_types(c) for c in obj["content"]], + } + + return None + -def cache_decoder(dct: dict[str, str]): +def cache_decoder(dct: dict[str, Any]): if "type" not in dct: return dct @@ -65,6 +78,55 @@ def cache_decoder(dct: dict[str, str]): input=dct["input"], ) + if dct["type"] == "tool_result": + return ToolResultBlockParam( + type="tool_result", + tool_use_id=dct["tool_use_id"], + is_error=dct["is_error"], + content=dct["content"], + ) + + +def normalize_tool_use_ids(history: list[dict[str, Any]]) -> list[dict[str, Any]]: + # grovel around in the history and find all tool_result blocks. for all those tool_use_ids + # generate a dictionary from id -> int (starting at 1). then go through the history and + # replace all tool_use_id's/id's with the corresponding integer if the id is a key in the dict. + + mapping = {} + for entry in history: + if not isinstance(entry["content"], list): + continue + for c in entry["content"]: + if not isinstance(c, dict) or "tool_use_id" not in c: + continue + id = c["tool_use_id"] + mapping[id] = len(mapping) + 1 + + if len(mapping) == 0: + return history + + normalized = copy.deepcopy(history) + + logger.warn("Normalizing tool use ids") + for entry in normalized: + if not isinstance(entry["content"], list): + continue + for c in entry["content"]: + if isinstance(c, ToolUseBlock): + if c.id in mapping: + mapped = mapping[c.id] + c.id = f"toolu_normalized_{mapped}" + continue + + if "tool_use_id" in c: + id = c["tool_use_id"] + if id in mapping: + mapped = mapping[id] + c["tool_use_id"] = f"toolu_normalized_{mapped}" + continue + + return normalized + class ModelCache: def __init__(self): @@ -74,13 +136,12 @@ def __init__(self): self.directory = os.environ[ModelCacheEnvVar] def _get_file(self, history: list[dict[str, str]]) -> str: - hash_input = str(history) + hash_input = json.dumps(history, cls=CacheEncoder) + logger.warn(f"HASH_INPUT\n{hash_input}\nEND_OF_HASH_INPUT") hash_object = hashlib.sha256(hash_input.encode("utf-8")) return f"{self.directory}/model-query-{hash_object.hexdigest()}.json" - def query( - self, history: list[dict[str, str]] - ) -> tuple[ModelQueryResult, list[dict[str, str]]] | None: + def query(self, history: list[dict[str, str]]) -> tuple[ModelQueryResult, list[dict[str, str]]] | None: if self.directory is None: return None file = self._get_file(history) @@ -89,9 +150,7 @@ def query( return None logger.info(f"ModelCacheHit file={file}") file_handle = open(file, "r") - [_, model_result, stats_calls] = json.load( - file_handle, object_hook=cache_decoder - ) + [_, model_result, stats_calls] = json.load(file_handle, object_hook=cache_decoder) return model_result, stats_calls def insert( @@ -102,7 +161,10 @@ def insert( ): if self.directory is None: return - file = self._get_file(history) + + normalized_history = normalize_tool_use_ids(history) + + file = self._get_file(normalized_history) logger.info(f"ModelCacheInsert file={file}") file_handle = open(file, "w") diff --git a/sweagent/agent/models.py b/sweagent/agent/models.py index 7ddfe3911..736fd9314 100644 --- a/sweagent/agent/models.py +++ b/sweagent/agent/models.py @@ -155,6 +155,8 @@ def __init__(self, args: ModelArguments, commands: list[Command]): self.model_metadata = {} self.stats = APIStats() self.cache = ModelCache() + # set this to true to raise an exception on cache misses + self.cache_only = False # Map `model_name` to API-compatible name `api_model` self.api_model = ( @@ -258,6 +260,8 @@ def query(self, history: list[dict[str, str]]) -> ModelQueryResult: for call in stats_calls: self.update_stats(call["input_tokens"], call["output_tokens"]) else: + if self.cache_only: + raise Exception("ModelCache miss") result_string = self._query_raw(history) self.cache.insert(history, result_string, self.update_stats_calls) self.update_stats_calls = None From fc201eb9274327b48c8de9143f894810c7316a6b Mon Sep 17 00:00:00 2001 From: Chris Toshok Date: Mon, 23 Sep 2024 12:51:48 -0700 Subject: [PATCH 6/6] might as well get this stuff commited even if we don't use it. factor out the important cache serialization/deserialization functions so we can write tests for them. they're runnable with 'rye run replayio-tests' --- pyproject.toml | 9 ++++++ requirements-dev.lock | 9 ++++++ sweagent/agent/model_cache.py | 60 ++++++++++++++++++++--------------- sweagent/agent/models.py | 5 +-- tests/test_model_cache.py | 35 ++++++++++++++++++++ 5 files changed, 91 insertions(+), 27 deletions(-) create mode 100644 tests/test_model_cache.py diff --git a/pyproject.toml b/pyproject.toml index 41aa7cbbd..418111033 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,10 +65,16 @@ namespaces = false "Documentation" = "https://github.com/princeton-nlp/SWE-agent" "Source" = "http://github.com/princeton-nlp/SWE-agent" +[tool.rye] +dev-dependencies = [ + "pytest>=8.3.3", +] + [tool.pytest.ini_options] markers = [ "slow: marks tests as slow (deselect with '-m \"not slow\"')", + "replayio: marks tests added as part of replay.io work (deselect with '-m \"not replayio\"')", ] testpaths = [ "tests" @@ -216,3 +222,6 @@ aci = "aci" [tool.ruff.lint.isort] required-imports = ["from __future__ import annotations"] + +[tool.rye.scripts] +replayio-tests = "pytest -m replayio" \ No newline at end of file diff --git a/requirements-dev.lock b/requirements-dev.lock index a102214e4..2de466120 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -88,6 +88,7 @@ eval-type-backport==0.2.0 # via together exceptiongroup==1.2.2 # via anyio + # via pytest farama-notifications==0.0.4 # via gymnasium fastcore==1.7.5 @@ -164,6 +165,8 @@ idna==3.8 # via yarl importlib-metadata==8.4.0 # via opentelemetry-api +iniconfig==2.0.0 + # via pytest itsdangerous==2.2.0 # via flask jinja2==3.1.4 @@ -224,6 +227,7 @@ packaging==24.1 # via fastcore # via ghapi # via huggingface-hub + # via pytest pandas==2.2.2 # via datasets # via sweagent @@ -231,6 +235,8 @@ pillow==10.4.0 # via together platformdirs==4.3.2 # via virtualenv +pluggy==1.5.0 + # via pytest pre-commit==3.8.0 # via swebench proto-plus==1.24.0 @@ -261,6 +267,7 @@ pyparsing==3.1.4 # via httplib2 pyright==1.1.380 # via sweagent +pytest==8.3.3 python-dateutil==2.9.0.post0 # via botocore # via pandas @@ -331,6 +338,8 @@ together==1.2.12 # via sweagent tokenizers==0.20.0 # via anthropic +tomli==2.0.1 + # via pytest tqdm==4.66.5 # via datasets # via huggingface-hub diff --git a/sweagent/agent/model_cache.py b/sweagent/agent/model_cache.py index 9a0661def..9d8d859b9 100644 --- a/sweagent/agent/model_cache.py +++ b/sweagent/agent/model_cache.py @@ -5,7 +5,7 @@ import copy from typing import Any -from anthropic.types import ContentBlock, TextBlock, ToolUseBlock, ToolResultBlockParam +from anthropic.types import TextBlock, ToolUseBlock, ToolResultBlockParam from sweagent.utils.log import get_logger from sweagent.agent.model_result import AnthropicModelResult, ModelQueryResult @@ -50,13 +50,16 @@ def encode_anthropic_types(obj) -> dict[str, Any] | None: } if isinstance(obj, dict) and "type" in obj and obj["type"] == "tool_result": - return { + result = { "type": "tool_result", "tool_use_id": obj["id"], - "is_error": obj["is_error"], - "content": [encode_anthropic_types(c) for c in obj["content"]], } - + if "is_error" in obj: + result["is_error"] = obj["is_error"] + if "content" in obj: + result["content"] = [encode_anthropic_types(c) for c in obj["content"]] + return result + return None @@ -79,28 +82,19 @@ def cache_decoder(dct: dict[str, Any]): ) if dct["type"] == "tool_result": - return ToolResultBlockParam( - type="tool_result", - tool_use_id=dct["tool_use_id"], - is_error=dct["is_error"], - content=dct["content"], - ) + return dct + raise ValueError(f"Unknown type {dct['type']} in cache_decoder") def normalize_tool_use_ids(history: list[dict[str, Any]]) -> list[dict[str, Any]]: - # grovel around in the history and find all tool_result blocks. for all those tool_use_ids - # generate a dictionary from id -> int (starting at 1). then go through the history and - # replace all tool_use_id's/id's with the corresponding integer if the id is a key in the dict. - mapping = {} for entry in history: if not isinstance(entry["content"], list): continue for c in entry["content"]: - if not isinstance(c, dict) or "tool_use_id" not in c: + if not isinstance(c, ToolUseBlock): continue - id = c["tool_use_id"] - mapping[id] = len(mapping) + 1 + mapping[c.id] = len(mapping) + 1 if len(mapping) == 0: return history @@ -127,6 +121,21 @@ def normalize_tool_use_ids(history: list[dict[str, Any]]) -> list[dict[str, Any] return normalized +def hash_string(s: str) -> str: + hash_object = hashlib.sha256(s.encode("utf-8")) + return hash_object.hexdigest() + +def json_serialize_str(obj: Any, **kwargs) -> str: + return json.dumps(obj, **kwargs, cls=CacheEncoder) + +def json_serialize_file(obj: Any, fp: Any, **kwargs): # SupportsWrite[str] on fp here + json.dump(obj, fp, **kwargs, cls=CacheEncoder) + +def json_deserialize_str(s: str, **kwargs) -> Any: + return json.loads(s, **kwargs, object_hook=cache_decoder) + +def json_deserialize_file(fp: Any, **kwargs) -> Any: # SupportsRead[str] on fp here + return json.load(fp, **kwargs, object_hook=cache_decoder) class ModelCache: def __init__(self): @@ -136,21 +145,22 @@ def __init__(self): self.directory = os.environ[ModelCacheEnvVar] def _get_file(self, history: list[dict[str, str]]) -> str: - hash_input = json.dumps(history, cls=CacheEncoder) - logger.warn(f"HASH_INPUT\n{hash_input}\nEND_OF_HASH_INPUT") - hash_object = hashlib.sha256(hash_input.encode("utf-8")) - return f"{self.directory}/model-query-{hash_object.hexdigest()}.json" + hash_input = json_serialize_str(history) + print(f"HASH_INPUT\n{hash_input}\nEND_OF_HASH_INPUT") + hash = hash_string(hash_input) + return f"{self.directory}/model-query-{hash}.json" def query(self, history: list[dict[str, str]]) -> tuple[ModelQueryResult, list[dict[str, str]]] | None: if self.directory is None: return None - file = self._get_file(history) + normalized_history = normalize_tool_use_ids(history) + file = self._get_file(normalized_history) if not os.path.exists(file): logger.info(f"ModelCacheMiss file={file}") return None logger.info(f"ModelCacheHit file={file}") file_handle = open(file, "r") - [_, model_result, stats_calls] = json.load(file_handle, object_hook=cache_decoder) + [_, model_result, stats_calls] = json_deserialize_file(file_handle) return model_result, stats_calls def insert( @@ -168,4 +178,4 @@ def insert( logger.info(f"ModelCacheInsert file={file}") file_handle = open(file, "w") - json.dump([history, model_result, stats_calls], file_handle, cls=CacheEncoder) + json_serialize_file([history, model_result, stats_calls], file_handle) diff --git a/sweagent/agent/models.py b/sweagent/agent/models.py index 736fd9314..f677c5604 100644 --- a/sweagent/agent/models.py +++ b/sweagent/agent/models.py @@ -22,7 +22,7 @@ ) from sweagent.agent.commands import Command -from sweagent.agent.model_cache import ModelCache +from sweagent.agent.model_cache import ModelCache, json_serialize_str from sweagent.agent.model_result import AnthropicModelResult, ModelQueryResult from sweagent.utils.config import keys_config from sweagent.utils.log import get_logger @@ -86,7 +86,7 @@ def compress_history_entry(input_entry: any): elif isinstance(content, list): for b in content: cont = b["content"] - cont = cont if isinstance(cont, str) else json.dumps(cont, indent=2) + cont = cont if isinstance(cont, str) else json_serialize_str(cont, indent=2) b["content"] = f'(omitted {len(cont.splitlines())} lines)' @@ -156,6 +156,7 @@ def __init__(self, args: ModelArguments, commands: list[Command]): self.stats = APIStats() self.cache = ModelCache() # set this to true to raise an exception on cache misses + # this doesn't work since there is so much non-determinism within our command execution. self.cache_only = False # Map `model_name` to API-compatible name `api_model` diff --git a/tests/test_model_cache.py b/tests/test_model_cache.py new file mode 100644 index 000000000..f0ccc2666 --- /dev/null +++ b/tests/test_model_cache.py @@ -0,0 +1,35 @@ +from __future__ import annotations +import pytest + + +from anthropic.types import TextBlock, ToolUseBlock, ToolResultBlockParam +from sweagent.agent.model_cache import json_serialize_str, json_deserialize_str, json_deserialize_file, hash_string +from sweagent.agent.model_result import AnthropicModelResult + +@pytest.mark.replayio +def test_json_serialize_str(): + assert json_serialize_str(1) == "1" + assert json_serialize_str("hello") == '"hello"' + assert json_serialize_str([1, 2, 3]) == "[1, 2, 3]" + assert json_serialize_str({"a": 1, "b": 2}) == '{"a": 1, "b": 2}' + +@pytest.mark.replayio +def test_json_serialize_str_anthropic_types(): + assert json_serialize_str(TextBlock(type="text", text="hello")) == '{"type": "text", "text": "hello"}' + assert json_serialize_str(ToolUseBlock(type="tool_use", id="123", name="tool", input="input")) == '{"type": "tool_use", "id": "123", "name": "tool", "input": "input"}' + assert json_serialize_str({ "type": "tool_result", "tool_use_id": "123", "content": [TextBlock(type="text", text="hello")]}) == '{"type": "tool_result", "tool_use_id": "123", "content": [{"type": "text", "text": "hello"}]}' + assert json_serialize_str(AnthropicModelResult(blocks=[TextBlock(type="text", text="hello")])) == '{"type": "anthropic_model_result", "blocks": [{"type": "text", "text": "hello"}]}' + + +# this is the history from a local file with hash = b5f404f6b8e2ae3c709c7fbfc05e60b7a56863db4419ddfe0addea7e90bc3221 + +input = ''' +[{"role": "system", "content": "# BACKGROUND\\nYou are an autonomous programmer, and you're working directly in the command line with a special interface.\\nThe special interface consists of a file editor that shows you 100 lines of a file at a time.\\nIn addition to typical bash commands, you can also use the given commands (tools) to help you navigate and edit files.\\n\\n# RESPONSE FORMAT\\nAt every iteration, you should only call a *SINGLE* tool. Don't issue two tool calls at once.\\n\\n## Requirements\\n\\n* You are provided `tdd_*` tools to reproduce the issue with one or more golden tests and also check for regressions.\\n* The output of the last reproduction is always provided to you.\\n* Always start your investigations from the failing test.\\n* When deciding to investigate some part of the code, EXPLAIN CLEARLY why you think that this is the next step to take.\\n* Don't submit until the reproduction command proves your fix.\\n", "agent": "primary"}, {"role": "user", "content": "# ISSUE\\nA user reported the following issue:\\n\\nAdd secure default SECURE_REFERRER_POLICY / Referrer-policy header\\nDescription\\n\\t\\n#29406 added the ability for the SECURE_REFERRER_POLICY setting to set Referrer-Policy, released in Django 3.0.\\nI propose we change the default for this to \\"same-origin\\" to make Django applications leak less information to third party sites.\\nThe main risk of breakage here would be linked websites breaking, if they depend on verification through the Referer header. This is a pretty fragile technique since it can be spoofed.\\nDocumentation: \\u200bhttps://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Referrer-Policy\\nThe MDN support grid is out of date: \\u200bhttps://caniuse.com/#search=Referrer-Policy\\n\\n\\n\\n# ISSUE REPRODUCTION RESULTS\\n\\nRunning tests to reproduce the bug (from /django__django):\\n >./tests/runtests.py --verbosity 2 --settings=test_sqlite --parallel 1 -- project_template.test_settings.TestStartProjectSettings.test_middleware_headers\\n\\ntest_middleware_headers (project_template.test_settings.TestStartProjectSettings) ... FAIL\\n\\n======================================================================\\nFAIL: test_middleware_headers (project_template.test_settings.TestStartProjectSettings)\\n----------------------------------------------------------------------\\nTraceback (most recent call last):\\n File \\"/django__django/tests/project_template/test_settings.py\\", line 43, in test_middleware_headers\\n b'X-Frame-Options: DENY',\\nAssertionError: Lists differ: [b'Co[58 chars]', b'X-Content-Type-Options: nosniff', b'X-Fra[13 chars]ENY'] != [b'Co[58 chars]', b'Referrer-Policy: same-origin', b'X-Conten[46 chars]ENY']\\n\\nFirst differing element 2:\\nb'X-Content-Type-Options: nosniff'\\nb'Referrer-Policy: same-origin'\\n\\nSecond list contains 1 additional elements.\\nFirst extra element 4:\\nb'X-Frame-Options: DENY'\\n\\n [b'Content-Length: 0',\\n b'Content-Type: text/html; charset=utf-8',\\n+ b'Referrer-Policy: same-origin',\\n b'X-Content-Type-Options: nosniff',\\n b'X-Frame-Options: DENY']\\n\\n----------------------------------------------------------------------\\nRan 1 test in 0.005s\\n\\nFAILED (failures=1)\\nTesting against Django installed in '/django__django/django'\\nImporting application project_template\\nSkipping setup of unused database(s): default, other.\\nSystem check identified no issues (0 silenced).\\n\\n\\n\\n# INSTRUCTIONS\\n* Solve this issue on your own. Your terminal session has started and you're in the repository's root directory. Edit files and run any checks or tests as needed.\\n* YOU CAN ONLY MAKE ONE TOOL CALL (RUN A COMMAND) AT A TIME. You should always wait for feedback after every command.\\n* Once you have verified that you have solved the problem (or unable to make progress), you must call the final `submit` command.\\n* Don't use any interactive session commands (e.g. python, vim) in this environment.\\n* You can write scripts and run them. E.g. you can write a python script and then run it with `python .py`.\\n* WARNING: Indentation matters! When editing a file, make sure to insert appropriate indentation before each line!\\n\\n# PROBLEM SOLVING STRATEGY\\n1. If you open a file and need to get to an area around a specific line that is not in the first 100 lines, say line 583, don't just use the scroll_down command multiple times. Instead, use the goto 583 command. It's much quicker.\\n2. Always make sure to look at the current open file and the CWD. The currently open file might be in a different directory than the working directory! Note that some commands, such as 'create', open files, so they might change the current open file.\\n3. CWD is the directory of the repo you are supposed to edit. Only modify files inside this directory. Always provide absolute file paths, prefixed with $PWD.\\n4. When editing files, it is easy to accidentally specify a wrong line number or to write code with incorrect indentation. Always check the code after you issue an edit to make sure that it reflects what you wanted to accomplish. If it didn't, issue another command to fix it.\\n\\n\\n(Open file: n/a)\\n(Current directory: /django__django)\\nbash-$", "agent": "primary", "tdd": true}, {"role": "assistant", "content": [{"type": "text", "text": "To start investigating this issue, I'll first examine the test file that's failing. The error message indicates that the failing test is in the file \\"tests/project_template/test_settings.py\\". Let's open this file and look at the relevant test."}, {"type": "tool_use", "id": "toolu_012meYzV4qR1vJd9LMMXK1kK", "name": "open", "input": {"path": "/django__django/tests/project_template/test_settings.py"}}], "thought": "To start investigating this issue, I'll first examine the test file that's failing. The error message indicates that the failing test is in the file \\"tests/project_template/test_settings.py\\". Let's open this file and look at the relevant test.", "action": "open \\"/django__django/tests/project_template/test_settings.py\\"", "agent": "primary", "tdd": false}, {"role": "user", "content": [{"type": "tool_result", "tool_use_id": "toolu_012meYzV4qR1vJd9LMMXK1kK", "content": "[File: /django__django/tests/project_template/test_settings.py (44 lines total)]\\n1:import os\\n2:import shutil\\n3:import tempfile\\n4:\\n5:from django import conf\\n6:from django.test import SimpleTestCase\\n7:from django.test.utils import extend_sys_path\\n8:\\n9:\\n10:class TestStartProjectSettings(SimpleTestCase):\\n11: def setUp(self):\\n12: self.temp_dir = tempfile.TemporaryDirectory()\\n13: self.addCleanup(self.temp_dir.cleanup)\\n14: template_settings_py = os.path.join(\\n15: os.path.dirname(conf.__file__),\\n16: 'project_template',\\n17: 'project_name',\\n18: 'settings.py-tpl',\\n19: )\\n20: test_settings_py = os.path.join(self.temp_dir.name, 'test_settings.py')\\n21: shutil.copyfile(template_settings_py, test_settings_py)\\n22:\\n23: def test_middleware_headers(self):\\n24: \\"\\"\\"\\n25: Ensure headers sent by the default MIDDLEWARE don't inadvertently\\n26: change. For example, we never want \\"Vary: Cookie\\" to appear in the list\\n27: since it prevents the caching of responses.\\n28: \\"\\"\\"\\n29: with extend_sys_path(self.temp_dir.name):\\n30: from test_settings import MIDDLEWARE\\n31:\\n32: with self.settings(\\n33: MIDDLEWARE=MIDDLEWARE,\\n34: ROOT_URLCONF='project_template.urls',\\n35: ):\\n36: response = self.client.get('/empty/')\\n37: headers = sorted(response.serialize_headers().split(b'\\\\r\\\\n'))\\n38: self.assertEqual(headers, [\\n39: b'Content-Length: 0',\\n40: b'Content-Type: text/html; charset=utf-8',\\n41: b'Referrer-Policy: same-origin',\\n42: b'X-Content-Type-Options: nosniff',\\n43: b'X-Frame-Options: DENY',\\n44: ])\\n(Open file: /django__django/tests/project_template/test_settings.py)\\n(Current directory: /django__django)\\nbash-$"}], "agent": "primary", "tdd": false}] +'''.strip() + +@pytest.mark.replayio +def test_json_round_trip(): + history = json_deserialize_str(input) + history_str = json_serialize_str(history) + + assert history_str == input