Skip to content

Commit

Permalink
test: add test for agent async fn to increase coverage (#16331)
Browse files Browse the repository at this point in the history
-test: add unit test for async fn custom agent to increase coverage
  • Loading branch information
jzhao62 authored Oct 2, 2024
1 parent 0491ab7 commit 2afa325
Showing 1 changed file with 40 additions and 2 deletions.
42 changes: 40 additions & 2 deletions llama-index-core/tests/agent/custom/test_simple_function.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
"""Test simple function agent."""

from typing import Any, Dict, Tuple

from llama_index.core.agent.custom.simple_function import FnAgentWorker
import pytest
from llama_index.core.agent.custom.simple_function import FnAgentWorker


def mock_foo_fn_no_state_param() -> Tuple[None, bool]:
"""Mock agent input function without a state."""
return None, True


def mock_foo_fn(state: dict) -> Tuple[Dict[str, Any], bool]:
Expand All @@ -25,6 +29,11 @@ def mock_foo_fn(state: dict) -> Tuple[Dict[str, Any], bool]:
return state, is_done


async def async_mock_foo_fn(state: dict) -> Tuple[Dict[str, Any], bool]:
"""Mock async agent input function."""
return mock_foo_fn(state)


def test_fn_agent() -> None:
"""Test function agent."""
agent = FnAgentWorker(fn=mock_foo_fn, initial_state={"max_count": 5}).as_agent()
Expand All @@ -34,3 +43,32 @@ def test_fn_agent() -> None:
with pytest.raises(ValueError):
agent = FnAgentWorker(fn=mock_foo_fn).as_agent()
response = agent.query("hello")


def test_fn_agent_init() -> None:
"""Test function agent init."""
with pytest.raises(ValueError) as error_info:
FnAgentWorker(fn=mock_foo_fn_no_state_param).as_agent()

assert (
str(error_info.value)
== "StatefulFnComponent must have 'state' as required parameters"
)


@pytest.mark.asyncio()
async def test_run_async_step() -> None:
"""Test run async step."""
agent_without_async_fn = FnAgentWorker(
fn=mock_foo_fn, asnc_fn=None, initial_state={"max_count": 5}
).as_agent()

response = await agent_without_async_fn.aquery("hello")
assert str(response) == "hello:foo:foo:foo:foo:foo"

agent_with_async_fn = FnAgentWorker(
fn=mock_foo_fn, async_fn=async_mock_foo_fn, initial_state={"max_count": 5}
).as_agent()

response = await agent_with_async_fn.aquery("hello")
assert str(response) == "hello:foo:foo:foo:foo:foo"

0 comments on commit 2afa325

Please sign in to comment.