forked from SWE-agent/SWE-agent
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_replay.py
102 lines (90 loc) · 3.57 KB
/
run_replay.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import json
import os
import subprocess
import yaml
from argparse import ArgumentParser
from sweagent.environment.utils import is_from_github_url
from typing import Any, Dict, List
def process_single_traj(traj_path: str, config_file: str, data_path: str, suffix: str):
replay_action_trajs_path = "temp_replay.jsonl"
# Open trajectory file, extract responses as actions
if traj_path.endswith(".yaml"):
traj_data = dict()
with open(traj_path, "r") as f:
traj_data["history"] = yaml.safe_load(f)
else:
traj_data = json.load(open(traj_path, "r"))
actions = [x["content"] for x in traj_data["history"] if x["role"] == "assistant"]
instance_id = traj_path.split("/")[-1].split(".")[0]
with open(replay_action_trajs_path, "w") as f:
print(
json.dumps({instance_id: actions}),
file=f,
end="\n",
flush=True
)
# Get data_path from args.yaml
if data_path is None:
args_path = os.path.join(
os.path.dirname(traj_path),
"args.yaml"
)
args = yaml.safe_load(open(args_path))
data_path = args['environment']['data_path']
# Identify the relevant task instance and create it
def create_task_instances_tmp_file(data: List[Dict[str, Any]]) -> str:
"""Helper function to create a temporary file to write task instances to.
Returns path to the temporary file.
"""
data = [d for d in data if d["instance_id"] == instance_id]
tmp_path = instance_id + ".jsonl"
with open(tmp_path, "w") as f:
for d in data:
print(json.dumps(d), file=f, end="\n", flush=True)
return replay_task_instances_path
is_github = False
if data_path.endswith(".jsonl"):
replay_task_instances_path = create_task_instances_tmp_file([json.loads(x) for x in open(data_path, "r").readlines()])
elif data_path.endswith(".json"):
replay_task_instances_path = create_task_instances_tmp_file(json.load(open(data_path)))
elif is_from_github_url(data_path):
is_github = True
replay_task_instances_path = data_path
else:
raise ValueError("--data_path must be a .json or .jsonl")
# Call run.py via subprocess
command = [
"python",
"run.py",
"--config_file", config_file,
"--data_path", replay_task_instances_path,
"--install_environment", "True",
"--model_name", "replay",
"--replay_path", replay_action_trajs_path,
]
if is_github:
# Not sure if this only applies to github urls for data_path
command.extend(["--skip_existing", "False"])
if suffix is not None:
command.extend(["--suffix", suffix])
subprocess.run(command)
os.remove(replay_action_trajs_path)
try:
os.remove(replay_task_instances_path)
except FileNotFoundError:
pass
def main(
traj_path: str,
config_file: str,
data_path: str,
suffix: str,
):
process_single_traj(traj_path, config_file, data_path, suffix)
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--traj_path", help="Path to trajectory to replay", default=None)
parser.add_argument("--config_file", help="Path to template", required=True)
parser.add_argument("--data_path", help="(Optional) Path to data file containing task instances ref'ed by replay trajectories", default=None)
parser.add_argument("--suffix", help="(Optional) Suffix argument appended to end of traj path", default=None)
args = parser.parse_args()
main(**vars(args))