Skip to content

Commit

Permalink
Implement AIAssistant.__init_subclass__, and adjust tests
Browse files Browse the repository at this point in the history
  • Loading branch information
pamella committed Jun 19, 2024
1 parent 9c944eb commit 984aba5
Show file tree
Hide file tree
Showing 5 changed files with 134 additions and 85 deletions.
1 change: 1 addition & 0 deletions django_ai_assistant/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from django_ai_assistant.helpers.assistants import ( # noqa
AIAssistant,
get_assistant_cls,
get_assistant_cls_registry,
)
from django_ai_assistant.langchain.tools import ( # noqa
Expand Down
43 changes: 36 additions & 7 deletions django_ai_assistant/helpers/assistants.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ class AIAssistant(abc.ABC): # noqa: F821
_init_kwargs: dict[str, Any]
_method_tools: Sequence[BaseTool]

_registry: ClassVar[dict[str, type["AIAssistant"]]] = {}

def __init__(self, *, user=None, request=None, view=None, **kwargs):
if not hasattr(self, "id"):
raise AIAssistantMisconfiguredError(
Expand All @@ -80,6 +82,27 @@ def __init__(self, *, user=None, request=None, view=None, **kwargs):

self._set_method_tools()

def __init_subclass__(cls, **kwargs):
"""
Called when a class is subclassed from AIAssistant.
This method is automatically invoked when a new subclass of AIAssistant
is created. It allows AIAssistant to perform additional setup or configuration
for the subclass, such as registering the subclass in a registry.
Args:
cls (type): The newly created subclass.
**kwargs: Additional keyword arguments passed during subclass creation.
"""
super().__init_subclass__(**kwargs)
if hasattr(cls, "id") and cls.id is not None:
if not re.match(r"^[a-zA-Z0-9_-]+$", cls.id):
raise AIAssistantMisconfiguredError(
f"Assistant id '{cls.id}' does not match the pattern '^[a-zA-Z0-9_-]+$'"
f"at {cls.__name__}"
)
cls._registry[cls.id] = cls

def _set_method_tools(self):
# Find tool methods (decorated with `@method_tool` from django_ai_assistant/tools.py):
members = inspect.getmembers(
Expand Down Expand Up @@ -114,12 +137,12 @@ def _set_method_tools(self):
self._method_tools = tools

@classmethod
def _get_assistant_cls_registry(cls: type["AIAssistant"]) -> dict[str, type["AIAssistant"]]:
registry: dict[str, type["AIAssistant"]] = {}
for subclass in cls.__subclasses__():
registry[subclass.id] = subclass
registry.update(subclass._get_assistant_cls_registry())
return registry
def get_registry(cls):
return cls._registry

@classmethod
def clear_registry(cls):
cls._registry.clear()

def get_name(self):
return self.name
Expand Down Expand Up @@ -306,4 +329,10 @@ def as_tool(self, description) -> BaseTool:


def get_assistant_cls_registry() -> dict[str, type[AIAssistant]]:
return AIAssistant._get_assistant_cls_registry()
"""Get the registry of AIAssistant classes."""
return AIAssistant.get_registry()


def get_assistant_cls(assistant_id: str) -> type[AIAssistant]:
"""Get the AIAssistant class for the given assistant ID."""
return AIAssistant.get_registry()[assistant_id]
8 changes: 4 additions & 4 deletions django_ai_assistant/helpers/use_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
AIAssistantNotDefinedError,
AIUserNotAllowedError,
)
from django_ai_assistant.helpers.assistants import get_assistant_cls_registry
from django_ai_assistant.helpers import assistants
from django_ai_assistant.langchain.chat_message_histories import DjangoChatMessageHistory
from django_ai_assistant.models import Message, Thread
from django_ai_assistant.permissions import (
Expand All @@ -25,9 +25,9 @@ def get_assistant_cls(
user: Any,
request: HttpRequest | None = None,
):
if assistant_id not in get_assistant_cls_registry():
if assistant_id not in assistants.get_assistant_cls_registry():
raise AIAssistantNotDefinedError(f"Assistant with id={assistant_id} not found")
assistant_cls = get_assistant_cls_registry()[assistant_id]
assistant_cls = assistants.get_assistant_cls(assistant_id)
if not can_run_assistant(
assistant_cls=assistant_cls,
user=user,
Expand Down Expand Up @@ -56,7 +56,7 @@ def get_assistants_info(
):
return [
get_assistant_cls(assistant_id=assistant_id, user=user, request=request)
for assistant_id in get_assistant_cls_registry().keys()
for assistant_id in assistants.get_assistant_cls_registry().keys()
]


Expand Down
111 changes: 60 additions & 51 deletions tests/test_assistants.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,41 +5,80 @@
from langchain_core.messages import AIMessage, HumanMessage, messages_to_dict
from langchain_core.retrievers import BaseRetriever

from django_ai_assistant.helpers.assistants import AIAssistant
from django_ai_assistant.helpers.assistants import AIAssistant, get_assistant_cls
from django_ai_assistant.langchain.tools import BaseModel, Field, method_tool
from django_ai_assistant.models import Thread


class TemperatureAssistant(AIAssistant):
id = "temperature_assistant" # noqa: A003
name = "Temperature Assistant"
instructions = "You are a temperature bot."
model = "gpt-4o"
@pytest.fixture(scope="module", autouse=True)
def setup_assistants():
# Clear the registry before the tests in the module
AIAssistant.clear_registry()

def get_instructions(self):
return self.instructions + " Today is 2024-06-09."
# Define the assistant class inside the fixture to ensure registration
class TemperatureAssistant(AIAssistant):
id = "temperature_assistant" # noqa: A003
name = "Temperature Assistant"
instructions = "You are a temperature bot."
model = "gpt-4o"

def get_instructions(self):
return self.instructions + " Today is 2024-06-09."

@method_tool
def fetch_current_temperature(self, location: str) -> str:
"""Fetch the current temperature data for a location"""
return "32 degrees Celsius"
@method_tool
def fetch_current_temperature(self, location: str) -> str:
"""Fetch the current temperature data for a location"""
return "32 degrees Celsius"

class FetchForecastTemperatureInput(BaseModel):
location: str
dt_str: str = Field(description="Date in the format 'YYYY-MM-DD'")

@method_tool(args_schema=FetchForecastTemperatureInput)
def fetch_forecast_temperature(self, location: str, dt_str: str) -> str:
"""Fetch the forecast temperature data for a location"""
return "35 degrees Celsius"

class TourGuideAssistant(AIAssistant):
id = "tour_guide_assistant" # noqa: A003
name = "Tour Guide Assistant"
instructions = (
"You are a tour guide assistant offers information about nearby attractions. "
"The user is at a location and wants to know what to learn about nearby attractions. "
"Use the following pieces of context to suggest nearby attractions to the user. "
"If there are no interesting attractions nearby, "
"tell the user there's nothing to see where they're at. "
"Use three sentences maximum and keep your suggestions concise."
"\n\n"
"---START OF CONTEXT---\n"
"{context}"
"---END OF CONTEXT---\n"
)
model = "gpt-4o"
has_rag = True

class FetchForecastTemperatureInput(BaseModel):
location: str
dt_str: str = Field(description="Date in the format 'YYYY-MM-DD'")
def get_retriever(self) -> BaseRetriever:
return SequentialRetriever(
sequential_responses=[
[
Document(page_content="Central Park"),
Document(page_content="American Museum of Natural History"),
],
[Document(page_content="Museum of Modern Art")],
]
)

@method_tool(args_schema=FetchForecastTemperatureInput)
def fetch_forecast_temperature(self, location: str, dt_str: str) -> str:
"""Fetch the forecast temperature data for a location"""
return "35 degrees Celsius"
yield
# Clear the registry after the tests in the module
AIAssistant.clear_registry()


@pytest.mark.django_db(transaction=True)
@pytest.mark.vcr
def test_AIAssistant_invoke():
thread = Thread.objects.create(name="Recife Temperature Chat")

assistant = TemperatureAssistant()
assistant = get_assistant_cls("temperature_assistant")()
response_0 = assistant.invoke(
{"input": "What is the temperature today in Recife?"},
thread_id=thread.id,
Expand Down Expand Up @@ -101,42 +140,12 @@ async def _aget_relevant_documents(self, query: str) -> List[Document]:
return self._get_relevant_documents(query)


class TourGuideAssistant(AIAssistant):
id = "tour_guide_assistant" # noqa: A003
name = "Tour Guide Assistant"
instructions = (
"You are a tour guide assistant offers information about nearby attractions. "
"The user is at a location and wants to know what to learn about nearby attractions. "
"Use the following pieces of context to suggest nearby attractions to the user. "
"If there are no interesting attractions nearby, "
"tell the user there's nothing to see where they're at. "
"Use three sentences maximum and keep your suggestions concise."
"\n\n"
"---START OF CONTEXT---\n"
"{context}"
"---END OF CONTEXT---\n"
)
model = "gpt-4o"
has_rag = True

def get_retriever(self) -> BaseRetriever:
return SequentialRetriever(
sequential_responses=[
[
Document(page_content="Central Park"),
Document(page_content="American Museum of Natural History"),
],
[Document(page_content="Museum of Modern Art")],
]
)


@pytest.mark.django_db(transaction=True)
@pytest.mark.vcr
def test_AIAssistant_with_rag_invoke():
thread = Thread.objects.create(name="Tour Guide Chat")

assistant = TourGuideAssistant()
assistant = get_assistant_cls("tour_guide_assistant")()
response_0 = assistant.invoke(
{"input": "I'm at Central Park W & 79st, New York, NY 10024, United States."},
thread_id=thread.id,
Expand Down
56 changes: 33 additions & 23 deletions tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,29 +10,39 @@
# Set up


class TemperatureAssistant(AIAssistant):
id = "temperature_assistant" # noqa: A003
name = "Temperature Assistant"
description = "A temperature assistant that provides temperature information."
instructions = "You are a temperature bot."
model = "gpt-4o"

def get_instructions(self):
return self.instructions + " Today is 2024-06-09."

@method_tool
def fetch_current_temperature(self, location: str) -> str:
"""Fetch the current temperature data for a location"""
return "32 degrees Celsius"

class FetchForecastTemperatureInput(BaseModel):
location: str
dt_str: str = Field(description="Date in the format 'YYYY-MM-DD'")

@method_tool(args_schema=FetchForecastTemperatureInput)
def fetch_forecast_temperature(self, location: str, dt_str: str) -> str:
"""Fetch the forecast temperature data for a location"""
return "35 degrees Celsius"
@pytest.fixture(scope="module", autouse=True)
def setup_assistants():
# Clear the registry before the tests in the module
AIAssistant.clear_registry()

# Define the assistant class inside the fixture to ensure registration
class TemperatureAssistant(AIAssistant):
id = "temperature_assistant" # noqa: A003
name = "Temperature Assistant"
description = "A temperature assistant that provides temperature information."
instructions = "You are a temperature bot."
model = "gpt-4o"

def get_instructions(self):
return self.instructions + " Today is 2024-06-09."

@method_tool
def fetch_current_temperature(self, location: str) -> str:
"""Fetch the current temperature data for a location"""
return "32 degrees Celsius"

class FetchForecastTemperatureInput(BaseModel):
location: str
dt_str: str = Field(description="Date in the format 'YYYY-MM-DD'")

@method_tool(args_schema=FetchForecastTemperatureInput)
def fetch_forecast_temperature(self, location: str, dt_str: str) -> str:
"""Fetch the forecast temperature data for a location"""
return "35 degrees Celsius"

yield
# Clear the registry after the tests in the module
AIAssistant.clear_registry()


# Assistant Views
Expand Down

0 comments on commit 984aba5

Please sign in to comment.