From 3aa56d458ebe6254e6cca631c226f6d4a266ddb2 Mon Sep 17 00:00:00 2001 From: yitong <2969413251@qq.com> Date: Wed, 25 Oct 2023 19:24:03 +0800 Subject: [PATCH] fix: make to_message asynchronous to accelerate compressing chat history of multiple agents --- agentverse/agents/tasksolving_agent/critic.py | 2 +- .../agents/tasksolving_agent/evaluator.py | 20 ++++-- .../agents/tasksolving_agent/executor.py | 69 +------------------ .../agents/tasksolving_agent/role_assigner.py | 14 ++-- agentverse/agents/tasksolving_agent/solver.py | 14 ++-- .../environments/tasksolving_env/basic.py | 4 +- .../tasksolving_env/rules/base.py | 8 +-- .../rules/decision_maker/brainstorming.py | 2 +- .../rules/decision_maker/central.py | 2 +- .../rules/decision_maker/concurrent.py | 2 +- .../rules/decision_maker/dynamic.py | 4 +- .../rules/decision_maker/horizontal.py | 2 +- .../rules/decision_maker/horizontal_tool.py | 4 +- .../rules/decision_maker/vertical.py | 2 +- .../decision_maker/vertical_solver_first.py | 4 +- .../tasksolving_env/rules/evaluator/base.py | 2 +- .../tasksolving_env/rules/evaluator/basic.py | 8 +-- .../rules/executor/code_test.py | 30 +------- .../rules/executor/tool_using.py | 4 +- .../rules/role_assigner/base.py | 4 +- .../rules/role_assigner/role_description.py | 8 +-- agentverse/memory/chat_history.py | 20 +++--- 22 files changed, 75 insertions(+), 154 deletions(-) diff --git a/agentverse/agents/tasksolving_agent/critic.py b/agentverse/agents/tasksolving_agent/critic.py index ad0541c44..89850dd61 100644 --- a/agentverse/agents/tasksolving_agent/critic.py +++ b/agentverse/agents/tasksolving_agent/critic.py @@ -88,7 +88,7 @@ async def astep( max_send_token -= prompt_token - history = self.memory.to_messages( + history = await self.memory.to_messages( self.name, start_index=-self.max_history, max_send_token=max_send_token, diff --git a/agentverse/agents/tasksolving_agent/evaluator.py b/agentverse/agents/tasksolving_agent/evaluator.py index d9c1375e7..78275b492 100644 --- a/agentverse/agents/tasksolving_agent/evaluator.py +++ b/agentverse/agents/tasksolving_agent/evaluator.py @@ -26,6 +26,17 @@ def step( task_description: str, all_role_description: str, ) -> EvaluatorMessage: + pass + # return parsed_response + + async def astep( + self, + solution: str, + result: str, + task_description: str, + all_role_description: str, + ) -> EvaluatorMessage: + """Asynchronous version of step""" logger.debug("", self.name, Fore.MAGENTA) prepend_prompt, append_prompt, prompt_token = self.get_all_prompts( solution=solution, @@ -49,7 +60,7 @@ def step( max_send_token -= prompt_token - history = self.memory.to_messages( + history = await self.memory.to_messages( self.name, max_send_token=max_send_token, model=model_name, @@ -57,7 +68,7 @@ def step( parsed_response = None for i in range(self.max_retry): try: - response = self.llm.generate_response( + response = await self.llm.agenerate_response( prepend_prompt, history, append_prompt ) parsed_response = self.output_parser.parse(response) @@ -78,11 +89,6 @@ def step( advice=parsed_response[1] if parsed_response is not None else "", ) return message - # return parsed_response - - async def astep(self, solution: str) -> EvaluatorMessage: - """Asynchronous version of step""" - pass def _fill_prompt_template(self, solution: str, task_description: str) -> str: """Fill the placeholders in the prompt template diff --git a/agentverse/agents/tasksolving_agent/executor.py b/agentverse/agents/tasksolving_agent/executor.py index bd8b07c4c..6e4ebc4ee 100644 --- a/agentverse/agents/tasksolving_agent/executor.py +++ b/agentverse/agents/tasksolving_agent/executor.py @@ -23,72 +23,7 @@ class ExecutorAgent(BaseAgent): def step( self, task_description: str, solution: str, tools: List[dict] = [], **kwargs ) -> ExecutorMessage: - logger.debug("", self.name, Fore.MAGENTA) - prepend_prompt, append_prompt, prompt_token = self.get_all_prompts( - task_description=task_description, - solution=solution, - agent_name=self.name, - **kwargs, - ) - - model_name = self.llm.args.model - - if model_name.startswith("gpt-3.5-turbo"): - tokens_per_message = 4 - else: - tokens_per_message = 3 - - max_send_token = self.llm.send_token_limit(model_name) - if len(prepend_prompt) > 0: - max_send_token -= tokens_per_message - if (len(append_prompt)) > 0: - max_send_token -= tokens_per_message - - max_send_token -= prompt_token - - history = self.memory.to_messages( - self.name, - start_index=-self.max_history, - max_send_token=max_send_token, - model=model_name, - ) - parsed_response = None - for i in range(self.max_retry): - try: - response = self.llm.generate_response( - prepend_prompt, history, append_prompt, tools - ) - parsed_response = self.output_parser.parse(response) - break - except (KeyboardInterrupt, bdb.BdbQuit): - raise - except Exception as e: - logger.error(e) - logger.warn("Retrying...") - continue - - if parsed_response is None: - logger.error(f"{self.name} failed to generate valid response.") - if isinstance(parsed_response, AgentFinish): - message = ExecutorMessage( - content=parsed_response.return_values["output"], - sender=self.name, - sender_agent=self, - ) - elif isinstance(parsed_response, AgentAction): - message = ExecutorMessage( - content=parsed_response.log, - sender=self.name, - sender_agent=self, - tool_name=parsed_response.tool, - tool_input=parsed_response.tool_input, - ) - else: - raise ValueError( - f"Error response type: {type(parsed_response)}. Only support \ - AgentFinish and AgentAction. Modify your output parser." - ) - return message + pass async def astep( self, task_description: str, solution: str, tools: List[dict] = [], **kwargs @@ -116,7 +51,7 @@ async def astep( max_send_token -= prompt_token - history = self.memory.to_messages( + history = await self.memory.to_messages( self.name, start_index=-self.max_history, max_send_token=max_send_token, diff --git a/agentverse/agents/tasksolving_agent/role_assigner.py b/agentverse/agents/tasksolving_agent/role_assigner.py index ba752eee0..8fb3271a4 100644 --- a/agentverse/agents/tasksolving_agent/role_assigner.py +++ b/agentverse/agents/tasksolving_agent/role_assigner.py @@ -22,6 +22,12 @@ class RoleAssignerAgent(BaseAgent): def step( self, advice: str, task_description: str, cnt_critic_agents: int ) -> RoleAssignerMessage: + pass + + async def astep( + self, advice: str, task_description: str, cnt_critic_agents: int + ) -> RoleAssignerMessage: + """Asynchronous version of step""" logger.debug("", self.name, Fore.MAGENTA) prepend_prompt, append_prompt, prompt_token = self.get_all_prompts( advice=advice, @@ -44,13 +50,13 @@ def step( max_send_token -= prompt_token - history = self.memory.to_messages( + history = await self.memory.to_messages( self.name, max_send_token=max_send_token, model=model_name ) parsed_response = None for i in range(self.max_retry): try: - response = self.llm.generate_response( + response = await self.llm.agenerate_response( prepend_prompt, history, append_prompt ) parsed_response = self.output_parser.parse(response) @@ -76,10 +82,6 @@ def step( ) return message - async def astep(self, env_description: str = "") -> RoleAssignerMessage: - """Asynchronous version of step""" - pass - def _fill_prompt_template( self, advice, task_description: str, cnt_critic_agents: int ) -> str: diff --git a/agentverse/agents/tasksolving_agent/solver.py b/agentverse/agents/tasksolving_agent/solver.py index d65e397ae..dd95ea55f 100644 --- a/agentverse/agents/tasksolving_agent/solver.py +++ b/agentverse/agents/tasksolving_agent/solver.py @@ -25,6 +25,12 @@ class SolverAgent(BaseAgent): def step( self, former_solution: str, advice: str, task_description: str = "", **kwargs ) -> SolverMessage: + pass + + async def astep( + self, former_solution: str, advice: str, task_description: str = "", **kwargs + ) -> SolverMessage: + """Asynchronous version of step""" logger.debug("", self.name, Fore.MAGENTA) # prompt = self._fill_prompt_template( # former_solution, critic_opinions, advice, task_description @@ -52,7 +58,7 @@ def step( max_send_token -= prompt_token - history = self.memory.to_messages( + history = await self.memory.to_messages( self.name, start_index=-self.max_history, max_send_token=max_send_token, @@ -61,7 +67,7 @@ def step( parsed_response = None for i in range(self.max_retry): try: - response = self.llm.generate_response( + response = await self.llm.agenerate_response( prepend_prompt, history, append_prompt ) parsed_response = self.output_parser.parse(response) @@ -85,10 +91,6 @@ def step( ) return message - async def astep(self, env_description: str = "") -> SolverMessage: - """Asynchronous version of step""" - pass - def _fill_prompt_template( self, former_solution: str, diff --git a/agentverse/environments/tasksolving_env/basic.py b/agentverse/environments/tasksolving_env/basic.py index 8e4631a24..457430311 100644 --- a/agentverse/environments/tasksolving_env/basic.py +++ b/agentverse/environments/tasksolving_env/basic.py @@ -51,7 +51,7 @@ async def step( logger.info(f"Loop Round {self.cnt_turn}") # ================== EXPERT RECRUITMENT ================== - agents = self.rule.role_assign( + agents = await self.rule.role_assign( self.task_description, self.agents, self.cnt_turn, advice ) description = "\n".join([agent.role_description for agent in agents]) @@ -79,7 +79,7 @@ async def step( # ================== EXECUTION ================== # ================== EVALUATION ================== - score, advice = self.rule.evaluate( + score, advice = await self.rule.evaluate( self.task_description, self.agents, plan, result ) logs.append( diff --git a/agentverse/environments/tasksolving_env/rules/base.py b/agentverse/environments/tasksolving_env/rules/base.py index 2e25c017c..a1fcb0dc9 100644 --- a/agentverse/environments/tasksolving_env/rules/base.py +++ b/agentverse/environments/tasksolving_env/rules/base.py @@ -68,7 +68,7 @@ def build_components(config: Dict, registry): **kwargs, ) - def role_assign( + async def role_assign( self, task_description: str, agents: List[BaseAgent], @@ -79,7 +79,7 @@ def role_assign( if self.role_assign_only_once and cnt_turn > 0: agents = [agents[AGENT_TYPES.SOLVER]] + agents[AGENT_TYPES.CRITIC] else: - agents = self.role_assigner.step( + agents = await self.role_assigner.astep( role_assigner=agents[AGENT_TYPES.ROLE_ASSIGNMENT], group_members=[agents[AGENT_TYPES.SOLVER]] + agents[AGENT_TYPES.CRITIC], advice=advice, @@ -137,7 +137,7 @@ async def execute( agents[AGENT_TYPES.SOLVER].add_message_to_memory(results) return results - def evaluate( + async def evaluate( self, task_description: str, agents: List[BaseAgent], @@ -162,7 +162,7 @@ def evaluate( # logger.error("Bad response from human evaluator!") # return ([comprehensiveness, detailedness, feasibility, novelty], advice) # else: - evaluation = self.evaluator.step( + evaluation = await self.evaluator.astep( agent=agents[AGENT_TYPES.EVALUATION], solution=solution, result=result, diff --git a/agentverse/environments/tasksolving_env/rules/decision_maker/brainstorming.py b/agentverse/environments/tasksolving_env/rules/decision_maker/brainstorming.py index a6db1a5f6..39bd35ae9 100644 --- a/agentverse/environments/tasksolving_env/rules/decision_maker/brainstorming.py +++ b/agentverse/environments/tasksolving_env/rules/decision_maker/brainstorming.py @@ -53,7 +53,7 @@ async def astep( Fore.YELLOW, ) - result = agents[0].step(previous_plan, advice, task_description) + result = await agents[0].astep(previous_plan, advice, task_description) for agent in agents: agent.memory.reset() self.broadcast_messages( diff --git a/agentverse/environments/tasksolving_env/rules/decision_maker/central.py b/agentverse/environments/tasksolving_env/rules/decision_maker/central.py index 5d7bf5702..1d682e3da 100644 --- a/agentverse/environments/tasksolving_env/rules/decision_maker/central.py +++ b/agentverse/environments/tasksolving_env/rules/decision_maker/central.py @@ -47,7 +47,7 @@ async def astep( ), ) agents[1].add_message_to_memory([result]) - result = agents[0].step( + result = await agents[0].astep( previous_plan, advice, task_description, chat_record=result.content ) return [result] diff --git a/agentverse/environments/tasksolving_env/rules/decision_maker/concurrent.py b/agentverse/environments/tasksolving_env/rules/decision_maker/concurrent.py index cc34e00b7..f3979b851 100644 --- a/agentverse/environments/tasksolving_env/rules/decision_maker/concurrent.py +++ b/agentverse/environments/tasksolving_env/rules/decision_maker/concurrent.py @@ -59,7 +59,7 @@ async def astep( last_reviews = nonempty_reviews agents[0].add_message_to_memory(last_reviews) - result = agents[0].step(previous_plan, advice, task_description) + result = await agents[0].astep(previous_plan, advice, task_description) # agents[0].add_message_to_memory([result]) self.broadcast_messages(agents, [result]) return [result] diff --git a/agentverse/environments/tasksolving_env/rules/decision_maker/dynamic.py b/agentverse/environments/tasksolving_env/rules/decision_maker/dynamic.py index d6b6d72fe..c1fccf923 100644 --- a/agentverse/environments/tasksolving_env/rules/decision_maker/dynamic.py +++ b/agentverse/environments/tasksolving_env/rules/decision_maker/dynamic.py @@ -56,7 +56,7 @@ async def astep( # Fore.YELLOW, # ) - previous_sentence = manager.step( + previous_sentence = await manager.astep( previous_plan, review, advice, task_description, previous_sentence ) reviews.append(previous_sentence) @@ -76,7 +76,7 @@ async def astep( nonempty_reviews.append(review) agents[0].add_message_to_memory(nonempty_reviews) - result = agents[0].step(previous_plan, advice, task_description) + result = await agents[0].astep(previous_plan, advice, task_description) return [result] diff --git a/agentverse/environments/tasksolving_env/rules/decision_maker/horizontal.py b/agentverse/environments/tasksolving_env/rules/decision_maker/horizontal.py index b2a8c5703..ef9f43108 100644 --- a/agentverse/environments/tasksolving_env/rules/decision_maker/horizontal.py +++ b/agentverse/environments/tasksolving_env/rules/decision_maker/horizontal.py @@ -50,7 +50,7 @@ async def astep( Fore.YELLOW, ) - result = agents[0].step(previous_plan, advice, task_description) + result = await agents[0].astep(previous_plan, advice, task_description) return [result] def reset(self): diff --git a/agentverse/environments/tasksolving_env/rules/decision_maker/horizontal_tool.py b/agentverse/environments/tasksolving_env/rules/decision_maker/horizontal_tool.py index 5cea85eab..b30e88046 100644 --- a/agentverse/environments/tasksolving_env/rules/decision_maker/horizontal_tool.py +++ b/agentverse/environments/tasksolving_env/rules/decision_maker/horizontal_tool.py @@ -77,7 +77,9 @@ async def astep( if end_flag: break - result: SolverMessage = agents[0].step(previous_plan, advice, task_description) + result: SolverMessage = await agents[0].astep( + previous_plan, advice, task_description + ) result_list = [] for res in result.content: res_tmp = deepcopy(result) diff --git a/agentverse/environments/tasksolving_env/rules/decision_maker/vertical.py b/agentverse/environments/tasksolving_env/rules/decision_maker/vertical.py index d8adf594d..0bcc4688f 100644 --- a/agentverse/environments/tasksolving_env/rules/decision_maker/vertical.py +++ b/agentverse/environments/tasksolving_env/rules/decision_maker/vertical.py @@ -50,7 +50,7 @@ async def astep( if not review.is_agree and review.content != "": nonempty_reviews.append(review) agents[0].add_message_to_memory(nonempty_reviews) - result = agents[0].step(previous_plan, advice, task_description) + result = await agents[0].astep(previous_plan, advice, task_description) agents[0].add_message_to_memory([result]) return [result] diff --git a/agentverse/environments/tasksolving_env/rules/decision_maker/vertical_solver_first.py b/agentverse/environments/tasksolving_env/rules/decision_maker/vertical_solver_first.py index 97114f455..c3c38af8e 100644 --- a/agentverse/environments/tasksolving_env/rules/decision_maker/vertical_solver_first.py +++ b/agentverse/environments/tasksolving_env/rules/decision_maker/vertical_solver_first.py @@ -38,7 +38,7 @@ async def astep( self.broadcast_messages( agents, [Message(content=advice, sender="Evaluator")] ) - previous_plan = agents[0].step(previous_plan, advice, task_description) + previous_plan = await agents[0].astep(previous_plan, advice, task_description) self.broadcast_messages(agents, [previous_plan]) logger.info("", f"Initial Plan:\n{previous_plan.content}", Fore.BLUE) for i in range(self.max_inner_turns): @@ -65,7 +65,7 @@ async def astep( logger.info("", "Consensus Reached!.", Fore.GREEN) break self.broadcast_messages(agents, nonempty_reviews) - previous_plan = agents[0].step(previous_plan, advice, task_description) + previous_plan = await agents[0].astep(previous_plan, advice, task_description) logger.info("", f"Updated Plan:\n{previous_plan.content}", Fore.BLUE) self.broadcast_messages(agents, [previous_plan]) result = previous_plan diff --git a/agentverse/environments/tasksolving_env/rules/evaluator/base.py b/agentverse/environments/tasksolving_env/rules/evaluator/base.py index f3d72ad98..de83528c1 100644 --- a/agentverse/environments/tasksolving_env/rules/evaluator/base.py +++ b/agentverse/environments/tasksolving_env/rules/evaluator/base.py @@ -20,7 +20,7 @@ class BaseEvaluator(BaseModel): """ @abstractmethod - def step( + def astep( self, agent: EvaluatorAgent, solution: List[SolverMessage], diff --git a/agentverse/environments/tasksolving_env/rules/evaluator/basic.py b/agentverse/environments/tasksolving_env/rules/evaluator/basic.py index a7738fe21..847234f03 100644 --- a/agentverse/environments/tasksolving_env/rules/evaluator/basic.py +++ b/agentverse/environments/tasksolving_env/rules/evaluator/basic.py @@ -14,7 +14,7 @@ class BasicEvaluator(BaseEvaluator): cnt_agents: int = 0 - def step( + async def astep( self, agent: EvaluatorAgent, solution: List[SolverMessage], @@ -27,7 +27,7 @@ def step( flatten_solution = "\n".join([s.content for s in solution]) flatten_result = "\n".join([r.content for r in result]) flatten_all_role_description = "\n".join(all_role_description) - evaluation = agent.step( + evaluation = await agent.astep( flatten_solution, flatten_result, task_description, @@ -40,7 +40,7 @@ def step( class BasicEvaluator(BaseEvaluator): cnt_agents: int = 0 - def step( + async def astep( self, agent: EvaluatorAgent, solution: List[SolverMessage], @@ -54,7 +54,7 @@ def step( flatten_result = "\n".join([r.content for r in result]) flatten_all_role_description = "\n".join(all_role_description) agent.add_message_to_memory(result) - evaluation = agent.step( + evaluation = await agent.astep( flatten_solution, flatten_result, task_description, diff --git a/agentverse/environments/tasksolving_env/rules/executor/code_test.py b/agentverse/environments/tasksolving_env/rules/executor/code_test.py index 121aabc67..1a60b720e 100644 --- a/agentverse/environments/tasksolving_env/rules/executor/code_test.py +++ b/agentverse/environments/tasksolving_env/rules/executor/code_test.py @@ -71,35 +71,7 @@ def step( *args, **kwargs, ) -> Any: - solution = solution[0].content - os.makedirs("tmp", exist_ok=True) - self.write_to_file("tmp/main.py", solution) - manager = multiprocessing.Manager() - result = manager.list() - if task_description not in self.has_test: - response = agent.step(task_description, solution).content - self.write_to_file(response["file_path"], response["code"]) - self.has_test[task_description] = f"python {response['file_path']}" - p = multiprocessing.Process( - target=execute_command, args=(f"python {response['file_path']}", result) - ) - p.start() - p.join(timeout=self.timeout + 1) - if p.is_alive(): - p.kill() - # result = execute_command(f"python {response['file_path']}") - else: - # result = execute_command(self.has_test[task_description]) - p = multiprocessing.Process( - target=execute_command, args=(self.has_test[task_description], result) - ) - p.start() - p.join(timeout=self.timeout + 1) - if p.is_alive(): - p.kill() - if not result: - result.append("Execution timed out.") - return [ExecutorMessage(content=result[0], sender="Code Tester")] + pass def write_to_file(self, file_name, file_content): # TODO: generalize this method to a common tool diff --git a/agentverse/environments/tasksolving_env/rules/executor/tool_using.py b/agentverse/environments/tasksolving_env/rules/executor/tool_using.py index 5f177fc1c..62bc30429 100644 --- a/agentverse/environments/tasksolving_env/rules/executor/tool_using.py +++ b/agentverse/environments/tasksolving_env/rules/executor/tool_using.py @@ -258,8 +258,8 @@ async def _summarize_webpage(webpage, question): ], function_call={"name": "parse_web_text"}, ) - except e: - logger.error(e) + except Exception as e: + logger.error("Failed to call the tool. Exception: " + str(e)) continue arguments = ast.literal_eval( JsonRepair( diff --git a/agentverse/environments/tasksolving_env/rules/role_assigner/base.py b/agentverse/environments/tasksolving_env/rules/role_assigner/base.py index 726abf52a..22a3127c7 100644 --- a/agentverse/environments/tasksolving_env/rules/role_assigner/base.py +++ b/agentverse/environments/tasksolving_env/rules/role_assigner/base.py @@ -19,7 +19,7 @@ class BaseRoleAssigner(BaseModel): """ @abstractmethod - def step( + def astep( self, role_assigner: RoleAssignerAgent, group_members: List[CriticAgent], @@ -40,7 +40,7 @@ class DummyRoleAssigner(BaseRoleAssigner): The base class of role assignment class. """ - def step( + def astep( self, role_assigner: RoleAssignerAgent, group_members: List[CriticAgent], diff --git a/agentverse/environments/tasksolving_env/rules/role_assigner/role_description.py b/agentverse/environments/tasksolving_env/rules/role_assigner/role_description.py index 1d7490c83..b28b64408 100644 --- a/agentverse/environments/tasksolving_env/rules/role_assigner/role_description.py +++ b/agentverse/environments/tasksolving_env/rules/role_assigner/role_description.py @@ -16,7 +16,7 @@ class DescriptionAssigner(BaseRoleAssigner): Generates descriptions for each agent. """ - def step( + async def astep( self, role_assigner: RoleAssignerAgent, group_members: List[CriticAgent], @@ -28,7 +28,7 @@ def step( assert task_description != "" assert len(group_members) > 0 - roles = role_assigner.step(advice, task_description, len(group_members)) + roles = await role_assigner.astep(advice, task_description, len(group_members)) if len(roles.content) < len(group_members): raise ValueError( f"Number of roles ({len(roles.content)}) and number of group members ({len(group_members)}) do not match." @@ -50,7 +50,7 @@ class DescriptionNameAssigner(BaseRoleAssigner): Generates description and name for each agent. """ - def step( + async def astep( self, role_assigner: RoleAssignerAgent, group_members: List[CriticAgent], @@ -63,7 +63,7 @@ def step( assert len(group_members) > 0 # roles: [{'name': 'xxx', 'description': 'xxx'}, ...] - roles = role_assigner.step(advice, task_description, len(group_members)) + roles = await role_assigner.astep(advice, task_description, len(group_members)) if len(group_members) < 2: pass diff --git a/agentverse/memory/chat_history.py b/agentverse/memory/chat_history.py index f05624bf0..f2188c065 100644 --- a/agentverse/memory/chat_history.py +++ b/agentverse/memory/chat_history.py @@ -78,7 +78,7 @@ def to_string(self, add_sender_prefix: bool = False) -> str: else: return "\n".join([message.content for message in self.messages]) - def to_messages( + async def to_messages( self, my_name: str = "", start_index: int = 0, @@ -141,7 +141,9 @@ def to_messages( prompt, messages, max_send_token, model ) if trimmed_history: - new_summary_msg, _ = self.trim_messages(list(prompt), model, messages) + new_summary_msg, _ = await self.trim_messages( + list(prompt), model, messages + ) prompt.append(new_summary_msg) messages = prompt return messages @@ -149,7 +151,7 @@ def to_messages( def reset(self) -> None: self.messages = [] - def trim_messages( + async def trim_messages( self, current_message_chain: list[dict], model: str, history: List[dict] ) -> tuple[dict, list[dict]]: new_messages_not_in_chain = [ @@ -159,7 +161,7 @@ def trim_messages( if not new_messages_not_in_chain: return self.summary_message(), [] - new_summary_message = self.update_running_summary( + new_summary_message = await self.update_running_summary( new_events=new_messages_not_in_chain, model=model ) @@ -168,7 +170,7 @@ def trim_messages( return new_summary_message, new_messages_not_in_chain - def update_running_summary( + async def update_running_summary( self, new_events: list[Message], model: str = "gpt-3.5-turbo", @@ -217,7 +219,7 @@ def update_running_summary( batch_tlength + event_tlength > max_input_tokens - prompt_template_length - summary_tlength ): - self._update_summary_with_batch(batch, model, max_summary_length) + await self._update_summary_with_batch(batch, model, max_summary_length) summary_tlength = count_string_tokens(self.summary, model) batch = [event] batch_tlength = event_tlength @@ -226,18 +228,18 @@ def update_running_summary( batch_tlength += event_tlength if batch: - self._update_summary_with_batch(batch, model, max_summary_length) + await self._update_summary_with_batch(batch, model, max_summary_length) return self.summary_message() - def _update_summary_with_batch( + async def _update_summary_with_batch( self, new_events_batch: list[dict], model: str, max_summary_length: int ) -> None: prompt = self.SUMMARIZATION_PROMPT.format( summary=self.summary, new_events=new_events_batch ) - self.summary = openai.ChatCompletion.create( + self.summary = await openai.ChatCompletion.acreate( messages=[{"role": "user", "content": prompt}], model=model, max_tokens=max_summary_length,