Skip to content

Commit

Permalink
Merge pull request #16 from replayio/toshok/flesh-out-model-cache-and…
Browse files Browse the repository at this point in the history
…-add-tests

flesh out the model cache code (wrt object serialization) and add tests
  • Loading branch information
toshok authored Sep 23, 2024
2 parents 1d32db4 + fc201eb commit 8cb45c2
Show file tree
Hide file tree
Showing 9 changed files with 255 additions and 42 deletions.
11 changes: 10 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,16 @@ namespaces = false
"Documentation" = "https://github.com/princeton-nlp/SWE-agent"
"Source" = "http://github.com/princeton-nlp/SWE-agent"

[tool.rye]
dev-dependencies = [
"pytest>=8.3.3",
]


[tool.pytest.ini_options]
markers = [
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
"replayio: marks tests added as part of replay.io work (deselect with '-m \"not replayio\"')",
]
testpaths = [
"tests"
Expand Down Expand Up @@ -170,12 +176,12 @@ select = [

ignore = [
# flake8-return
"RET504", # Unnecessary assignment * before `return` statement
"RET505", # can't autofix
"RET506", # can't autofix
"RET507", # can't autofix
# error (E)
"E501", # line too long
"E504", # Unnecessary assignment * before `return` statement
"E402", # import not on top of file
"E722", # bare except
"E741", # ambiguous symbol
Expand Down Expand Up @@ -216,3 +222,6 @@ aci = "aci"

[tool.ruff.lint.isort]
required-imports = ["from __future__ import annotations"]

[tool.rye.scripts]
replayio-tests = "pytest -m replayio"
13 changes: 13 additions & 0 deletions requirements-dev.lock
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ eval-type-backport==0.2.0
# via together
exceptiongroup==1.2.2
# via anyio
# via pytest
farama-notifications==0.0.4
# via gymnasium
fastcore==1.7.5
Expand All @@ -110,6 +111,7 @@ frozenlist==1.4.1
# via aiosignal
fsspec==2024.6.1
# via datasets
# via fsspec
# via huggingface-hub
ghapi==1.0.6
# via swebench
Expand Down Expand Up @@ -163,6 +165,8 @@ idna==3.8
# via yarl
importlib-metadata==8.4.0
# via opentelemetry-api
iniconfig==2.0.0
# via pytest
itsdangerous==2.2.0
# via flask
jinja2==3.1.4
Expand All @@ -187,6 +191,7 @@ multiprocess==0.70.16
# via datasets
nodeenv==1.9.1
# via pre-commit
# via pyright
numpy==2.1.1
# via datasets
# via gymnasium
Expand Down Expand Up @@ -222,13 +227,16 @@ packaging==24.1
# via fastcore
# via ghapi
# via huggingface-hub
# via pytest
pandas==2.2.2
# via datasets
# via sweagent
pillow==10.4.0
# via together
platformdirs==4.3.2
# via virtualenv
pluggy==1.5.0
# via pytest
pre-commit==3.8.0
# via swebench
proto-plus==1.24.0
Expand Down Expand Up @@ -257,6 +265,9 @@ pygments==2.18.0
# via rich
pyparsing==3.1.4
# via httplib2
pyright==1.1.380
# via sweagent
pytest==8.3.3
python-dateutil==2.9.0.post0
# via botocore
# via pandas
Expand Down Expand Up @@ -327,6 +338,8 @@ together==1.2.12
# via sweagent
tokenizers==0.20.0
# via anthropic
tomli==2.0.1
# via pytest
tqdm==4.66.5
# via datasets
# via huggingface-hub
Expand Down
4 changes: 4 additions & 0 deletions requirements.lock
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ frozenlist==1.4.1
# via aiosignal
fsspec==2024.6.1
# via datasets
# via fsspec
# via huggingface-hub
ghapi==1.0.6
# via swebench
Expand Down Expand Up @@ -187,6 +188,7 @@ multiprocess==0.70.16
# via datasets
nodeenv==1.9.1
# via pre-commit
# via pyright
numpy==2.1.1
# via datasets
# via gymnasium
Expand Down Expand Up @@ -257,6 +259,8 @@ pygments==2.18.0
# via rich
pyparsing==3.1.4
# via httplib2
pyright==1.1.380
# via sweagent
python-dateutil==2.9.0.post0
# via botocore
# via pandas
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@ flask-socketio
groq
pandas

# Some of our own dependencies
# Replay: Some of our own dependencies
google-api-python-client
google-auth-httplib2
google-auth-oauthlib
python-dotenv
pyright
tqdm
opentelemetry-api>=1.25.0
opentelemetry-exporter-otlp-proto-http>=1.25.0
Expand Down
166 changes: 153 additions & 13 deletions sweagent/agent/model_cache.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,181 @@

from __future__ import annotations
import json
import hashlib
import os
import copy
from typing import Any

from anthropic.types import TextBlock, ToolUseBlock, ToolResultBlockParam

from sweagent.utils.log import get_logger
from sweagent.agent.model_result import AnthropicModelResult, ModelQueryResult

ModelCacheEnvVar = "MODEL_CACHE_DIRECTORY"

logger = get_logger("model_cache")


class CacheEncoder(json.JSONEncoder):
def default(self, o):
if isinstance(o, AnthropicModelResult):
return {
"type": "anthropic_model_result",
"blocks": [encode_anthropic_types(block) for block in o.blocks],
}

encoded = encode_anthropic_types(o)
if encoded is not None:
return encoded

return super().default(o)


def encode_anthropic_types(obj) -> dict[str, Any] | None:
# it should be possible to use the Anthropic library to encode these (using BaseModel#to_dict),
# but there doesn't seem to be a way to reconstruct instances from those dicts directly. Given
# that it seems I have to write the latter, I'd like to keep code for both close to the other to
# make it easier to keep them in sync.
if isinstance(obj, TextBlock):
return {
"type": "text",
"text": obj.text,
}

if isinstance(obj, ToolUseBlock):
return {
"type": "tool_use",
"id": obj.id,
"name": obj.name,
"input": obj.input,
}

if isinstance(obj, dict) and "type" in obj and obj["type"] == "tool_result":
result = {
"type": "tool_result",
"tool_use_id": obj["id"],
}
if "is_error" in obj:
result["is_error"] = obj["is_error"]
if "content" in obj:
result["content"] = [encode_anthropic_types(c) for c in obj["content"]]
return result

return None


def cache_decoder(dct: dict[str, Any]):
if "type" not in dct:
return dct

if dct["type"] == "anthropic_model_result":
return AnthropicModelResult(blocks=dct["blocks"])

if dct["type"] == "text":
return TextBlock(type="text", text=dct["text"])

if dct["type"] == "tool_use":
return ToolUseBlock(
type="tool_use",
id=dct["id"],
name=dct["name"],
input=dct["input"],
)

if dct["type"] == "tool_result":
return dct

raise ValueError(f"Unknown type {dct['type']} in cache_decoder")

def normalize_tool_use_ids(history: list[dict[str, Any]]) -> list[dict[str, Any]]:
mapping = {}
for entry in history:
if not isinstance(entry["content"], list):
continue
for c in entry["content"]:
if not isinstance(c, ToolUseBlock):
continue
mapping[c.id] = len(mapping) + 1

if len(mapping) == 0:
return history

normalized = copy.deepcopy(history)

logger.warn("Normalizing tool use ids")
for entry in normalized:
if not isinstance(entry["content"], list):
continue
for c in entry["content"]:
if isinstance(c, ToolUseBlock):
if c.id in mapping:
mapped = mapping[c.id]
c.id = f"toolu_normalized_{mapped}"
continue

if "tool_use_id" in c:
id = c["tool_use_id"]
if id in mapping:
mapped = mapping[id]
c["tool_use_id"] = f"toolu_normalized_{mapped}"
continue

return normalized

def hash_string(s: str) -> str:
hash_object = hashlib.sha256(s.encode("utf-8"))
return hash_object.hexdigest()

def json_serialize_str(obj: Any, **kwargs) -> str:
return json.dumps(obj, **kwargs, cls=CacheEncoder)

def json_serialize_file(obj: Any, fp: Any, **kwargs): # SupportsWrite[str] on fp here
json.dump(obj, fp, **kwargs, cls=CacheEncoder)

def json_deserialize_str(s: str, **kwargs) -> Any:
return json.loads(s, **kwargs, object_hook=cache_decoder)

def json_deserialize_file(fp: Any, **kwargs) -> Any: # SupportsRead[str] on fp here
return json.load(fp, **kwargs, object_hook=cache_decoder)

class ModelCache:
def __init__(self):
self.directory = None
if ModelCacheEnvVar in os.environ:
logger.warning("⚠ ModelCache is enabled")
self.directory = os.environ[ModelCacheEnvVar]

def _get_file(self, history: list[dict[str, str]]) -> str:
hash_input = str(history)
hash_object = hashlib.sha256(hash_input.encode('utf-8'))
return f"{self.directory}/model-query-{hash_object.hexdigest()}.json"
hash_input = json_serialize_str(history)
print(f"HASH_INPUT\n{hash_input}\nEND_OF_HASH_INPUT")
hash = hash_string(hash_input)
return f"{self.directory}/model-query-{hash}.json"

def query(self, history: list[dict[str, str]]) -> tuple[str, list[dict[str, str]]] | None:
def query(self, history: list[dict[str, str]]) -> tuple[ModelQueryResult, list[dict[str, str]]] | None:
if self.directory is None:
return None
file = self._get_file(history)
normalized_history = normalize_tool_use_ids(history)
file = self._get_file(normalized_history)
if not os.path.exists(file):
logger.info(f"ModelCacheMiss file={file}")
return None
logger.info(f"ModelCacheHit file={file}")
file_handle = open(file, 'r')
entries = json.load(file_handle)
return entries[1], entries[2]
file_handle = open(file, "r")
[_, model_result, stats_calls] = json_deserialize_file(file_handle)
return model_result, stats_calls

def insert(self, history: list[dict[str, str]], result_string: str, stats_calls: list[dict[str,str]]):
def insert(
self,
history: list[dict[str, str]],
model_result: ModelQueryResult,
stats_calls: list[dict[str, str]],
):
if self.directory is None:
return
file = self._get_file(history)

normalized_history = normalize_tool_use_ids(history)

file = self._get_file(normalized_history)
logger.info(f"ModelCacheInsert file={file}")
file_handle = open(file, 'w')
json.dump([history, result_string, stats_calls], file_handle)

file_handle = open(file, "w")
json_serialize_file([history, model_result, stats_calls], file_handle)
27 changes: 27 additions & 0 deletions sweagent/agent/model_result.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from dataclasses import dataclass

from anthropic.types import ContentBlock

@dataclass
class AnthropicModelResult:
blocks: list[ContentBlock]

def get_tool_use_blocks(self):
return [block for block in self.blocks if block.type == "tool_use"]

def get_last_tool_use(self):
return next(reversed(self.get_tool_use_blocks()), None)

def get_text_blocks(self):
return [block for block in self.blocks if block.type == "text"]

def get_non_content_blocks(self):
return [block for block in self.blocks if block.type not in ["tool_use", "text"]]

def __init__(self, blocks):
self.blocks = blocks

def __repr__(self) -> str:
return f"AnthropicModelResult(blocks={repr(self.blocks)})"

ModelQueryResult = str | AnthropicModelResult
Loading

0 comments on commit 8cb45c2

Please sign in to comment.