Skip to content

Commit

Permalink
adds error handling for agent.get
Browse files Browse the repository at this point in the history
  • Loading branch information
BWMac committed Jan 23, 2025
1 parent 6cd5e5c commit 151b96b
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 1 deletion.
9 changes: 8 additions & 1 deletion synapseclient/models/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,7 @@ async def get_async(self, *, synapse_client: Optional[Synapse] = None) -> "Agent
The existing Agent object.
Example: Get and chat with an existing agent
Retrieve an existing agent by providing the agent's registration ID and calling `get()`.
Retrieve an existing custom agent by providing the agent's registration ID and calling `get_async()`.
Then, send a prompt to the agent.
import asyncio
Expand All @@ -661,6 +661,13 @@ async def main():
asyncio.run(main())
"""
if not self.registration_id:
raise ValueError(
"Registration ID is required to retrieve a custom agent. "
"If you are trying to use the baseline agent, you do not need to "
"use `get` or `get_async`. Instead, simply create an `Agent` object "
"and start prompting `my_agent = Agent(); my_agent.prompt(...)`.",
)
agent_response = await get_agent(
registration_id=self.registration_id,
synapse_client=synapse_client,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,13 @@ async def test_get(self) -> None:
expected_agent = self.get_test_agent()
assert agent == expected_agent

async def test_get_no_registration_id(self) -> None:
# GIVEN an Agent with no registration id
agent = Agent()
# WHEN I get the agent, I expect a ValueError to be raised
with pytest.raises(ValueError, match="Registration ID is required"):
await agent.get_async(synapse_client=self.syn)

async def test_start_session(self) -> None:
# GIVEN an Agent with a valid agent registration id
agent = Agent(registration_id=AGENT_REGISTRATION_ID)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,13 @@ def test_get(self) -> None:
expected_agent = self.get_test_agent()
assert agent == expected_agent

def test_get_no_registration_id(self) -> None:
# GIVEN an Agent with no registration id
agent = Agent()
# WHEN I get the agent, I expect a ValueError to be raised
with pytest.raises(ValueError, match="Registration ID is required"):
agent.get(synapse_client=self.syn)

def test_start_session(self) -> None:
# GIVEN an Agent with a valid agent registration id
agent = Agent(registration_id=AGENT_REGISTRATION_ID).get(
Expand Down

0 comments on commit 151b96b

Please sign in to comment.