Skip to content

Commit

Permalink
reduce max_iteration, fix json extract and introduce pooling for work…
Browse files Browse the repository at this point in the history
…ers.
  • Loading branch information
lucifertrj committed Dec 3, 2024
1 parent 300a2d3 commit d913841
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 95 deletions.
15 changes: 5 additions & 10 deletions src/openagi/prompts/worker_task_execution.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
from textwrap import dedent

from openagi.prompts.base import BasePrompt

WORKER_TASK_EXECUTION = dedent(
"""
You: {worker_description}
WORKER_TASK_EXECUTION = """
You are expert in: {worker_description}
# Instructions
- You run in a loop of Thought, Action, Observation. Follow the instructions below to understand the workflow and follow them in each iteration of the loop.
Expand All @@ -13,9 +10,9 @@
- Observation will be the result of running those actions. Make sure to thoroughly analyze the observation to see if it aligns with your expectations.
- On each observation, try to understand the drawbacks and mistakes and learn from them to improve further and get back on track.
- Take the context into account when you are answering the question. It will be the results or data from the past executions. If no context is provided, then you can assume that the context is empty and you can start from scratch. Use context to ensure consistency and accuracy in your responses.
- Output the answer when you feel the observations are correct and aligned with the goal. They do not have to be very accurate, but ensure they are reasonably reliable.
- The output should always be in the following format in all the iterations. Ensure the JSON format is suitable for utilization with json.loads(), enclosed in triple backticks:
- Output the answer when you feel the observations are reasonably good and aligned with the goal. They do not have to be very accurate, but ensure they are reasonably reliable.
- No Action/Output should be without json. Trying not include your thoughts as part of the action. You can skip the action if not required.
- The output needs to be in JSON ONLY:
- For Running an action:
```json
{
Expand Down Expand Up @@ -72,8 +69,6 @@
Begin!
{thought_provokes}
""".strip()
)


class WorkerAgentTaskExecution(BasePrompt):
base_prompt: str = WORKER_TASK_EXECUTION
base_prompt: str = WORKER_TASK_EXECUTION
96 changes: 34 additions & 62 deletions src/openagi/utils/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,47 +2,27 @@
import json
import logging
import re
from textwrap import dedent
from typing import Dict, List, Optional, Tuple

from openagi.exception import OpenAGIException
from openagi.llms.base import LLMBaseModel


def force_json_output(resp_txt: str, llm):
def force_json_output(resp_txt: str, llm) -> str:
"""
Forces the output once the max iterations are reached.
Forces proper JSON output format in first attempt.
"""
#prompt = dedent(
# """
# Below is a JSON block. Please try to provide the output in the format shown below only
# ```json
# {"key": "value"}
# ```
# the contents between ```json and ``` will be extracted and passed to json.loads() in python to convert it to a dictionary. Make sure that it works when passed else you will be fined. If its already in the correct format, then you can return the same output in the expected output format.
# Input:
# {resp_txt}
# Output:
# """
#).strip()

prompt = dedent(
"""
Your task is to process the input JSON and provide a valid JSON output. Follow these instructions carefully:
1. The output must be enclosed in a code block using triple backticks and the 'json' language identifier, like this:
```json
{"key": "value"}
```
2. The JSON inside the code block must be valid and parseable by Python's json.loads() function.
3. Ensure there are no extra spaces, newlines, or characters outside the JSON object within the code block.
4. If the input is already in the correct format, reproduce it exactly in the output format specified above.
5. Do not include any explanations, comments, or additional text in your response. The output needs be in JSON only.
6. Verify your output carefully before submitting. Incorrect responses will result in penalties.
prompt = """
You are a JSON formatting expert. Your task is to process the input and provide a valid JSON output.
Input: {resp_txt}
Output:
"""
).strip()
FOLLOW THESE INSTRUCTIONS to convert:
- Output must be ONLY a JSON object wrapped in ```json code block
- Do not include any explanations, comments, or additional text in your response. The output needs be in JSON only.
Convert this INPUT to proper JSON:
INPUT: {resp_txt}
Output only the JSON:
""".strip()

prompt = prompt.replace("{resp_txt}", resp_txt)
return llm.run(prompt)
Expand All @@ -52,47 +32,39 @@ def get_last_json(
text: str, llm: Optional[LLMBaseModel] = None, max_iterations: int = 5
) -> Optional[Dict]:
"""
Extracts the last block of text between ```json and ``` markers from a given string.
Args:
text (str): The string from which to extract the JSON block.
llm (Optional[LLMBaseModel]): The language model instance to use for reformatting.
max_iterations (int): Maximum number of iterations to try reformatting.
Returns:
dict or None: The last JSON block as a dictionary if found and parsed, otherwise None.
Extracts valid JSON from text with improved reliability.
"""
pattern = r"```json(.*?)```"
matches = re.findall(pattern, text, flags=re.DOTALL)
# More precise JSON block pattern
pattern = r"```json\s*(\{[\s\S]*?\})\s*```"
matches = re.findall(pattern, text, re.MULTILINE)

if matches:
last_json = matches[-1].strip().replace("\n", "")
try:
last_json = matches[-1].strip()
last_json = re.sub(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])', '', last_json)
last_json = re.sub(r'\s+', ' ', last_json)
return json.loads(last_json)
except json.JSONDecodeError:
logging.error("JSON not extracted. Trying again.", exc_info=True)
pass

except json.JSONDecodeError as e:
logging.error(f"JSON parsing failed: {str(e)}", exc_info=True)
if llm:
text = force_json_output(last_json, llm)
return get_last_json(text, None, max_iterations)

if llm:
for iteration in range(1, max_iterations + 1):
logging.info(f"Iteration {iteration} to extract JSON from LLM output.")
try:
text = force_json_output(text, llm)
matches = re.findall(pattern, text, flags=re.DOTALL)
if matches:
last_json = matches[-1].strip().replace("\n", "")
json_resp = json.loads(last_json)
logging.info("JSON extracted successfully.")
return json_resp
except json.JSONDecodeError:
logging.error("JSON not extracted. Trying again.", exc_info=True)
continue
if iteration == max_iterations:
raise OpenAGIException(
"The last output is not a valid JSON. Please check the output of the last action."
)
return get_last_json(text, None, max_iterations)
except Exception as e:
logging.error(f"Attempt {iteration} failed: {str(e)}", exc_info=True)
if iteration == max_iterations:
raise OpenAGIException(
f"Failed to extract valid JSON after {max_iterations} attempts. Last error: {str(e)}"
)
return None



def get_act_classes_from_json(json_data) -> List[Tuple[str, Optional[Dict]]]:
"""
Extracts the Action class names and parameters from a JSON block.
Expand Down
63 changes: 40 additions & 23 deletions src/openagi/worker.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import functools
from concurrent.futures import ThreadPoolExecutor
import logging
from pathlib import Path
import re
from textwrap import dedent
from typing import Any, Dict, List, Optional, Union

from pydantic import BaseModel, Field, field_validator
Expand Down Expand Up @@ -35,7 +36,7 @@ class Worker(BaseModel):
default_factory=list,
)
max_iterations: int = Field(
default=20,
default=10,
description="Maximum number of steps to achieve the objective.",
)
output_key: str = Field(
Expand All @@ -46,7 +47,7 @@ class Worker(BaseModel):
default=True,
description="If set to True, the output will be overwritten even if it exists.",
)

# Validate output_key. Should contain only alphabets and only underscore are allowed. Not alphanumeric
@field_validator("output_key")
@classmethod
Expand All @@ -70,7 +71,7 @@ def worker_doc(self):
}

def provoke_thought_obs(self, observation):
thoughts = dedent(f"""Observation: {observation}""".strip())
thoughts = f"""Observation: {observation}""".strip()
return thoughts

def should_continue(self, llm_resp: str) -> Union[bool, Optional[Dict]]:
Expand All @@ -84,7 +85,7 @@ def _force_output(
"""Force the output once the max iterations are reached."""
prompt = (
"\n".join(all_thoughts_and_obs)
+ "Based on the previous action and observation, give me the output."
+ "Based on the previous action and observation, force and give me the output."
)
output = self.llm.run(prompt)
cont, final_output = self.should_continue(output)
Expand All @@ -101,43 +102,53 @@ def _force_output(
)
return (cont, final_output)

@functools.lru_cache(maxsize=100)
def _cached_llm_run(self, prompt: str) -> str:
"""Cache LLM responses for identical prompts"""
return self.llm.run(prompt)

def save_to_memory(self, task: Task):
"""Saves the output to the memory."""
return self.memory.update_task(task)
"""Optimized memory update"""
if not hasattr(self, '_memory_buffer'):
self._memory_buffer = []
self._memory_buffer.append(task)

# Batch update memory when buffer reaches certain size
if len(self._memory_buffer) >= 5:
for buffered_task in self._memory_buffer:
self.memory.update_task(buffered_task)
self._memory_buffer.clear()
return True

def execute_task(self, task: Task, context: Any = None) -> Any:
"""Executes the specified task."""
logging.info(
f"{'>'*20} Executing Task - {task.name}[{task.id}] with worker - {self.role}[{self.id}] {'<'*20}"
)
"""Optimized task execution"""
logging.info(f"{'>'*20} Executing Task - {task.name}[{task.id}] with worker - {self.role}[{self.id}] {'<'*20}")

# Pre-compute common values
iteration = 1
task_to_execute = f"{task.description}"
worker_description = f"{self.role} - {self.instructions}"
all_thoughts_and_obs = []

logging.debug("Provoking initial thought observation...")
initial_thought_provokes = self.provoke_thought_obs(None)

# Generate base prompt once
te_vars = dict(
task_to_execute=task_to_execute,
worker_description=worker_description,
supported_actions=[action.cls_doc() for action in self.actions],
thought_provokes=initial_thought_provokes,
thought_provokes=self.provoke_thought_obs(None),
output_key=self.output_key,
context=context,
max_iterations=self.max_iterations,
)

logging.debug("Generating base prompt...")
base_prompt = WorkerAgentTaskExecution().from_template(te_vars)

# Use cached LLM run
prompt = f"{base_prompt}\nThought:\nIteration: {iteration}\nActions:\n"

logging.debug("Running LLM with prompt...")
observations = self.llm.run(prompt)
logging.info(f"LLM execution completed. Observations: {observations}")
observations = self._cached_llm_run(prompt)
all_thoughts_and_obs.append(prompt)

max_iters = self.max_iterations + 1
while iteration < max_iters:
while iteration < self.max_iterations + 1:

logging.info(f"---- Iteration {iteration} ----")
logging.debug("Checking if task should continue...")
continue_flag, output = self.should_continue(observations)
Expand Down Expand Up @@ -210,6 +221,7 @@ def execute_task(self, task: Task, context: Any = None) -> Any:
prompt = f"{base_prompt}\n" + "\n".join(all_thoughts_and_obs)
logging.debug(f"\nSTART:{'*' * 20}\n{prompt}\n{'*' * 20}:END")
pth = Path(f"{self.memory.session_id}/logs/{task.name}-{iteration}.log")

pth.parent.mkdir(parents=True, exist_ok=True)
with open(pth, "w", encoding="utf-8") as f:
f.write(f"{prompt}\n")
Expand Down Expand Up @@ -240,3 +252,8 @@ def execute_task(self, task: Task, context: Any = None) -> Any:
f"Task Execution Completed - {task.name} with worker - {self.role}[{self.id}] in {iteration} iterations"
)
return output, task

def __del__(self):
"""Cleanup thread pool on deletion"""
if hasattr(self, '_thread_pool'):
self._thread_pool.shutdown(wait=False)

0 comments on commit d913841

Please sign in to comment.