Skip to content

Commit

Permalink
Update Agent class
Browse files Browse the repository at this point in the history
  • Loading branch information
ashpreetbedi committed Sep 26, 2024
1 parent 6eaec74 commit 96ef5e3
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 32 deletions.
4 changes: 2 additions & 2 deletions cookbook/agents/telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
agent = Agent(
model=OpenAIChat(id="gpt-4o"),
tools=[YFinanceTools(stock_price=True)],
show_tool_calls=True,
# show_tool_calls=True,
markdown=True,
# debug_mode=True,
# monitoring=False,
Expand All @@ -27,7 +27,7 @@
# print(m)
# print("---")

run: Iterator[RunResponse] = agent.run("What is the stock price of NVDA", stream=True)
run: Iterator[RunResponse] = agent.run("What is the stock price of NVDA", stream=True, stream_intermediate_steps=True)
for chunk in run:
print("---")
pprint(chunk.model_dump(exclude={"messages"}))
Expand Down
110 changes: 89 additions & 21 deletions phi/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@

from phi.document import Document
from phi.agent.session import AgentSession
from phi.agent.response import RunResponse, RunEvent
from phi.agent.response import RunResponse, RunResponseEvent
from phi.knowledge.agent import AgentKnowledge
from phi.model import Model
from phi.model.message import Message, MessageContext
from phi.model.response import ModelResponse
from phi.model.response import ModelResponse, ModelResponseEvent
from phi.memory.agent import AgentMemory, MemoryRetrieval, Memory # noqa: F401
from phi.prompt.template import PromptTemplate
from phi.storage.agent import AgentStorage
Expand Down Expand Up @@ -883,6 +883,7 @@ def _run(
stream: bool = False,
images: Optional[List[Union[str, Dict]]] = None,
messages: Optional[List[Union[Dict, Message]]] = None,
stream_intermediate_steps: bool = False,
**kwargs: Any,
) -> Iterator[RunResponse]:
"""Run the Agent with a message and return the response.
Expand All @@ -900,6 +901,9 @@ def _run(
6. Save session to storage
7. Save output to file if save_output_to_file is set
"""
# Evaluate if streaming is enabled
stream_agent_response = stream and self.streamable
stream_intermediate_steps = stream_intermediate_steps and stream_agent_response
# Create the run_response object
self.run_response = RunResponse(run_id=str(uuid4()))

Expand All @@ -908,12 +912,11 @@ def _run(
# 1. Update the Model (set defaults, add tools, etc.)
self.update_model()
self.run_response.model = self.model.id if self.model is not None else None
if stream and self.streamable:
if stream_intermediate_steps:
yield RunResponse(
run_id=self.run_response.run_id,
content="Run started",
model=self.run_response.model,
event=RunEvent.run_start.value,
event=RunResponseEvent.run_started.value,
)

# 2. Read existing session from storage
Expand Down Expand Up @@ -977,14 +980,22 @@ def _run(
# 4. Generate a response from the Model (includes running function calls)
model_response: ModelResponse
self.model = cast(Model, self.model)
if stream and self.streamable:
if stream_agent_response:
model_response = ModelResponse(content="")
for model_response_chunk in self.model.response_stream(messages=messages_for_model):
if model_response_chunk.content is not None and model_response.content is not None:
model_response.content += model_response_chunk.content
self.run_response.content = model_response_chunk.content
self.run_response.messages = messages_for_model
yield self.run_response
if model_response_chunk.event == ModelResponseEvent.assistant_response.value:
if model_response_chunk.content is not None and model_response.content is not None:
model_response.content += model_response_chunk.content
self.run_response.content = model_response_chunk.content
self.run_response.messages = messages_for_model
yield self.run_response
elif model_response_chunk.event == ModelResponseEvent.tool_call.value:
if stream_intermediate_steps:
yield RunResponse(
run_id=self.run_response.run_id,
content=model_response_chunk.content,
event=RunResponseEvent.tool_call.value,
)
else:
model_response = self.model.response(messages=messages_for_model)
self.run_response.content = model_response.content
Expand All @@ -993,6 +1004,12 @@ def _run(
# Add the model metrics to the run_response
self.run_response.metrics = self.model.metrics if self.model else None
# 5. Update Memory
if stream_intermediate_steps:
yield RunResponse(
run_id=self.run_response.run_id,
content="Updating memory",
event=RunResponseEvent.updating_memory.value,
)
# Add the user message to the chat history
if message is not None:
user_message_for_chat_history = None
Expand Down Expand Up @@ -1038,7 +1055,7 @@ def _run(

# Update the run_response
# Update content if streaming as run_response will only contain the last chunk
if stream:
if stream_agent_response:
self.run_response.content = model_response.content
# Add tools from this run to the run_response
for _run_message in run_messages:
Expand Down Expand Up @@ -1112,9 +1129,15 @@ def _run(
self.log_agent_run(run_id=self.run_response.run_id, run_data=run_data)

logger.debug(f"*********** Agent Run End: {self.run_response.run_id} ***********")
if stream_intermediate_steps:
yield RunResponse(
run_id=self.run_response.run_id,
content="Run completed",
event=RunResponseEvent.run_completed.value,
)

# -*- Yield final response if not streaming
if not stream:
# -*- Yield final response if not streaming so that run() can get the response
if not stream_agent_response:
yield self.run_response

@overload
Expand All @@ -1125,6 +1148,7 @@ def run(
stream: Literal[False] = False,
images: Optional[List[Union[str, Dict]]] = None,
messages: Optional[List[Union[Dict, Message]]] = None,
stream_intermediate_steps: bool = False,
**kwargs: Any,
) -> RunResponse: ...

Expand All @@ -1136,6 +1160,7 @@ def run(
stream: Literal[True] = True,
images: Optional[List[Union[str, Dict]]] = None,
messages: Optional[List[Union[Dict, Message]]] = None,
stream_intermediate_steps: bool = False,
**kwargs: Any,
) -> Iterator[RunResponse]: ...

Expand All @@ -1146,6 +1171,7 @@ def run(
stream: bool = False,
images: Optional[List[Union[str, Dict]]] = None,
messages: Optional[List[Union[Dict, Message]]] = None,
stream_intermediate_steps: bool = False,
**kwargs: Any,
) -> Union[RunResponse, Iterator[RunResponse]]:
"""Run the Agent with a message and return the response."""
Expand All @@ -1155,7 +1181,14 @@ def run(
# Set stream=False and run the agent
logger.debug("Setting stream=False as response_model is set")
run_response: RunResponse = next(
self._run(message=message, stream=False, images=images, messages=messages, **kwargs)
self._run(
message=message,
stream=False,
images=images,
messages=messages,
stream_intermediate_steps=stream_intermediate_steps,
**kwargs,
)
)

# If the model natively supports structured outputs, the content is already in the structured format
Expand Down Expand Up @@ -1193,10 +1226,24 @@ def run(
return run_response
else:
if stream and self.streamable:
resp = self._run(message=message, stream=True, images=images, messages=messages, **kwargs)
resp = self._run(
message=message,
stream=True,
images=images,
messages=messages,
stream_intermediate_steps=stream_intermediate_steps,
**kwargs,
)
return resp
else:
resp = self._run(message=message, stream=False, images=images, messages=messages, **kwargs)
resp = self._run(
message=message,
stream=False,
images=images,
messages=messages,
stream_intermediate_steps=stream_intermediate_steps,
**kwargs,
)
return next(resp)

async def _arun(
Expand All @@ -1206,6 +1253,7 @@ async def _arun(
stream: bool = False,
images: Optional[List[Union[str, Dict]]] = None,
messages: Optional[List[Union[Dict, Message]]] = None,
stream_intermediate_steps: bool = False,
**kwargs: Any,
) -> AsyncIterator[RunResponse]:
"""Async Run the Agent with a message and return the response.
Expand Down Expand Up @@ -1236,7 +1284,7 @@ async def _arun(
run_id=self.run_response.run_id,
content="Run started",
model=self.run_response.model,
event=RunEvent.run_start,
event=RunResponseEvent.run_started.value,
)

# 2. Read existing session from storage
Expand Down Expand Up @@ -1448,6 +1496,7 @@ async def arun(
stream: bool = False,
images: Optional[List[Union[str, Dict]]] = None,
messages: Optional[List[Union[Dict, Message]]] = None,
stream_intermediate_steps: bool = False,
**kwargs: Any,
) -> Any:
"""Async Run the Agent with a message and return the response."""
Expand All @@ -1457,7 +1506,12 @@ async def arun(
# Set stream=False and run the agent
logger.debug("Setting stream=False as response_model is set")
run_response = await self._arun(
message=message, stream=False, images=images, messages=messages, **kwargs
message=message,
stream=False,
images=images,
messages=messages,
stream_intermediate_steps=stream_intermediate_steps,
**kwargs,
).__anext__()

# If the model natively supports structured outputs, the content is already in the structured format
Expand Down Expand Up @@ -1495,10 +1549,24 @@ async def arun(
return run_response
else:
if stream and self.streamable:
resp = self._arun(message=message, stream=True, images=images, messages=messages, **kwargs)
resp = self._arun(
message=message,
stream=True,
images=images,
messages=messages,
stream_intermediate_steps=stream_intermediate_steps,
**kwargs,
)
return resp
else:
resp = self._arun(message=message, stream=False, images=images, messages=messages, **kwargs)
resp = self._arun(
message=message,
stream=False,
images=images,
messages=messages,
stream_intermediate_steps=stream_intermediate_steps,
**kwargs,
)
return await resp.__anext__()

def rename(self, name: str) -> None:
Expand Down
13 changes: 7 additions & 6 deletions phi/agent/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@
from phi.model.message import Message, MessageContext


class RunEvent(str, Enum):
class RunResponseEvent(str, Enum):
"""Events that can be sent by the Agent.run() method"""

run_start = "RunStart"
intermediate_step = "IntermediateStep"
agent_response = "AgentResponse"
run_end = "RunEnd"
run_completed = "RunCompleted"
run_started = "RunStarted"
tool_call = "ToolCall"
updating_memory = "UpdatingMemory"


class RunResponse(BaseModel):
Expand All @@ -27,7 +28,7 @@ class RunResponse(BaseModel):
tools: Optional[List[Dict[str, Any]]] = None
context: Optional[List[MessageContext]] = None
model: Optional[str] = None
event: str = RunEvent.agent_response.value
event: str = RunResponseEvent.agent_response.value
created_at: int = Field(default_factory=lambda: int(time()))

model_config = ConfigDict(arbitrary_types_allowed=True, use_enum_values=True)
model_config = ConfigDict(arbitrary_types_allowed=True)
21 changes: 18 additions & 3 deletions phi/model/openai/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from phi.model.base import Model
from phi.model.message import Message
from phi.model.response import ModelResponse
from phi.model.response import ModelResponse, ModelResponseEvent
from phi.tools.function import FunctionCall
from phi.utils.log import logger
from phi.utils.timer import Timer
Expand Down Expand Up @@ -677,9 +677,19 @@ def response_stream(self, messages: List[Message]) -> Iterator[ModelResponse]:
continue
function_calls_to_run.append(_function_call)

# Yield intermediate messages for tool calls
for _f in function_calls_to_run:
yield ModelResponse(content=_f.get_call_str(), event=ModelResponseEvent.tool_call.value)

# Yield tool call
if self.show_tool_calls:
for _f in function_calls_to_run:
yield ModelResponse(content=f"\n - Running: {_f.get_call_str()}\n\n")
if len(function_calls_to_run) == 1:
yield ModelResponse(content=f"\n - Running: {function_calls_to_run[0].get_call_str()}\n\n")
elif len(function_calls_to_run) > 1:
yield ModelResponse(content="\nRunning:")
for _f in function_calls_to_run:
yield ModelResponse(content=f"\n - {_f.get_call_str()}")
yield ModelResponse(content="\n\n")

function_call_results = self.run_function_calls(function_calls_to_run)
if len(function_call_results) > 0:
Expand Down Expand Up @@ -765,6 +775,11 @@ async def aresponse_stream(self, messages: List[Message]) -> Any:
continue
function_calls_to_run.append(_function_call)

# Yield intermediate messages for tool calls
for _f in function_calls_to_run:
yield ModelResponse(content=_f.get_call_str(), event=ModelResponseEvent.tool_call.value)

# Yield tool call
if self.show_tool_calls:
if len(function_calls_to_run) == 1:
yield ModelResponse(content=f"\n - Running: {function_calls_to_run[0].get_call_str()}\n\n")
Expand Down
9 changes: 9 additions & 0 deletions phi/model/response.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
from enum import Enum
from typing import Optional

from pydantic import BaseModel


class ModelResponseEvent(str, Enum):
"""Events that can be sent by the Model.response() method"""

tool_call = "ToolCall"
assistant_response = "ModelResponse"


class ModelResponse(BaseModel):
"""Response returned by Model.response()"""

content: Optional[str] = None
event: str = ModelResponseEvent.assistant_response.value

0 comments on commit 96ef5e3

Please sign in to comment.