Skip to content

Commit

Permalink
Sort tool methods by the order they appear in the source code since t…
Browse files Browse the repository at this point in the history
…his can be meaningful
  • Loading branch information
fjsj committed Jun 17, 2024
1 parent ce8fe42 commit a2722a9
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 13 deletions.
7 changes: 6 additions & 1 deletion django_ai_assistant/helpers/assistants.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,15 @@ def __init__(self, *, user=None, request=None, view=None, **kwargs):
def _set_method_tools(self):
# Find tool methods (decorated with `@method_tool` from django_ai_assistant/tools.py):
members = inspect.getmembers(
self, predicate=lambda m: hasattr(m, "_is_tool") and m._is_tool
self,
predicate=lambda m: inspect.ismethod(m) and getattr(m, "_is_tool", False),
)
tool_methods = [m for _, m in members]

# Sort tool methods by the order they appear in the source code,
# since this can be meaningful:
tool_methods.sort(key=lambda m: inspect.getsourcelines(m)[1])

# Transform tool methods into tool objects:
tools = []
for method in tool_methods:
Expand Down
15 changes: 6 additions & 9 deletions example/movies/ai_assistants.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def get_instructions(self):
# See: https://docs.djangoproject.com/en/5.0/topics/i18n/timezones/#default-time-zone-and-current-time-zone
# In a real application, you should use the user's timezone
current_date_str = timezone.now().date().isoformat()
user_backlog_str = self._get_movies_backlog()
user_backlog_str = self.get_movies_backlog()

return "\n".join(
[
Expand Down Expand Up @@ -115,7 +115,10 @@ def firecrawl_scrape_url(self, url: str) -> str:
)
return response["markdown"]

def _get_movies_backlog(self) -> str:
@method_tool
def get_movies_backlog(self) -> str:
"""Get what movies are on user's backlog."""

return (
"\n".join(
[
Expand All @@ -126,12 +129,6 @@ def _get_movies_backlog(self) -> str:
or "Empty"
)

@method_tool
def get_movies_backlog(self) -> str:
"""Get what movies are on user's backlog."""

return self._get_movies_backlog()

@method_tool
def add_movie_to_backlog(self, movie_name: str, imdb_url: str, imdb_rating: float) -> str:
"""Add a movie to user's backlog. Must pass the movie_name, imdb_url, and imdb_rating."""
Expand Down Expand Up @@ -165,4 +162,4 @@ def reorder_backlog(self, imdb_url_list: Sequence[str]) -> str:
"""Reorder movies in user's backlog."""

MovieBacklogItem.reorder_backlog(self._user, imdb_url_list)
return "Reordered movies in backlog. New backlog: \n" + self._get_movies_backlog()
return "Reordered movies in backlog. New backlog: \n" + self.get_movies_backlog()
1 change: 0 additions & 1 deletion example/rag/ai_assistants.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
class DjangoDocsAssistant(AIAssistant):
id = "django_docs_assistant" # noqa: A003
name = "Django Docs Assistant"
description = "An assistant that answers questions related to Django web framework."
instructions = (
"You are an assistant for answering questions related to Django web framework. "
"Use the following pieces of retrieved context from Django's documentation to answer "
Expand Down
36 changes: 34 additions & 2 deletions tests/test_assistants.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
class TemperatureAssistant(AIAssistant):
id = "temperature_assistant" # noqa: A003
name = "Temperature Assistant"
description = "A temperature assistant that provides temperature information."
instructions = "You are a temperature bot."
model = "gpt-4o"

Expand Down Expand Up @@ -105,7 +104,6 @@ async def _aget_relevant_documents(self, query: str) -> List[Document]:
class TourGuideAssistant(AIAssistant):
id = "tour_guide_assistant" # noqa: A003
name = "Tour Guide Assistant"
description = "A tour guide assistant that offers information about nearby attractions."
instructions = (
"You are a tour guide assistant offers information about nearby attractions. "
"The user is at a location and wants to know what to learn about nearby attractions. "
Expand Down Expand Up @@ -182,3 +180,37 @@ def test_AIAssistant_with_rag_invoke():
AIMessage(content=response_1["output"], id=messages_ids[3]),
]
)


@pytest.mark.django_db(transaction=True)
def test_AIAssistant_tool_order_same_as_declaration():
class FooAssistant(AIAssistant):
id = "foo_assistant" # noqa: A003
name = "Foo Assistant"
instructions = "You are a helpful assistant."
model = "gpt-4o"

@method_tool
def tool_d(self, foo: str, bar: float, baz: int, qux: str) -> str:
"""Tool D"""
return "DDD"

@method_tool
def tool_c(self, foo: str, bar: float, baz: int) -> str:
"""Tool C"""
return "CCC"

@method_tool
def tool_b(self, foo: str, bar: float) -> str:
"""Tool B"""
return "BBB"

@method_tool
def tool_a(self, foo: str) -> str:
"""Tool A"""
return "AAA"

assistant = FooAssistant()

assert hasattr(assistant, "_method_tools")
assert [t.name for t in assistant._method_tools] == ["tool_d", "tool_c", "tool_b", "tool_a"]

0 comments on commit a2722a9

Please sign in to comment.