From d9138419e545bbf9103b81e7bf0b1ab0161542d0 Mon Sep 17 00:00:00 2001 From: lucifertrj Date: Tue, 3 Dec 2024 18:44:41 +0530 Subject: [PATCH] reduce max_iteration, fix json extract and introduce pooling for workers. --- src/openagi/prompts/worker_task_execution.py | 15 +-- src/openagi/utils/extraction.py | 96 +++++++------------- src/openagi/worker.py | 63 ++++++++----- 3 files changed, 79 insertions(+), 95 deletions(-) diff --git a/src/openagi/prompts/worker_task_execution.py b/src/openagi/prompts/worker_task_execution.py index ee495e5..440d5d7 100644 --- a/src/openagi/prompts/worker_task_execution.py +++ b/src/openagi/prompts/worker_task_execution.py @@ -1,10 +1,7 @@ -from textwrap import dedent - from openagi.prompts.base import BasePrompt -WORKER_TASK_EXECUTION = dedent( - """ -You: {worker_description} +WORKER_TASK_EXECUTION = """ +You are expert in: {worker_description} # Instructions - You run in a loop of Thought, Action, Observation. Follow the instructions below to understand the workflow and follow them in each iteration of the loop. @@ -13,9 +10,9 @@ - Observation will be the result of running those actions. Make sure to thoroughly analyze the observation to see if it aligns with your expectations. - On each observation, try to understand the drawbacks and mistakes and learn from them to improve further and get back on track. - Take the context into account when you are answering the question. It will be the results or data from the past executions. If no context is provided, then you can assume that the context is empty and you can start from scratch. Use context to ensure consistency and accuracy in your responses. -- Output the answer when you feel the observations are correct and aligned with the goal. They do not have to be very accurate, but ensure they are reasonably reliable. -- The output should always be in the following format in all the iterations. Ensure the JSON format is suitable for utilization with json.loads(), enclosed in triple backticks: +- Output the answer when you feel the observations are reasonably good and aligned with the goal. They do not have to be very accurate, but ensure they are reasonably reliable. - No Action/Output should be without json. Trying not include your thoughts as part of the action. You can skip the action if not required. +- The output needs to be in JSON ONLY: - For Running an action: ```json { @@ -72,8 +69,6 @@ Begin! {thought_provokes} """.strip() -) - class WorkerAgentTaskExecution(BasePrompt): - base_prompt: str = WORKER_TASK_EXECUTION + base_prompt: str = WORKER_TASK_EXECUTION \ No newline at end of file diff --git a/src/openagi/utils/extraction.py b/src/openagi/utils/extraction.py index 2bee55e..bad4262 100644 --- a/src/openagi/utils/extraction.py +++ b/src/openagi/utils/extraction.py @@ -2,47 +2,27 @@ import json import logging import re -from textwrap import dedent from typing import Dict, List, Optional, Tuple from openagi.exception import OpenAGIException from openagi.llms.base import LLMBaseModel -def force_json_output(resp_txt: str, llm): +def force_json_output(resp_txt: str, llm) -> str: """ - Forces the output once the max iterations are reached. + Forces proper JSON output format in first attempt. """ - #prompt = dedent( - # """ - # Below is a JSON block. Please try to provide the output in the format shown below only - # ```json - # {"key": "value"} - # ``` - # the contents between ```json and ``` will be extracted and passed to json.loads() in python to convert it to a dictionary. Make sure that it works when passed else you will be fined. If its already in the correct format, then you can return the same output in the expected output format. - # Input: - # {resp_txt} - # Output: - # """ - #).strip() - - prompt = dedent( - """ - Your task is to process the input JSON and provide a valid JSON output. Follow these instructions carefully: - 1. The output must be enclosed in a code block using triple backticks and the 'json' language identifier, like this: - ```json - {"key": "value"} - ``` - 2. The JSON inside the code block must be valid and parseable by Python's json.loads() function. - 3. Ensure there are no extra spaces, newlines, or characters outside the JSON object within the code block. - 4. If the input is already in the correct format, reproduce it exactly in the output format specified above. - 5. Do not include any explanations, comments, or additional text in your response. The output needs be in JSON only. - 6. Verify your output carefully before submitting. Incorrect responses will result in penalties. + prompt = """ + You are a JSON formatting expert. Your task is to process the input and provide a valid JSON output. - Input: {resp_txt} - Output: - """ - ).strip() + FOLLOW THESE INSTRUCTIONS to convert: + - Output must be ONLY a JSON object wrapped in ```json code block + - Do not include any explanations, comments, or additional text in your response. The output needs be in JSON only. + + Convert this INPUT to proper JSON: + INPUT: {resp_txt} + Output only the JSON: + """.strip() prompt = prompt.replace("{resp_txt}", resp_txt) return llm.run(prompt) @@ -52,47 +32,39 @@ def get_last_json( text: str, llm: Optional[LLMBaseModel] = None, max_iterations: int = 5 ) -> Optional[Dict]: """ - Extracts the last block of text between ```json and ``` markers from a given string. - - Args: - text (str): The string from which to extract the JSON block. - llm (Optional[LLMBaseModel]): The language model instance to use for reformatting. - max_iterations (int): Maximum number of iterations to try reformatting. - - Returns: - dict or None: The last JSON block as a dictionary if found and parsed, otherwise None. + Extracts valid JSON from text with improved reliability. """ - pattern = r"```json(.*?)```" - matches = re.findall(pattern, text, flags=re.DOTALL) + # More precise JSON block pattern + pattern = r"```json\s*(\{[\s\S]*?\})\s*```" + matches = re.findall(pattern, text, re.MULTILINE) + if matches: - last_json = matches[-1].strip().replace("\n", "") try: + last_json = matches[-1].strip() + last_json = re.sub(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])', '', last_json) + last_json = re.sub(r'\s+', ' ', last_json) return json.loads(last_json) - except json.JSONDecodeError: - logging.error("JSON not extracted. Trying again.", exc_info=True) - pass - + except json.JSONDecodeError as e: + logging.error(f"JSON parsing failed: {str(e)}", exc_info=True) + if llm: + text = force_json_output(last_json, llm) + return get_last_json(text, None, max_iterations) + if llm: for iteration in range(1, max_iterations + 1): - logging.info(f"Iteration {iteration} to extract JSON from LLM output.") try: text = force_json_output(text, llm) - matches = re.findall(pattern, text, flags=re.DOTALL) - if matches: - last_json = matches[-1].strip().replace("\n", "") - json_resp = json.loads(last_json) - logging.info("JSON extracted successfully.") - return json_resp - except json.JSONDecodeError: - logging.error("JSON not extracted. Trying again.", exc_info=True) - continue - if iteration == max_iterations: - raise OpenAGIException( - "The last output is not a valid JSON. Please check the output of the last action." - ) + return get_last_json(text, None, max_iterations) + except Exception as e: + logging.error(f"Attempt {iteration} failed: {str(e)}", exc_info=True) + if iteration == max_iterations: + raise OpenAGIException( + f"Failed to extract valid JSON after {max_iterations} attempts. Last error: {str(e)}" + ) return None + def get_act_classes_from_json(json_data) -> List[Tuple[str, Optional[Dict]]]: """ Extracts the Action class names and parameters from a JSON block. diff --git a/src/openagi/worker.py b/src/openagi/worker.py index f226856..e89b0a5 100644 --- a/src/openagi/worker.py +++ b/src/openagi/worker.py @@ -1,7 +1,8 @@ +import functools +from concurrent.futures import ThreadPoolExecutor import logging from pathlib import Path import re -from textwrap import dedent from typing import Any, Dict, List, Optional, Union from pydantic import BaseModel, Field, field_validator @@ -35,7 +36,7 @@ class Worker(BaseModel): default_factory=list, ) max_iterations: int = Field( - default=20, + default=10, description="Maximum number of steps to achieve the objective.", ) output_key: str = Field( @@ -46,7 +47,7 @@ class Worker(BaseModel): default=True, description="If set to True, the output will be overwritten even if it exists.", ) - + # Validate output_key. Should contain only alphabets and only underscore are allowed. Not alphanumeric @field_validator("output_key") @classmethod @@ -70,7 +71,7 @@ def worker_doc(self): } def provoke_thought_obs(self, observation): - thoughts = dedent(f"""Observation: {observation}""".strip()) + thoughts = f"""Observation: {observation}""".strip() return thoughts def should_continue(self, llm_resp: str) -> Union[bool, Optional[Dict]]: @@ -84,7 +85,7 @@ def _force_output( """Force the output once the max iterations are reached.""" prompt = ( "\n".join(all_thoughts_and_obs) - + "Based on the previous action and observation, give me the output." + + "Based on the previous action and observation, force and give me the output." ) output = self.llm.run(prompt) cont, final_output = self.should_continue(output) @@ -101,43 +102,53 @@ def _force_output( ) return (cont, final_output) + @functools.lru_cache(maxsize=100) + def _cached_llm_run(self, prompt: str) -> str: + """Cache LLM responses for identical prompts""" + return self.llm.run(prompt) + def save_to_memory(self, task: Task): - """Saves the output to the memory.""" - return self.memory.update_task(task) + """Optimized memory update""" + if not hasattr(self, '_memory_buffer'): + self._memory_buffer = [] + self._memory_buffer.append(task) + + # Batch update memory when buffer reaches certain size + if len(self._memory_buffer) >= 5: + for buffered_task in self._memory_buffer: + self.memory.update_task(buffered_task) + self._memory_buffer.clear() + return True def execute_task(self, task: Task, context: Any = None) -> Any: - """Executes the specified task.""" - logging.info( - f"{'>'*20} Executing Task - {task.name}[{task.id}] with worker - {self.role}[{self.id}] {'<'*20}" - ) + """Optimized task execution""" + logging.info(f"{'>'*20} Executing Task - {task.name}[{task.id}] with worker - {self.role}[{self.id}] {'<'*20}") + + # Pre-compute common values iteration = 1 task_to_execute = f"{task.description}" worker_description = f"{self.role} - {self.instructions}" all_thoughts_and_obs = [] - - logging.debug("Provoking initial thought observation...") - initial_thought_provokes = self.provoke_thought_obs(None) + + # Generate base prompt once te_vars = dict( task_to_execute=task_to_execute, worker_description=worker_description, supported_actions=[action.cls_doc() for action in self.actions], - thought_provokes=initial_thought_provokes, + thought_provokes=self.provoke_thought_obs(None), output_key=self.output_key, context=context, max_iterations=self.max_iterations, ) - - logging.debug("Generating base prompt...") base_prompt = WorkerAgentTaskExecution().from_template(te_vars) + + # Use cached LLM run prompt = f"{base_prompt}\nThought:\nIteration: {iteration}\nActions:\n" - - logging.debug("Running LLM with prompt...") - observations = self.llm.run(prompt) - logging.info(f"LLM execution completed. Observations: {observations}") + observations = self._cached_llm_run(prompt) all_thoughts_and_obs.append(prompt) - max_iters = self.max_iterations + 1 - while iteration < max_iters: + while iteration < self.max_iterations + 1: + logging.info(f"---- Iteration {iteration} ----") logging.debug("Checking if task should continue...") continue_flag, output = self.should_continue(observations) @@ -210,6 +221,7 @@ def execute_task(self, task: Task, context: Any = None) -> Any: prompt = f"{base_prompt}\n" + "\n".join(all_thoughts_and_obs) logging.debug(f"\nSTART:{'*' * 20}\n{prompt}\n{'*' * 20}:END") pth = Path(f"{self.memory.session_id}/logs/{task.name}-{iteration}.log") + pth.parent.mkdir(parents=True, exist_ok=True) with open(pth, "w", encoding="utf-8") as f: f.write(f"{prompt}\n") @@ -240,3 +252,8 @@ def execute_task(self, task: Task, context: Any = None) -> Any: f"Task Execution Completed - {task.name} with worker - {self.role}[{self.id}] in {iteration} iterations" ) return output, task + + def __del__(self): + """Cleanup thread pool on deletion""" + if hasattr(self, '_thread_pool'): + self._thread_pool.shutdown(wait=False)