Skip to content

Commit

Permalink
fix(agent): Fix agent loss message bug (#1283)
Browse files Browse the repository at this point in the history
  • Loading branch information
yhjun1026 authored Mar 14, 2024
1 parent adaa68e commit a207640
Show file tree
Hide file tree
Showing 10 changed files with 278 additions and 91 deletions.
4 changes: 2 additions & 2 deletions dbgpt/agent/actions/plugin_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,10 @@ async def a_run(
if not resource_plugin_client:
raise ValueError("No implementation of the use of plug-in resources!")
response_success = True
status = Status.TODO.value
status = Status.RUNNING.value
tool_result = ""
err_msg = None
try:
status = Status.RUNNING.value
tool_result = await resource_plugin_client.a_execute_command(
param.tool_name, param.args, plugin_generator
)
Expand Down
47 changes: 47 additions & 0 deletions dbgpt/agent/agents/agent_new.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple, Union

from dbgpt.agent.resource.resource_loader import ResourceLoader
from dbgpt.core import LLMClient
from dbgpt.util.annotations import PublicAPI

from ..memory.gpts_memory import GptsMemory


class Agent(ABC):
async def a_send(
Expand Down Expand Up @@ -72,6 +78,8 @@ async def a_review(
async def a_act(
self,
message: Optional[str],
sender: Optional[Agent] = None,
reviewer: Optional[Agent] = None,
**kwargs,
) -> Union[str, Dict, None]:
"""
Expand Down Expand Up @@ -101,3 +109,42 @@ async def a_verify(
Returns:
"""


@dataclasses.dataclass
class AgentContext:
conv_id: str
gpts_app_name: str = None
language: str = None
max_chat_round: Optional[int] = 100
max_retry_round: Optional[int] = 10
max_new_tokens: Optional[int] = 1024
temperature: Optional[float] = 0.5
allow_format_str_template: Optional[bool] = False

def to_dict(self) -> Dict[str, Any]:
return dataclasses.asdict(self)


@dataclasses.dataclass
@PublicAPI(stability="beta")
class AgentGenerateContext:
"""A class to represent the input of a Agent."""

message: Optional[Dict]
sender: Agent
reviewer: Agent
silent: Optional[bool] = False

rely_messages: List[Dict] = dataclasses.field(default_factory=list)
final: Optional[bool] = True

memory: Optional[GptsMemory] = None
agent_context: Optional[AgentContext] = None
resource_loader: Optional[ResourceLoader] = None
llm_client: Optional[LLMClient] = None

round_index: int = None

def to_dict(self) -> Dict:
return dataclasses.asdict(self)
137 changes: 85 additions & 52 deletions dbgpt/agent/agents/base_agent_new.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
from pydantic import BaseModel, Field

from dbgpt.agent.actions.action import Action, ActionOutput
from dbgpt.agent.agents.agent import AgentContext
from dbgpt.agent.agents.agent_new import Agent
from dbgpt.agent.agents.agent_new import Agent, AgentContext
from dbgpt.agent.agents.llm.llm import LLMConfig, LLMStrategyType
from dbgpt.agent.agents.llm.llm_client import AIWrapper
from dbgpt.agent.agents.role import Role
Expand All @@ -31,7 +30,7 @@ class ConversableAgent(Role, Agent):
llm_config: Optional[LLMConfig] = None
memory: GptsMemory = Field(default_factory=GptsMemory)
resource_loader: Optional[ResourceLoader] = None
max_retry_count: int = 10
max_retry_count: int = 3
consecutive_auto_reply_counter: int = 0
llm_client: Optional[AIWrapper] = None
oai_system_message: List[Dict] = Field(default_factory=list)
Expand Down Expand Up @@ -178,54 +177,75 @@ async def a_generate_reply(
logger.info(
f"generate agent reply!sender={sender}, rely_messages_len={rely_messages}"
)
try:
reply_message = self._init_reply_message(recive_message=recive_message)
await self._system_message_assembly(
recive_message["content"], reply_message.get("context", None)
)

reply_message = self._init_reply_message(recive_message=recive_message)
await self._system_message_assembly(
recive_message["content"], reply_message.get("context", None)
)
fail_reason = None
current_retry_counter = 0
is_sucess = True
while current_retry_counter < self.max_retry_count:
if current_retry_counter > 0:
retry_message = self._init_reply_message(
recive_message=recive_message
)
retry_message["content"] = fail_reason
retry_message["current_goal"] = recive_message.get(
"current_goal", None
)
# The current message is a self-optimized message that needs to be recorded.
# It is temporarily set to be initiated by the originating end to facilitate the organization of historical memory context.
await sender.a_send(
retry_message, self, reviewer, request_reply=False
)

fail_reason = None
current_retry_counter = 0
is_sucess = True
while current_retry_counter < self.max_retry_count:
if current_retry_counter > 0:
retry_message = self._init_reply_message(recive_message=recive_message)
retry_message["content"] = fail_reason
# The current message is a self-optimized message that needs to be recorded.
# It is temporarily set to be initiated by the originating end to facilitate the organization of historical memory context.
await sender.a_send(retry_message, self, reviewer, request_reply=False)

# 1.Think about how to do things
llm_reply, model_name = await self.a_thinking(
self._load_thinking_messages(recive_message, sender, rely_messages)
)
reply_message["model_name"] = model_name
reply_message["content"] = llm_reply

# 2.Review whether what is being done is legal
approve, comments = await self.a_review(llm_reply, self)
reply_message["review_info"] = {"approve": approve, "comments": comments}

# 3.Act based on the results of your thinking
act_extent_param = self.prepare_act_param()
act_out: ActionOutput = await self.a_act(
message=llm_reply,
**act_extent_param,
)
reply_message["action_report"] = act_out.dict()

# 4.Reply information verification
check_paas, reason = await self.a_verify(reply_message, sender, reviewer)
is_sucess = check_paas
# 5.Optimize wrong answers myself
if not check_paas:
current_retry_counter += 1
# Send error messages and issue new problem-solving instructions
await self.a_send(reply_message, sender, reviewer, request_reply=False)
fail_reason = reason
else:
break
return is_sucess, reply_message
# 1.Think about how to do things
llm_reply, model_name = await self.a_thinking(
self._load_thinking_messages(recive_message, sender, rely_messages)
)
reply_message["model_name"] = model_name
reply_message["content"] = llm_reply

# 2.Review whether what is being done is legal
approve, comments = await self.a_review(llm_reply, self)
reply_message["review_info"] = {
"approve": approve,
"comments": comments,
}

# 3.Act based on the results of your thinking
act_extent_param = self.prepare_act_param()
act_out: ActionOutput = await self.a_act(
message=llm_reply,
sender=sender,
reviewer=reviewer,
**act_extent_param,
)
reply_message["action_report"] = act_out.dict()

# 4.Reply information verification
check_paas, reason = await self.a_verify(
reply_message, sender, reviewer
)
is_sucess = check_paas
# 5.Optimize wrong answers myself
if not check_paas:
current_retry_counter += 1
# Send error messages and issue new problem-solving instructions
if current_retry_counter < self.max_retry_count:
await self.a_send(
reply_message, sender, reviewer, request_reply=False
)
fail_reason = reason
else:
break
return is_sucess, reply_message

except Exception as e:
logger.exception("Generate reply exception!")
return False, {"content": str(e)}

async def a_thinking(
self, messages: Optional[List[Dict]], prompt: Optional[str] = None
Expand Down Expand Up @@ -265,7 +285,13 @@ async def a_review(
) -> Tuple[bool, Any]:
return True, None

async def a_act(self, message: Optional[str], **kwargs) -> Optional[ActionOutput]:
async def a_act(
self,
message: Optional[str],
sender: Optional[ConversableAgent] = None,
reviewer: Optional[ConversableAgent] = None,
**kwargs,
) -> Optional[ActionOutput]:
last_out = None
for action in self.actions:
# Select the resources required by acton
Expand Down Expand Up @@ -335,6 +361,7 @@ async def a_initiate_chat(
#######################################################################

def _init_actions(self, actions: List[Action] = None):
self.actions = []
for idx, action in enumerate(actions):
if not isinstance(action, Action):
self.actions.append(action())
Expand Down Expand Up @@ -426,7 +453,9 @@ async def _system_message_assembly(
for item in self.resources:
resource_client = self.resource_loader.get_resesource_api(item.type)
resource_prompt_list.append(
await resource_client.get_resource_prompt(item, qustion)
await resource_client.get_resource_prompt(
self.agent_context.conv_id, item, qustion
)
)
if context is None:
context = {}
Expand Down Expand Up @@ -525,7 +554,11 @@ def _convert_to_ai_message(
content = item.content
if item.action_report:
action_out = ActionOutput.from_dict(json.loads(item.action_report))
if action_out is not None and action_out.content is not None:
if (
action_out is not None
and action_out.is_exe_success
and action_out.content is not None
):
content = action_out.content
oai_messages.append(
{
Expand Down
Loading

0 comments on commit a207640

Please sign in to comment.