forked from SalesforceAIResearch/AgentLite
-
Notifications
You must be signed in to change notification settings - Fork 0
/
evaluate_webshop.py
114 lines (103 loc) · 3.95 KB
/
evaluate_webshop.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
from typing import List
import os
import argparse
from webshop_agents import WebshopAgent
from webshop_env import Webshop
from webshop_multiagent import bolaa_webagent
from agentlite.actions import BaseAction, FinishAct, ThinkAct
from agentlite.actions.InnerActions import INNER_ACT_KEY
from agentlite.agents import ABCAgent, BaseAgent
from agentlite.commons import AgentAct, TaskPackage
from agentlite.llm.agent_llms import BaseLLM, get_llm_backend
from agentlite.llm.LLMConfig import LLMConfig
from agentlite.logging.multi_agent_log import AgentLogger
LAM_URL = os.environ["LAM_URL"]
webshop_env = Webshop()
# =============================== start of webshop agent designing =============================== #
def evalute(idx: int, llm_name="gpt-3.5-turbo-16k-0613", agent_arch="react", PROMPT_DEBUG_FLAG=False):
if llm_name in ["xlam", "xlam_v2"]:
llm_config = LLMConfig(
{
"llm_name": llm_name,
"temperature": 0.0,
"base_url": LAM_URL,
"api_key": "EMPTY"
}
)
else:
llm_config = LLMConfig({"llm_name": llm_name, "temperature": 0.0})
llm = get_llm_backend(llm_config)
env_idx = f"fixed_{idx}"
if agent_arch in ["bolaa"]:
agent = bolaa_webagent(session_idx=env_idx, env=webshop_env, llm=llm, PROMPT_DEBUG_FLAG=PROMPT_DEBUG_FLAG)
task = agent.goal
agent.run()
else:
# reset the env first if not using bolaa agent
action = "reset[]"
webshop_env.step(env_idx, action)
agent = WebshopAgent(session_idx=env_idx, env=webshop_env, llm=llm, agent_arch=agent_arch, PROMPT_DEBUG_FLAG=PROMPT_DEBUG_FLAG)
task = webshop_env.goal
print(f"Task: {task}")
task_package = TaskPackage(instruction=task)
agent(task_package)
reward = webshop_env.reward
sub_reward = webshop_env.sub_reward
return reward, sub_reward, task
# using this function to rerun the evaluation if breaks
def get_runned_ids(file_path):
try:
with open(file_path, "r") as file:
runned_ids = [int(line.split()[0]) for line in file]
return runned_ids
except FileNotFoundError:
print("The file was not found.")
return None
except ValueError:
print("The last item in the last line is not a valid number.")
return None
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Test Search Agent on the webshop Benchmark"
)
parser.add_argument(
"--llm",
type=str,
default="gpt-3.5-turbo-16k-0613",
help="Name of the language model",
)
parser.add_argument(
"--agent_arch",
type=str,
choices=["react", "act", "planact", "planreact", "zs", "zst", "bolaa"],
default="react",
help="agent reasoning type",
)
parser.add_argument(
"--debug",
action='store_true',
help="debug flag",
)
args = parser.parse_args()
rewards = []
all_task_ids = list(range(0, 251))
REWARD_LOG_FILE = f"{args.llm}_{args.agent_arch}_results_webshop.csv"
runned_ids = get_runned_ids(REWARD_LOG_FILE)
if runned_ids is None:
evalute_ids = all_task_ids
else:
evalute_ids = [id for id in all_task_ids if id not in runned_ids]
# running webshop evaluation
with open(REWARD_LOG_FILE, "a") as f:
for i in evalute_ids:
reward, subreward, task = evalute(i, llm_name=args.llm, agent_arch=args.agent_arch, PROMPT_DEBUG_FLAG=args.debug)
rewards.append(reward)
reward_str = f"""{i}\t{task}\t{subreward}\t{reward}\n"""
f.write(reward_str)
# calculate the average reward
# read the file and calculate the average reward
with open(REWARD_LOG_FILE, "r") as f:
lines = f.readlines()
rewards = [float(line.split('\t')[3]) for line in lines]
avg_reward = sum(rewards) / len(rewards)
print(f"The average reward is: {avg_reward}")