From 984aba592861ba657a8ce20c6d0ebd617d31249c Mon Sep 17 00:00:00 2001 From: Pamella Bezerra <pamella@vinta.com.br> Date: Wed, 19 Jun 2024 10:11:54 -0300 Subject: [PATCH] Implement AIAssistant.__init_subclass__, and adjust tests --- django_ai_assistant/__init__.py | 1 + django_ai_assistant/helpers/assistants.py | 43 +++++++-- django_ai_assistant/helpers/use_cases.py | 8 +- tests/test_assistants.py | 111 ++++++++++++---------- tests/test_views.py | 56 ++++++----- 5 files changed, 134 insertions(+), 85 deletions(-) diff --git a/django_ai_assistant/__init__.py b/django_ai_assistant/__init__.py index 98f641e..1b44f39 100644 --- a/django_ai_assistant/__init__.py +++ b/django_ai_assistant/__init__.py @@ -2,6 +2,7 @@ from django_ai_assistant.helpers.assistants import ( # noqa AIAssistant, + get_assistant_cls, get_assistant_cls_registry, ) from django_ai_assistant.langchain.tools import ( # noqa diff --git a/django_ai_assistant/helpers/assistants.py b/django_ai_assistant/helpers/assistants.py index 910d528..ff9e7a1 100644 --- a/django_ai_assistant/helpers/assistants.py +++ b/django_ai_assistant/helpers/assistants.py @@ -55,6 +55,8 @@ class AIAssistant(abc.ABC): # noqa: F821 _init_kwargs: dict[str, Any] _method_tools: Sequence[BaseTool] + _registry: ClassVar[dict[str, type["AIAssistant"]]] = {} + def __init__(self, *, user=None, request=None, view=None, **kwargs): if not hasattr(self, "id"): raise AIAssistantMisconfiguredError( @@ -80,6 +82,27 @@ def __init__(self, *, user=None, request=None, view=None, **kwargs): self._set_method_tools() + def __init_subclass__(cls, **kwargs): + """ + Called when a class is subclassed from AIAssistant. + + This method is automatically invoked when a new subclass of AIAssistant + is created. It allows AIAssistant to perform additional setup or configuration + for the subclass, such as registering the subclass in a registry. + + Args: + cls (type): The newly created subclass. + **kwargs: Additional keyword arguments passed during subclass creation. + """ + super().__init_subclass__(**kwargs) + if hasattr(cls, "id") and cls.id is not None: + if not re.match(r"^[a-zA-Z0-9_-]+$", cls.id): + raise AIAssistantMisconfiguredError( + f"Assistant id '{cls.id}' does not match the pattern '^[a-zA-Z0-9_-]+$'" + f"at {cls.__name__}" + ) + cls._registry[cls.id] = cls + def _set_method_tools(self): # Find tool methods (decorated with `@method_tool` from django_ai_assistant/tools.py): members = inspect.getmembers( @@ -114,12 +137,12 @@ def _set_method_tools(self): self._method_tools = tools @classmethod - def _get_assistant_cls_registry(cls: type["AIAssistant"]) -> dict[str, type["AIAssistant"]]: - registry: dict[str, type["AIAssistant"]] = {} - for subclass in cls.__subclasses__(): - registry[subclass.id] = subclass - registry.update(subclass._get_assistant_cls_registry()) - return registry + def get_registry(cls): + return cls._registry + + @classmethod + def clear_registry(cls): + cls._registry.clear() def get_name(self): return self.name @@ -306,4 +329,10 @@ def as_tool(self, description) -> BaseTool: def get_assistant_cls_registry() -> dict[str, type[AIAssistant]]: - return AIAssistant._get_assistant_cls_registry() + """Get the registry of AIAssistant classes.""" + return AIAssistant.get_registry() + + +def get_assistant_cls(assistant_id: str) -> type[AIAssistant]: + """Get the AIAssistant class for the given assistant ID.""" + return AIAssistant.get_registry()[assistant_id] diff --git a/django_ai_assistant/helpers/use_cases.py b/django_ai_assistant/helpers/use_cases.py index 8522b46..f383791 100644 --- a/django_ai_assistant/helpers/use_cases.py +++ b/django_ai_assistant/helpers/use_cases.py @@ -8,7 +8,7 @@ AIAssistantNotDefinedError, AIUserNotAllowedError, ) -from django_ai_assistant.helpers.assistants import get_assistant_cls_registry +from django_ai_assistant.helpers import assistants from django_ai_assistant.langchain.chat_message_histories import DjangoChatMessageHistory from django_ai_assistant.models import Message, Thread from django_ai_assistant.permissions import ( @@ -25,9 +25,9 @@ def get_assistant_cls( user: Any, request: HttpRequest | None = None, ): - if assistant_id not in get_assistant_cls_registry(): + if assistant_id not in assistants.get_assistant_cls_registry(): raise AIAssistantNotDefinedError(f"Assistant with id={assistant_id} not found") - assistant_cls = get_assistant_cls_registry()[assistant_id] + assistant_cls = assistants.get_assistant_cls(assistant_id) if not can_run_assistant( assistant_cls=assistant_cls, user=user, @@ -56,7 +56,7 @@ def get_assistants_info( ): return [ get_assistant_cls(assistant_id=assistant_id, user=user, request=request) - for assistant_id in get_assistant_cls_registry().keys() + for assistant_id in assistants.get_assistant_cls_registry().keys() ] diff --git a/tests/test_assistants.py b/tests/test_assistants.py index d634525..d7d4283 100644 --- a/tests/test_assistants.py +++ b/tests/test_assistants.py @@ -5,33 +5,72 @@ from langchain_core.messages import AIMessage, HumanMessage, messages_to_dict from langchain_core.retrievers import BaseRetriever -from django_ai_assistant.helpers.assistants import AIAssistant +from django_ai_assistant.helpers.assistants import AIAssistant, get_assistant_cls from django_ai_assistant.langchain.tools import BaseModel, Field, method_tool from django_ai_assistant.models import Thread -class TemperatureAssistant(AIAssistant): - id = "temperature_assistant" # noqa: A003 - name = "Temperature Assistant" - instructions = "You are a temperature bot." - model = "gpt-4o" +@pytest.fixture(scope="module", autouse=True) +def setup_assistants(): + # Clear the registry before the tests in the module + AIAssistant.clear_registry() - def get_instructions(self): - return self.instructions + " Today is 2024-06-09." + # Define the assistant class inside the fixture to ensure registration + class TemperatureAssistant(AIAssistant): + id = "temperature_assistant" # noqa: A003 + name = "Temperature Assistant" + instructions = "You are a temperature bot." + model = "gpt-4o" + + def get_instructions(self): + return self.instructions + " Today is 2024-06-09." - @method_tool - def fetch_current_temperature(self, location: str) -> str: - """Fetch the current temperature data for a location""" - return "32 degrees Celsius" + @method_tool + def fetch_current_temperature(self, location: str) -> str: + """Fetch the current temperature data for a location""" + return "32 degrees Celsius" + + class FetchForecastTemperatureInput(BaseModel): + location: str + dt_str: str = Field(description="Date in the format 'YYYY-MM-DD'") + + @method_tool(args_schema=FetchForecastTemperatureInput) + def fetch_forecast_temperature(self, location: str, dt_str: str) -> str: + """Fetch the forecast temperature data for a location""" + return "35 degrees Celsius" + + class TourGuideAssistant(AIAssistant): + id = "tour_guide_assistant" # noqa: A003 + name = "Tour Guide Assistant" + instructions = ( + "You are a tour guide assistant offers information about nearby attractions. " + "The user is at a location and wants to know what to learn about nearby attractions. " + "Use the following pieces of context to suggest nearby attractions to the user. " + "If there are no interesting attractions nearby, " + "tell the user there's nothing to see where they're at. " + "Use three sentences maximum and keep your suggestions concise." + "\n\n" + "---START OF CONTEXT---\n" + "{context}" + "---END OF CONTEXT---\n" + ) + model = "gpt-4o" + has_rag = True - class FetchForecastTemperatureInput(BaseModel): - location: str - dt_str: str = Field(description="Date in the format 'YYYY-MM-DD'") + def get_retriever(self) -> BaseRetriever: + return SequentialRetriever( + sequential_responses=[ + [ + Document(page_content="Central Park"), + Document(page_content="American Museum of Natural History"), + ], + [Document(page_content="Museum of Modern Art")], + ] + ) - @method_tool(args_schema=FetchForecastTemperatureInput) - def fetch_forecast_temperature(self, location: str, dt_str: str) -> str: - """Fetch the forecast temperature data for a location""" - return "35 degrees Celsius" + yield + # Clear the registry after the tests in the module + AIAssistant.clear_registry() @pytest.mark.django_db(transaction=True) @@ -39,7 +78,7 @@ def fetch_forecast_temperature(self, location: str, dt_str: str) -> str: def test_AIAssistant_invoke(): thread = Thread.objects.create(name="Recife Temperature Chat") - assistant = TemperatureAssistant() + assistant = get_assistant_cls("temperature_assistant")() response_0 = assistant.invoke( {"input": "What is the temperature today in Recife?"}, thread_id=thread.id, @@ -101,42 +140,12 @@ async def _aget_relevant_documents(self, query: str) -> List[Document]: return self._get_relevant_documents(query) -class TourGuideAssistant(AIAssistant): - id = "tour_guide_assistant" # noqa: A003 - name = "Tour Guide Assistant" - instructions = ( - "You are a tour guide assistant offers information about nearby attractions. " - "The user is at a location and wants to know what to learn about nearby attractions. " - "Use the following pieces of context to suggest nearby attractions to the user. " - "If there are no interesting attractions nearby, " - "tell the user there's nothing to see where they're at. " - "Use three sentences maximum and keep your suggestions concise." - "\n\n" - "---START OF CONTEXT---\n" - "{context}" - "---END OF CONTEXT---\n" - ) - model = "gpt-4o" - has_rag = True - - def get_retriever(self) -> BaseRetriever: - return SequentialRetriever( - sequential_responses=[ - [ - Document(page_content="Central Park"), - Document(page_content="American Museum of Natural History"), - ], - [Document(page_content="Museum of Modern Art")], - ] - ) - - @pytest.mark.django_db(transaction=True) @pytest.mark.vcr def test_AIAssistant_with_rag_invoke(): thread = Thread.objects.create(name="Tour Guide Chat") - assistant = TourGuideAssistant() + assistant = get_assistant_cls("tour_guide_assistant")() response_0 = assistant.invoke( {"input": "I'm at Central Park W & 79st, New York, NY 10024, United States."}, thread_id=thread.id, diff --git a/tests/test_views.py b/tests/test_views.py index 31a5dd8..7cec6c5 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -10,29 +10,39 @@ # Set up -class TemperatureAssistant(AIAssistant): - id = "temperature_assistant" # noqa: A003 - name = "Temperature Assistant" - description = "A temperature assistant that provides temperature information." - instructions = "You are a temperature bot." - model = "gpt-4o" - - def get_instructions(self): - return self.instructions + " Today is 2024-06-09." - - @method_tool - def fetch_current_temperature(self, location: str) -> str: - """Fetch the current temperature data for a location""" - return "32 degrees Celsius" - - class FetchForecastTemperatureInput(BaseModel): - location: str - dt_str: str = Field(description="Date in the format 'YYYY-MM-DD'") - - @method_tool(args_schema=FetchForecastTemperatureInput) - def fetch_forecast_temperature(self, location: str, dt_str: str) -> str: - """Fetch the forecast temperature data for a location""" - return "35 degrees Celsius" +@pytest.fixture(scope="module", autouse=True) +def setup_assistants(): + # Clear the registry before the tests in the module + AIAssistant.clear_registry() + + # Define the assistant class inside the fixture to ensure registration + class TemperatureAssistant(AIAssistant): + id = "temperature_assistant" # noqa: A003 + name = "Temperature Assistant" + description = "A temperature assistant that provides temperature information." + instructions = "You are a temperature bot." + model = "gpt-4o" + + def get_instructions(self): + return self.instructions + " Today is 2024-06-09." + + @method_tool + def fetch_current_temperature(self, location: str) -> str: + """Fetch the current temperature data for a location""" + return "32 degrees Celsius" + + class FetchForecastTemperatureInput(BaseModel): + location: str + dt_str: str = Field(description="Date in the format 'YYYY-MM-DD'") + + @method_tool(args_schema=FetchForecastTemperatureInput) + def fetch_forecast_temperature(self, location: str, dt_str: str) -> str: + """Fetch the forecast temperature data for a location""" + return "35 degrees Celsius" + + yield + # Clear the registry after the tests in the module + AIAssistant.clear_registry() # Assistant Views