-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcritic.py
133 lines (109 loc) · 4.64 KB
/
critic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
from voyager.prompts import load_prompt
from voyager.utils.json_utils import fix_and_parse_json
from langchain.schema import HumanMessage, SystemMessage
class CriticAgent:
def __init__(
self,
model_name="gpt-3.5-turbo",
temperature=0,
request_timout=120,
mode="auto",
):
self.llm = model_name
assert mode in ["auto", "manual"]
self.mode = mode
def render_system_message(self):
system_message = SystemMessage(content=load_prompt("critic"))
return system_message
def render_human_message(self, *, events, task, context, chest_observation):
assert events[-1][0] == "observe", "Last event must be observe"
biome = events[-1][1]["status"]["biome"]
time_of_day = events[-1][1]["status"]["timeOfDay"]
voxels = events[-1][1]["voxels"]
health = events[-1][1]["status"]["health"]
hunger = events[-1][1]["status"]["food"]
position = events[-1][1]["status"]["position"]
equipment = events[-1][1]["status"]["equipment"]
inventory_used = events[-1][1]["status"]["inventoryUsed"]
inventory = events[-1][1]["inventory"]
for i, (event_type, event) in enumerate(events):
if event_type == "onError":
print(f"\033[31mCritic Agent: Error occurs {event['onError']}\033[0m")
return None
observation = ""
observation += f"Biome: {biome}\n\n"
observation += f"Time: {time_of_day}\n\n"
if voxels:
observation += f"Nearby blocks: {', '.join(voxels)}\n\n"
else:
observation += f"Nearby blocks: None\n\n"
observation += f"Health: {health:.1f}/20\n\n"
observation += f"Hunger: {hunger:.1f}/20\n\n"
observation += f"Position: x={position['x']:.1f}, y={position['y']:.1f}, z={position['z']:.1f}\n\n"
observation += f"Equipment: {equipment}\n\n"
if inventory:
observation += f"Inventory ({inventory_used}/36): {inventory}\n\n"
else:
observation += f"Inventory ({inventory_used}/36): Empty\n\n"
observation += chest_observation
observation += f"Task: {task}\n\n"
if context:
observation += f"Context: {context}\n\n"
else:
observation += f"Context: None\n\n"
print(f"\033[31m****Critic Agent human message****\n{observation}\033[0m")
return HumanMessage(content=observation)
def human_check_task_success(self):
confirmed = False
success = False
critique = ""
while not confirmed:
success = input("Success? (y/n)")
success = success.lower() == "y"
critique = input("Enter your critique:")
print(f"Success: {success}\nCritique: {critique}")
confirmed = input("Confirm? (y/n)") in ["y", ""]
return success, critique
def ai_check_task_success(self, messages, max_retries=5):
if max_retries == 0:
print(
"\033[31mFailed to parse Critic Agent response. Consider updating your prompt.\033[0m"
)
return False, ""
if messages[1] is None:
return False, ""
critic = self.llm.invoke(messages)
print(f"\033[31m****Critic Agent ai message****\n{critic}\033[0m")
try:
response = fix_and_parse_json(critic)
assert response["success"] in [True, False]
if "critique" not in response:
response["critique"] = ""
return response["success"], response["critique"]
except Exception as e:
print(f"\033[31mError parsing critic response: {e} Trying again!\033[0m")
return self.ai_check_task_success(
messages=messages,
max_retries=max_retries - 1,
)
def check_task_success(
self, *, events, task, context, chest_observation, max_retries=5
):
human_message = self.render_human_message(
events=events,
task=task,
context=context,
chest_observation=chest_observation,
)
messages = [
self.render_system_message(),
human_message,
]
if self.mode == "manual":
return self.human_check_task_success()
elif self.mode == "auto":
return self.ai_check_task_success(
messages=messages, max_retries=max_retries
)
else:
raise ValueError(f"Invalid critic agent mode: {self.mode}")