Skip to content

Commit

Permalink
Log Agent Runs asynchronously
Browse files Browse the repository at this point in the history
  • Loading branch information
ashpreetbedi committed Sep 27, 2024
1 parent aee0650 commit 10e12aa
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 60 deletions.
42 changes: 21 additions & 21 deletions cookbook/agents/telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
tools=[YFinanceTools(stock_price=True)],
show_tool_calls=True,
markdown=True,
# debug_mode=True,
debug_mode=True,
# monitoring=False,
storage=PgAgentStorage(table_name="agent_sessions", db_url="postgresql+psycopg://ai:ai@localhost:5532/ai"),
)
Expand All @@ -30,23 +30,23 @@
# run: RunResponse = agent.run("What is the stock price of NVDA")
# pprint(run.content)

run_stream: Iterator[RunResponse] = agent.run(
"What is the stock price of NVDA", stream=True, stream_intermediate_steps=True
)
for chunk in run_stream:
print("---")
pprint(chunk.model_dump(exclude={"messages"}))
print("---")


# async def main():
# run: RunResponse = await agent.arun("What is the stock price of NVDA and TSLA")
# pprint(run)
# # async for chunk in await agent.arun("What is the stock price of NVDA and TSLA", stream=True):
# # print(chunk.content)
#
#
# asyncio.run(main())

# agent.print_response("What is the stock price of NVDA and TSLA?")
# agent.print_response("What is the stock price of NVDA and TSLA?", stream=True)
# run_stream: Iterator[RunResponse] = agent.run(
# "What is the stock price of NVDA", stream=True, stream_intermediate_steps=True
# )
# for chunk in run_stream:
# print("---")
# pprint(chunk.model_dump(exclude={"messages"}))
# print("---")


async def main():
await agent.aprint_response("What is the stock price of NVDA and TSLA")
# run: RunResponse = await agent.arun("What is the stock price of NVDA and TSLA")
# pprint(run)
# async for chunk in await agent.arun("What is the stock price of NVDA and TSLA", stream=True):
# print(chunk.content)


asyncio.run(main())

