From 984aba592861ba657a8ce20c6d0ebd617d31249c Mon Sep 17 00:00:00 2001
From: Pamella Bezerra <pamella@vinta.com.br>
Date: Wed, 19 Jun 2024 10:11:54 -0300
Subject: [PATCH] Implement AIAssistant.__init_subclass__, and adjust tests

---
 django_ai_assistant/__init__.py           |   1 +
 django_ai_assistant/helpers/assistants.py |  43 +++++++--
 django_ai_assistant/helpers/use_cases.py  |   8 +-
 tests/test_assistants.py                  | 111 ++++++++++++----------
 tests/test_views.py                       |  56 ++++++-----
 5 files changed, 134 insertions(+), 85 deletions(-)

diff --git a/django_ai_assistant/__init__.py b/django_ai_assistant/__init__.py
index 98f641e..1b44f39 100644
--- a/django_ai_assistant/__init__.py
+++ b/django_ai_assistant/__init__.py
@@ -2,6 +2,7 @@
 
 from django_ai_assistant.helpers.assistants import (  # noqa
     AIAssistant,
+    get_assistant_cls,
     get_assistant_cls_registry,
 )
 from django_ai_assistant.langchain.tools import (  # noqa
diff --git a/django_ai_assistant/helpers/assistants.py b/django_ai_assistant/helpers/assistants.py
index 910d528..ff9e7a1 100644
--- a/django_ai_assistant/helpers/assistants.py
+++ b/django_ai_assistant/helpers/assistants.py
@@ -55,6 +55,8 @@ class AIAssistant(abc.ABC):  # noqa: F821
     _init_kwargs: dict[str, Any]
     _method_tools: Sequence[BaseTool]
 
+    _registry: ClassVar[dict[str, type["AIAssistant"]]] = {}
+
     def __init__(self, *, user=None, request=None, view=None, **kwargs):
         if not hasattr(self, "id"):
             raise AIAssistantMisconfiguredError(
@@ -80,6 +82,27 @@ def __init__(self, *, user=None, request=None, view=None, **kwargs):
 
         self._set_method_tools()
 
+    def __init_subclass__(cls, **kwargs):
+        """
+        Called when a class is subclassed from AIAssistant.
+
+        This method is automatically invoked when a new subclass of AIAssistant
+        is created. It allows AIAssistant to perform additional setup or configuration
+        for the subclass, such as registering the subclass in a registry.
+
+        Args:
+            cls (type): The newly created subclass.
+            **kwargs: Additional keyword arguments passed during subclass creation.
+        """
+        super().__init_subclass__(**kwargs)
+        if hasattr(cls, "id") and cls.id is not None:
+            if not re.match(r"^[a-zA-Z0-9_-]+$", cls.id):
+                raise AIAssistantMisconfiguredError(
+                    f"Assistant id '{cls.id}' does not match the pattern '^[a-zA-Z0-9_-]+$'"
+                    f"at {cls.__name__}"
+                )
+            cls._registry[cls.id] = cls
+
     def _set_method_tools(self):
         # Find tool methods (decorated with `@method_tool` from django_ai_assistant/tools.py):
         members = inspect.getmembers(
@@ -114,12 +137,12 @@ def _set_method_tools(self):
         self._method_tools = tools
 
     @classmethod
-    def _get_assistant_cls_registry(cls: type["AIAssistant"]) -> dict[str, type["AIAssistant"]]:
-        registry: dict[str, type["AIAssistant"]] = {}
-        for subclass in cls.__subclasses__():
-            registry[subclass.id] = subclass
-            registry.update(subclass._get_assistant_cls_registry())
-        return registry
+    def get_registry(cls):
+        return cls._registry
+
+    @classmethod
+    def clear_registry(cls):
+        cls._registry.clear()
 
     def get_name(self):
         return self.name
@@ -306,4 +329,10 @@ def as_tool(self, description) -> BaseTool:
 
 
 def get_assistant_cls_registry() -> dict[str, type[AIAssistant]]:
-    return AIAssistant._get_assistant_cls_registry()
+    """Get the registry of AIAssistant classes."""
+    return AIAssistant.get_registry()
+
+
+def get_assistant_cls(assistant_id: str) -> type[AIAssistant]:
+    """Get the AIAssistant class for the given assistant ID."""
+    return AIAssistant.get_registry()[assistant_id]
diff --git a/django_ai_assistant/helpers/use_cases.py b/django_ai_assistant/helpers/use_cases.py
index 8522b46..f383791 100644
--- a/django_ai_assistant/helpers/use_cases.py
+++ b/django_ai_assistant/helpers/use_cases.py
@@ -8,7 +8,7 @@
     AIAssistantNotDefinedError,
     AIUserNotAllowedError,
 )
-from django_ai_assistant.helpers.assistants import get_assistant_cls_registry
+from django_ai_assistant.helpers import assistants
 from django_ai_assistant.langchain.chat_message_histories import DjangoChatMessageHistory
 from django_ai_assistant.models import Message, Thread
 from django_ai_assistant.permissions import (
@@ -25,9 +25,9 @@ def get_assistant_cls(
     user: Any,
     request: HttpRequest | None = None,
 ):
-    if assistant_id not in get_assistant_cls_registry():
+    if assistant_id not in assistants.get_assistant_cls_registry():
         raise AIAssistantNotDefinedError(f"Assistant with id={assistant_id} not found")
-    assistant_cls = get_assistant_cls_registry()[assistant_id]
+    assistant_cls = assistants.get_assistant_cls(assistant_id)
     if not can_run_assistant(
         assistant_cls=assistant_cls,
         user=user,
@@ -56,7 +56,7 @@ def get_assistants_info(
 ):
     return [
         get_assistant_cls(assistant_id=assistant_id, user=user, request=request)
-        for assistant_id in get_assistant_cls_registry().keys()
+        for assistant_id in assistants.get_assistant_cls_registry().keys()
     ]
 
 
diff --git a/tests/test_assistants.py b/tests/test_assistants.py
index d634525..d7d4283 100644
--- a/tests/test_assistants.py
+++ b/tests/test_assistants.py
@@ -5,33 +5,72 @@
 from langchain_core.messages import AIMessage, HumanMessage, messages_to_dict
 from langchain_core.retrievers import BaseRetriever
 
-from django_ai_assistant.helpers.assistants import AIAssistant
+from django_ai_assistant.helpers.assistants import AIAssistant, get_assistant_cls
 from django_ai_assistant.langchain.tools import BaseModel, Field, method_tool
 from django_ai_assistant.models import Thread
 
 
-class TemperatureAssistant(AIAssistant):
-    id = "temperature_assistant"  # noqa: A003
-    name = "Temperature Assistant"
-    instructions = "You are a temperature bot."
-    model = "gpt-4o"
+@pytest.fixture(scope="module", autouse=True)
+def setup_assistants():
+    # Clear the registry before the tests in the module
+    AIAssistant.clear_registry()
 
-    def get_instructions(self):
-        return self.instructions + " Today is 2024-06-09."
+    # Define the assistant class inside the fixture to ensure registration
+    class TemperatureAssistant(AIAssistant):
+        id = "temperature_assistant"  # noqa: A003
+        name = "Temperature Assistant"
+        instructions = "You are a temperature bot."
+        model = "gpt-4o"
+
+        def get_instructions(self):
+            return self.instructions + " Today is 2024-06-09."
 
-    @method_tool
-    def fetch_current_temperature(self, location: str) -> str:
-        """Fetch the current temperature data for a location"""
-        return "32 degrees Celsius"
+        @method_tool
+        def fetch_current_temperature(self, location: str) -> str:
+            """Fetch the current temperature data for a location"""
+            return "32 degrees Celsius"
+
+        class FetchForecastTemperatureInput(BaseModel):
+            location: str
+            dt_str: str = Field(description="Date in the format 'YYYY-MM-DD'")
+
+        @method_tool(args_schema=FetchForecastTemperatureInput)
+        def fetch_forecast_temperature(self, location: str, dt_str: str) -> str:
+            """Fetch the forecast temperature data for a location"""
+            return "35 degrees Celsius"
+
+    class TourGuideAssistant(AIAssistant):
+        id = "tour_guide_assistant"  # noqa: A003
+        name = "Tour Guide Assistant"
+        instructions = (
+            "You are a tour guide assistant offers information about nearby attractions. "
+            "The user is at a location and wants to know what to learn about nearby attractions. "
+            "Use the following pieces of context to suggest nearby attractions to the user. "
+            "If there are no interesting attractions nearby, "
+            "tell the user there's nothing to see where they're at. "
+            "Use three sentences maximum and keep your suggestions concise."
+            "\n\n"
+            "---START OF CONTEXT---\n"
+            "{context}"
+            "---END OF CONTEXT---\n"
+        )
+        model = "gpt-4o"
+        has_rag = True
 
-    class FetchForecastTemperatureInput(BaseModel):
-        location: str
-        dt_str: str = Field(description="Date in the format 'YYYY-MM-DD'")
+        def get_retriever(self) -> BaseRetriever:
+            return SequentialRetriever(
+                sequential_responses=[
+                    [
+                        Document(page_content="Central Park"),
+                        Document(page_content="American Museum of Natural History"),
+                    ],
+                    [Document(page_content="Museum of Modern Art")],
+                ]
+            )
 
-    @method_tool(args_schema=FetchForecastTemperatureInput)
-    def fetch_forecast_temperature(self, location: str, dt_str: str) -> str:
-        """Fetch the forecast temperature data for a location"""
-        return "35 degrees Celsius"
+    yield
+    # Clear the registry after the tests in the module
+    AIAssistant.clear_registry()
 
 
 @pytest.mark.django_db(transaction=True)
@@ -39,7 +78,7 @@ def fetch_forecast_temperature(self, location: str, dt_str: str) -> str:
 def test_AIAssistant_invoke():
     thread = Thread.objects.create(name="Recife Temperature Chat")
 
-    assistant = TemperatureAssistant()
+    assistant = get_assistant_cls("temperature_assistant")()
     response_0 = assistant.invoke(
         {"input": "What is the temperature today in Recife?"},
         thread_id=thread.id,
@@ -101,42 +140,12 @@ async def _aget_relevant_documents(self, query: str) -> List[Document]:
         return self._get_relevant_documents(query)
 
 
-class TourGuideAssistant(AIAssistant):
-    id = "tour_guide_assistant"  # noqa: A003
-    name = "Tour Guide Assistant"
-    instructions = (
-        "You are a tour guide assistant offers information about nearby attractions. "
-        "The user is at a location and wants to know what to learn about nearby attractions. "
-        "Use the following pieces of context to suggest nearby attractions to the user. "
-        "If there are no interesting attractions nearby, "
-        "tell the user there's nothing to see where they're at. "
-        "Use three sentences maximum and keep your suggestions concise."
-        "\n\n"
-        "---START OF CONTEXT---\n"
-        "{context}"
-        "---END OF CONTEXT---\n"
-    )
-    model = "gpt-4o"
-    has_rag = True
-
-    def get_retriever(self) -> BaseRetriever:
-        return SequentialRetriever(
-            sequential_responses=[
-                [
-                    Document(page_content="Central Park"),
-                    Document(page_content="American Museum of Natural History"),
-                ],
-                [Document(page_content="Museum of Modern Art")],
-            ]
-        )
-
-
 @pytest.mark.django_db(transaction=True)
 @pytest.mark.vcr
 def test_AIAssistant_with_rag_invoke():
     thread = Thread.objects.create(name="Tour Guide Chat")
 
-    assistant = TourGuideAssistant()
+    assistant = get_assistant_cls("tour_guide_assistant")()
     response_0 = assistant.invoke(
         {"input": "I'm at Central Park W & 79st, New York, NY 10024, United States."},
         thread_id=thread.id,
diff --git a/tests/test_views.py b/tests/test_views.py
index 31a5dd8..7cec6c5 100644
--- a/tests/test_views.py
+++ b/tests/test_views.py
@@ -10,29 +10,39 @@
 # Set up
 
 
-class TemperatureAssistant(AIAssistant):
-    id = "temperature_assistant"  # noqa: A003
-    name = "Temperature Assistant"
-    description = "A temperature assistant that provides temperature information."
-    instructions = "You are a temperature bot."
-    model = "gpt-4o"
-
-    def get_instructions(self):
-        return self.instructions + " Today is 2024-06-09."
-
-    @method_tool
-    def fetch_current_temperature(self, location: str) -> str:
-        """Fetch the current temperature data for a location"""
-        return "32 degrees Celsius"
-
-    class FetchForecastTemperatureInput(BaseModel):
-        location: str
-        dt_str: str = Field(description="Date in the format 'YYYY-MM-DD'")
-
-    @method_tool(args_schema=FetchForecastTemperatureInput)
-    def fetch_forecast_temperature(self, location: str, dt_str: str) -> str:
-        """Fetch the forecast temperature data for a location"""
-        return "35 degrees Celsius"
+@pytest.fixture(scope="module", autouse=True)
+def setup_assistants():
+    # Clear the registry before the tests in the module
+    AIAssistant.clear_registry()
+
+    # Define the assistant class inside the fixture to ensure registration
+    class TemperatureAssistant(AIAssistant):
+        id = "temperature_assistant"  # noqa: A003
+        name = "Temperature Assistant"
+        description = "A temperature assistant that provides temperature information."
+        instructions = "You are a temperature bot."
+        model = "gpt-4o"
+
+        def get_instructions(self):
+            return self.instructions + " Today is 2024-06-09."
+
+        @method_tool
+        def fetch_current_temperature(self, location: str) -> str:
+            """Fetch the current temperature data for a location"""
+            return "32 degrees Celsius"
+
+        class FetchForecastTemperatureInput(BaseModel):
+            location: str
+            dt_str: str = Field(description="Date in the format 'YYYY-MM-DD'")
+
+        @method_tool(args_schema=FetchForecastTemperatureInput)
+        def fetch_forecast_temperature(self, location: str, dt_str: str) -> str:
+            """Fetch the forecast temperature data for a location"""
+            return "35 degrees Celsius"
+
+    yield
+    # Clear the registry after the tests in the module
+    AIAssistant.clear_registry()
 
 
 # Assistant Views