From 14a8f34b61d3666ac0c19aa7d4ecb61546e87f7a Mon Sep 17 00:00:00 2001 From: chenweize1998 Date: Wed, 25 Oct 2023 21:05:56 +0800 Subject: [PATCH] fix: bug in simulation gui. [ci skip] --- agentverse/gui.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/agentverse/gui.py b/agentverse/gui.py index 2bc4140a0..97cfac9c7 100644 --- a/agentverse/gui.py +++ b/agentverse/gui.py @@ -41,7 +41,7 @@ def __init__(self, task: str, tasks_dir: str, ui_kwargs: Dict[str, str]): self.messages = [] self.task = task self.ui_kwargs = ui_kwargs - if task == "pipeline_brainstorming": + if task == "tasksolving/brainstorming": self.backend = TaskSolving.from_task(task, tasks_dir) else: self.backend = Simulation.from_task(task, tasks_dir) @@ -60,9 +60,9 @@ def __init__(self, task: str, tasks_dir: str, ui_kwargs: Dict[str, str]): def get_avatar(self, idx): if idx == -1: img = cv2.imread(f"{IMG_PATH}/db_diag/-1.png") - elif self.task == "prisoner_dilemma": + elif self.task == "simulation/prisoner_dilemma": img = cv2.imread(f"{IMG_PATH}/prison/{idx}.png") - elif self.task == "db_diag": + elif self.task == "simulation/db_diag": img = cv2.imread(f"{IMG_PATH}/db_diag/{idx}.png") elif "sde" in self.task: img = cv2.imread(f"{IMG_PATH}/sde/{idx}.png") @@ -164,9 +164,9 @@ def reset(self, stu_num=0): self.backend.reset() self.turns_remain = self.backend.environment.max_turns - if self.task == "prisoner_dilemma": + if self.task == "simulation/prisoner_dilemma": background = cv2.imread(f"{IMG_PATH}/prison/case_1.png") - elif self.task == "db_diag": + elif self.task == "simulation/db_diag": background = cv2.imread(f"{IMG_PATH}/db_diag/background.png") elif "sde" in self.task: background = cv2.imread(f"{IMG_PATH}/sde/background.png") @@ -201,7 +201,7 @@ def gen_img(self, data: List[Dict]): # if len(data) != self.stu_num: if len(data) != self.stu_num + 1: raise gr.Error("data length is not equal to the total number of students.") - if self.task == "prisoner_dilemma": + if self.task == "simulation/prisoner_dilemma": img = cv2.imread(f"{IMG_PATH}/speaking.png", cv2.IMREAD_UNCHANGED) if ( len(self.messages) < 2 @@ -219,7 +219,7 @@ def gen_img(self, data: List[Dict]): cover_img(background, img, (550, 480)) if data[2]["message"] != "": cover_img(background, img, (550, 880)) - elif self.task == "db_diag": + elif self.task == "simulation/db_diag": background = cv2.imread(f"{IMG_PATH}/db_diag/background.png") img = cv2.imread(f"{IMG_PATH}/db_diag/speaking.png", cv2.IMREAD_UNCHANGED) if data[0]["message"] != "": @@ -262,7 +262,9 @@ def gen_img(self, data: List[Dict]): img = cv2.imread(f"{IMG_PATH}/hand.png", cv2.IMREAD_UNCHANGED) cover_img(background, img, (h_begin - 90, w_begin + 10)) elif data[stu_cnt]["message"] not in ["", "[RaiseHand]"]: - img = cv2.imread(f"{IMG_PATH}/speaking.png", cv2.IMREAD_UNCHANGED) + img = cv2.imread( + f"{IMG_PATH}/speaking.png", cv2.IMREAD_UNCHANGED + ) cover_img(background, img, (h_begin - 90, w_begin + 10)) else: @@ -274,7 +276,7 @@ def return_format(self, messages: List[Message]): _format = [{"message": "", "sender": idx} for idx in range(len(self.agent_id))] for message in messages: - if self.task == "db_diag": + if self.task == "simulation/db_diag": content_json: dict = message.content content_json[ "diagnose" @@ -342,7 +344,7 @@ def gen_message(self): avatar = self.get_avatar(-1) else: avatar = self.get_avatar((sender - 1) % 11 + 1) - if self.task == "db_diag": + if self.task == "simulation/db_diag": msg_json = json.loads(msg) self.solution_status = [False] * self.tot_solutions msg = msg_json["diagnose"] @@ -421,7 +423,7 @@ def submit(self, message: str): return self.gen_img([{"message": ""}] * len(self.agent_id)), self.gen_message() def launch(self, single_agent=False, discussion_mode=False): - if self.task == "pipeline_brainstorming": + if self.task == "tasksolving/brainstorming": with gr.Blocks() as demo: chatbot = gr.Chatbot(height=800, show_label=False) msg = gr.Textbox(label="Input") @@ -478,7 +480,7 @@ def respond(message, chat_history): # stu_num = gr.Number(label="Student Number", precision=0) # stu_num = self.stu_num - if self.task == "db_diag": + if self.task == "simulation/db_diag": user_msg = gr.Textbox() submit_btn = gr.Button("Submit", variant="primary")