agent.print_response("What is the stock price of NVDA and TSLA?", stream=True)
44 changes: 29 additions & 15 deletions phi/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1281,7 +1281,7 @@ async def _arun(
# Create the run_response object
self.run_response = RunResponse(run_id=str(uuid4()))

logger.debug(f"*********** Agent Run Start: {self.run_response.run_id} ***********")
logger.debug(f"*********** Async Agent Run Start: {self.run_response.run_id} ***********")

# 1. Update the Model (set defaults, add tools, etc.)
self.update_model()
Expand Down Expand Up @@ -1508,7 +1508,7 @@ async def _arun(
}
self.log_agent_run(run_id=self.run_response.run_id, run_data=run_data)

logger.debug(f"*********** Agent Run End: {self.run_response.run_id} ***********")
logger.debug(f"*********** Async Agent Run End: {self.run_response.run_id} ***********")
if stream_intermediate_steps:
yield RunResponse(
run_id=self.run_response.run_id,
Expand Down Expand Up @@ -1809,11 +1809,11 @@ def log_agent_session(self):
if not self.telemetry:
return

from phi.api.agent import create_agent_session, AgentSessionCreate
from phi.api.agent import trigger_agent_session_creation, AgentSessionCreate

try:
agent_session: AgentSession = self.agent_session or self.to_agent_session()
create_agent_session(
trigger_agent_session_creation(
session=AgentSessionCreate(
session_id=agent_session.session_id,
agent_data=agent_session.monitoring_data() if self.monitoring else agent_session.telemetry_data(),
Expand All @@ -1826,11 +1826,11 @@ def log_agent_run(self, run_id: str, run_data: Optional[Dict[str, Any]] = None)
if not self.telemetry:
return

from phi.api.agent import create_agent_run, AgentRunCreate
from phi.api.agent import trigger_agent_run_creation, AgentRunCreate

try:
agent_session: AgentSession = self.agent_session or self.to_agent_session()
create_agent_run(
trigger_agent_run_creation(
run=AgentRunCreate(
run_id=run_id,
run_data=run_data,
Expand Down Expand Up @@ -1937,9 +1937,10 @@ def print_response(
table.add_row(f"Response\n({response_timer.elapsed:.1f}s)", response_content) # type: ignore
console.print(table)

async def async_print_response(
async def aprint_response(
self,
message: Optional[Union[List, Dict, str]] = None,
*,
messages: Optional[List[Union[Dict, Message]]] = None,
stream: bool = False,
markdown: bool = False,
Expand All @@ -1953,13 +1954,15 @@ async def async_print_response(
from rich.progress import Progress, SpinnerColumn, TextColumn
from rich.box import ROUNDED
from rich.markdown import Markdown
from rich.json import JSON

if markdown:
self.markdown = True

if self.response_model is not None:
markdown = False
self.markdown = False
stream = False

if stream:
_response_content = ""
Expand All @@ -1969,8 +1972,8 @@ async def async_print_response(
response_timer = Timer()
response_timer.start()
async for resp in await self.arun(message=message, messages=messages, stream=True, **kwargs): # type: ignore #TODO: Review this
if isinstance(resp, str):
_response_content += resp
if isinstance(resp, RunResponse) and isinstance(resp.content, str):
_response_content += resp.content
response_content = Markdown(_response_content) if self.markdown else _response_content

table = Table(box=ROUNDED, border_style="blue", show_header=False)
Expand All @@ -1992,12 +1995,23 @@ async def async_print_response(

response_timer.stop()
response_content = ""
if isinstance(run_response, RunResponse) and isinstance(run_response.content, str):
response_content = (
Markdown(run_response.content)
if self.markdown
else self.convert_response_to_string(run_response.content)
)
if isinstance(run_response, RunResponse):
if isinstance(run_response.content, str):
response_content = (
Markdown(run_response.content)
if self.markdown
else self.convert_response_to_string(run_response.content)
)
elif self.response_model is not None and isinstance(run_response.content, BaseModel):
try:
response_content = JSON(run_response.content.model_dump_json(exclude_none=True), indent=2)
except Exception as e:
logger.warning(f"Failed to convert response to Markdown: {e}")
else:
try:
response_content = JSON(json.dumps(run_response.content), indent=4)
except Exception as e:
logger.warning(f"Failed to convert response to string: {e}")

table = Table(box=ROUNDED, border_style="blue", show_header=False)
if message and show_message:
Expand Down
77 changes: 57 additions & 20 deletions phi/api/agent.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from os import getenv
from typing import Union, Dict, List

Expand All @@ -11,67 +12,103 @@
from phi.utils.log import logger


def create_agent_session(session: AgentSessionCreate) -> bool:
async def create_agent_session(session: AgentSessionCreate) -> None:
if not phi_cli_settings.api_enabled:
return True
return

phi_api_key = getenv(PHI_API_KEY_ENV_VAR)
if phi_api_key is None:
logger.warning(f"{PHI_API_KEY_ENV_VAR} not set. You can get one from https://phidata.app")
return False
return

logger.debug("--**-- Creating Agent Session")
with api.AuthenticatedClient() as api_client:
logger.debug("--**-- Logging Agent Session")
async with api.AuthenticatedAsyncClient() as api_client:
try:
r: Response = api_client.post(
r: Response = await api_client.post(
ApiRoutes.AGENT_SESSION_CREATE,
headers={
"Authorization": f"Bearer {phi_api_key}",
},
json={"session": session.model_dump(exclude_none=True)},
)
if invalid_response(r):
return False
logger.debug(f"Invalid response: {r.status_code}, {r.text}")
return

response_json: Union[Dict, List] = r.json()
if response_json is None:
return False
return

logger.debug(f"Response: {response_json}")
return True
return
except Exception as e:
logger.debug(f"Could not create Agent session: {e}")
return False
return


def create_agent_run(run: AgentRunCreate) -> bool:
def trigger_agent_session_creation(session: AgentSessionCreate) -> None:
try:
# Get the current event loop if it exists
loop = asyncio.get_running_loop()
except RuntimeError:
# If no loop is found, create a new one
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)

if loop.is_running():
# Schedule the coroutine within the running loop
asyncio.create_task(create_agent_session(session))
else:
# Create a new event loop to run the task
loop.run_until_complete(create_agent_session(session))


async def create_agent_run(run: AgentRunCreate) -> None:
if not phi_cli_settings.api_enabled:
return True
return

phi_api_key = getenv(PHI_API_KEY_ENV_VAR)
if phi_api_key is None:
logger.warning(f"{PHI_API_KEY_ENV_VAR} not set. You can get one from https://phidata.app")
return False
return

logger.debug("--**-- Creating Agent Run")
with api.AuthenticatedClient() as api_client:
logger.debug("--**-- Logging Agent Run")
async with api.AuthenticatedAsyncClient() as api_client:
try:
r: Response = api_client.post(
r: Response = await api_client.post(
ApiRoutes.AGENT_RUN_CREATE,
headers={
"Authorization": f"Bearer {phi_api_key}",
},
json={"run": run.model_dump(exclude_none=True)},
)
if invalid_response(r):
return False
logger.debug(f"Invalid response: {r.status_code}, {r.text}")
return

response_json: Union[Dict, List] = r.json()
if response_json is None:
return False
return

logger.debug(f"Response: {response_json}")
return True
return
except Exception as e:
logger.debug(f"Could not create Agent run: {e}")
return False
return


def trigger_agent_run_creation(run: AgentRunCreate) -> None:
try:
# Get the current event loop if it exists
loop = asyncio.get_running_loop()
except RuntimeError:
# If no loop is found, create a new one
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)

if loop.is_running():
# Schedule the coroutine within the running loop
asyncio.create_task(create_agent_run(run))
else:
# Create a new event loop to run the task
loop.run_until_complete(create_agent_run(run))
8 changes: 4 additions & 4 deletions phi/model/openai/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,7 @@ async def aresponse(self, messages: List[Message]) -> ModelResponse:
Returns:
ModelResponse: The model response from the API.
"""
logger.debug("---------- OpenAI Response Start ----------")
logger.debug("---------- Async OpenAI Response Start ----------")
self._log_messages(messages)
model_response = ModelResponse()

Expand Down Expand Up @@ -595,7 +595,7 @@ async def aresponse(self, messages: List[Message]) -> ModelResponse:
if assistant_message.content is not None:
model_response.content = assistant_message.get_content_string()

logger.debug("---------- OpenAI Async Response End ----------")
logger.debug("---------- Async OpenAI Response End ----------")
return model_response

def _update_stream_metrics(self, stream_data: StreamData, assistant_message: Message):
Expand Down Expand Up @@ -758,7 +758,7 @@ async def aresponse_stream(self, messages: List[Message]) -> Any:
Returns:
Any: An asynchronous iterator of chat completion chunks.
"""
logger.debug("---------- OpenAI Async Response Start ----------")
logger.debug("---------- Async OpenAI Response Start ----------")
self._log_messages(messages)

stream_data: StreamData = StreamData()
Expand Down Expand Up @@ -846,7 +846,7 @@ async def aresponse_stream(self, messages: List[Message]) -> Any:

async for content in self.aresponse_stream(messages=messages):
yield content
logger.debug("---------- OpenAI Async Response End ----------")
logger.debug("---------- Async OpenAI Response End ----------")

def _build_tool_calls(self, tool_calls_data: List[ChoiceDeltaToolCall]) -> List[Dict[str, Any]]:
"""
Expand Down

0 comments on commit 10e12aa

Please sign in to comment.