diff --git a/src/openagi/worker.py b/src/openagi/worker.py index f226856..5115dfb 100644 --- a/src/openagi/worker.py +++ b/src/openagi/worker.py @@ -12,10 +12,10 @@ from openagi.memory.memory import Memory from openagi.prompts.worker_task_execution import WorkerAgentTaskExecution from openagi.tasks.task import Task +from openagi.prompts.task_clarification import TaskClarifier from openagi.utils.extraction import get_act_classes_from_json, get_last_json from openagi.utils.helper import get_default_id - class Worker(BaseModel): id: str = Field(default_factory=get_default_id) role: str = Field(description="Role of the worker.") @@ -30,6 +30,9 @@ class Worker(BaseModel): description="Memory to be used.", exclude=True, ) + human_intervene: bool = Field( + default=False, description="If human internvention is required or not." + ) actions: Optional[List[Any]] = Field( description="Actions that the Worker supports", default_factory=list, @@ -100,7 +103,57 @@ def _force_output( f"LLM did not produce the expected output after {self.max_iterations} iterations." ) return (cont, final_output) + + def human_clarification(self, worker_vars) -> Dict: + """ + Handles the human clarification process during task planning. + + This method is responsible for interacting with the human user to clarify any + ambiguities or missing information in the task planning process. It uses a + TaskClarifier prompt to generate a question for the human, and then waits for + the human's response to update the planner variables accordingly. + + The method will retry the clarification process up to `self.retry_threshold` + times before giving up and returning the current planner variables. + Args: + worker_vars (Dict): The current planner variables, which may be updated + based on the human's response. + + Returns: + Dict: The updated planner variables after the human clarification process. + """ + + logging.info(f"Initiating Human Clarification. Make sure to clarify the questions, if not just type `I dont know` to stop") + chat_history = [] + + while True: + clarifier_vars = { + **worker_vars, + "chat_history": "\n".join(chat_history) + } + clarifier = TaskClarifier.from_template(variables=clarifier_vars) + + response = self.llm.run(clarifier) + parsed_response = get_last_json(response, llm=self.llm) + question = parsed_response.get("question", "").strip() + + if not question: + return worker_vars + + # set the ques_prompt to question in input_action + # self.input_action.ques_prompt = question + human_input = self.input_action.execute(prompt=question) + worker_vars["objective"] += f" {human_input}" + + # Update chat history + chat_history.append(f"Q: {question}") + chat_history.append(f"A: {human_input}") + + # Check for unwillingness to continue + if any(phrase in human_input.lower() for phrase in ["don't know", "no more questions", "that's all", "stop asking"]): + return worker_vars + def save_to_memory(self, task: Task): """Saves the output to the memory.""" return self.memory.update_task(task) @@ -127,6 +180,9 @@ def execute_task(self, task: Task, context: Any = None) -> Any: max_iterations=self.max_iterations, ) + if self.human_intervene: + te_vars = self.human_clarification(te_vars) + logging.debug("Generating base prompt...") base_prompt = WorkerAgentTaskExecution().from_template(te_vars) prompt = f"{base_prompt}\nThought:\nIteration: {iteration}\nActions:\n"