Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds tests for use_cases.py #102

Merged
merged 7 commits into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions django_ai_assistant/helpers/use_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
)


def get_cls(
def get_assistant_cls(
assistant_id: str,
user: Any,
request: HttpRequest | None = None,
Expand All @@ -43,7 +43,7 @@ def get_single_assistant_info(
user: Any,
request: HttpRequest | None = None,
):
assistant_cls = get_cls(assistant_id, user, request)
assistant_cls = get_assistant_cls(assistant_id, user, request)

return {
"id": assistant_id,
Expand All @@ -56,7 +56,7 @@ def get_assistants_info(
request: HttpRequest | None = None,
):
return [
get_cls(assistant_id=assistant_id, user=user, request=request)
get_assistant_cls(assistant_id=assistant_id, user=user, request=request)
for assistant_id in AIAssistant.get_cls_registry().keys()
]

Expand All @@ -68,7 +68,7 @@ def create_message(
content: Any,
request: HttpRequest | None = None,
):
assistant_cls = get_cls(assistant_id, user, request)
assistant_cls = get_assistant_cls(assistant_id, user, request)

if not can_create_message(thread=thread, user=user, request=request):
raise AIUserNotAllowedError("User is not allowed to create messages in this thread")
Expand Down
Empty file added tests/test_helpers/__init__.py
Empty file.
186 changes: 186 additions & 0 deletions tests/test_helpers/cassettes/test_use_cases/test_create_message.yaml

Large diffs are not rendered by default.

File renamed without changes.
344 changes: 344 additions & 0 deletions tests/test_helpers/test_use_cases.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,344 @@
from django.contrib.auth.models import User

import pytest
from model_bakery import baker

from django_ai_assistant.exceptions import (
AIAssistantNotDefinedError,
AIUserNotAllowedError,
)
from django_ai_assistant.helpers import use_cases
from django_ai_assistant.helpers.assistants import AIAssistant
from django_ai_assistant.langchain.tools import BaseModel, Field, method_tool
from django_ai_assistant.models import Message, Thread


# Set up


@pytest.fixture(scope="module", autouse=True)
def setup_assistants():
# Clear the registry before the tests in the module
AIAssistant.clear_cls_registry()

# Define the assistant class inside the fixture to ensure registration
class TemperatureAssistant(AIAssistant):
id = "temperature_assistant" # noqa: A003
name = "Temperature Assistant"
instructions = "You are a temperature bot."
model = "gpt-4o"

def get_instructions(self):
return self.instructions + " Today is 2024-06-09."

@method_tool
def fetch_current_temperature(self, location: str) -> str:
"""Fetch the current temperature data for a location"""
return "32 degrees Celsius"

class FetchForecastTemperatureInput(BaseModel):
location: str
dt_str: str = Field(description="Date in the format 'YYYY-MM-DD'")

@method_tool(args_schema=FetchForecastTemperatureInput)
def fetch_forecast_temperature(self, location: str, dt_str: str) -> str:
"""Fetch the forecast temperature data for a location"""
return "35 degrees Celsius"

yield
# Clear the registry after the tests in the module
AIAssistant.clear_cls_registry()


def fake_permission_func(**kwargs):
return False


@pytest.fixture()
def use_fake_permissions(settings):
settings.AI_ASSISTANT_CAN_RUN_ASSISTANT = (
"tests.test_helpers.test_use_cases.fake_permission_func"
)
settings.AI_ASSISTANT_CAN_CREATE_THREAD_FN = (
"tests.test_helpers.test_use_cases.fake_permission_func"
)


# Assistant tests


def test_get_assistant_cls_returns_assistant_cls():
assistant_id = "temperature_assistant"
user = User()

assistant_cls = use_cases.get_assistant_cls(assistant_id, user)

assert assistant_cls.id == assistant_id


def test_get_assistant_cls_raises_error_when_assistant_not_defined():
assistant_id = "not_defined"
user = User()

with pytest.raises(AIAssistantNotDefinedError) as exc_info:
use_cases.get_assistant_cls(assistant_id, user)

assert str(exc_info.value) == "Assistant with id=not_defined not found"


def test_get_assistant_cls_raises_error_when_user_not_allowed(use_fake_permissions):
assistant_id = "temperature_assistant"
user = User()

with pytest.raises(AIUserNotAllowedError) as exc_info:
use_cases.get_assistant_cls(assistant_id, user)

assert str(exc_info.value) == "User is not allowed to use this assistant"


def test_get_single_assistant_info_returns_info():
assistant_id = "temperature_assistant"
user = User()

info = use_cases.get_single_assistant_info(assistant_id, user)

assert info["id"] == "temperature_assistant"
assert info["name"] == "Temperature Assistant"


def test_get_single_assistant_info_raises_exception_when_assistant_not_defined():
assistant_id = "not_defined"
user = User()

with pytest.raises(AIAssistantNotDefinedError) as exc_info:
use_cases.get_single_assistant_info(assistant_id, user)

assert str(exc_info.value) == "Assistant with id=not_defined not found"


def test_get_single_assistant_info_raises_exception_when_user_not_allowed(use_fake_permissions):
assistant_id = "temperature_assistant"
user = User()

with pytest.raises(AIUserNotAllowedError) as exc_info:
use_cases.get_single_assistant_info(assistant_id, user)

assert str(exc_info.value) == "User is not allowed to use this assistant"


def test_get_assistants_info_returns_info():
user = User()

info = use_cases.get_assistants_info(user)

assert info[0].id == "temperature_assistant"
assert info[0].name == "Temperature Assistant"
assert len(info) == 1


# Message tests


@pytest.mark.django_db(transaction=True)
@pytest.mark.vcr
def test_create_message():
user = baker.make(User)
thread = baker.make(Thread, created_by=user)
response = use_cases.create_message(
"temperature_assistant",
thread,
user,
"Hello, will I have to use my umbrella in Lisbon tomorrow?",
)

assert response == {
"input": "Hello, will I have to use my umbrella in Lisbon tomorrow?",
"history": [],
"output": "The forecast for Lisbon tomorrow is 35°C, which is quite warm and unlikely to involve rain. You probably won't need an umbrella.",
}


@pytest.mark.django_db(transaction=True)
def test_create_message_raises_exception_when_user_not_allowed():
user = baker.make(User)
thread = baker.make(Thread)

with pytest.raises(AIUserNotAllowedError) as exc_info:
use_cases.create_message(
"temperature_assistant",
thread,
user,
"Hello, will I have to use my umbrella in Lisbon tomorrow?",
)

assert str(exc_info.value) == "User is not allowed to create messages in this thread"


# Thread tests


@pytest.mark.django_db(transaction=True)
def test_create_thread():
user = baker.make(User)
response = use_cases.create_thread("My thread", user)

assert response.name == "My thread"
assert response.created_by == user


@pytest.mark.django_db(transaction=True)
def test_create_thread_raises_exception_when_user_not_allowed(use_fake_permissions):
user = baker.make(User)

with pytest.raises(AIUserNotAllowedError) as exc_info:
use_cases.create_thread("My thread", user)

assert str(exc_info.value) == "User is not allowed to create threads"


@pytest.mark.django_db(transaction=True)
def test_get_single_thread():
user = baker.make(User)
thread = baker.make(Thread, created_by=user)
response = use_cases.get_single_thread(thread.id, user)

assert response == thread


@pytest.mark.django_db(transaction=True)
def test_get_single_thread_raises_exception_when_user_not_allowed():
user = baker.make(User)
thread = baker.make(Thread)

with pytest.raises(AIUserNotAllowedError) as exc_info:
use_cases.get_single_thread(thread.id, user)

assert str(exc_info.value) == "User is not allowed to view this thread"


@pytest.mark.django_db(transaction=True)
def test_get_threads():
user = baker.make(User)
baker.make(Thread, created_by=user, _quantity=3)
response = use_cases.get_threads(user)

assert len(response) == 3


@pytest.mark.django_db(transaction=True)
def test_get_threads_does_not_list_other_users_threads():
user = baker.make(User)
baker.make(Thread, _quantity=3)
response = use_cases.get_threads(user)

assert len(response) == 0


@pytest.mark.django_db(transaction=True)
def test_update_thread():
user = baker.make(User)
thread = baker.make(Thread, created_by=user)
response = use_cases.update_thread(thread, "My updated thread", user)

assert response.name == "My updated thread"


@pytest.mark.django_db(transaction=True)
def test_update_thread_raises_exception_when_user_not_allowed():
user = baker.make(User)
thread = baker.make(Thread)

with pytest.raises(AIUserNotAllowedError) as exc_info:
use_cases.update_thread(thread, "My updated thread", user)

assert str(exc_info.value) == "User is not allowed to update this thread"


@pytest.mark.django_db(transaction=True)
def test_delete_thread():
user = baker.make(User)
thread = baker.make(Thread, created_by=user)
use_cases.delete_thread(thread, user)

assert not Thread.objects.filter(id=thread.id).exists()


@pytest.mark.django_db(transaction=True)
def test_delete_thread_raises_exception_when_user_not_allowed():
user = baker.make(User)
thread = baker.make(Thread)

with pytest.raises(AIUserNotAllowedError) as exc_info:
use_cases.delete_thread(thread, user)

assert str(exc_info.value) == "User is not allowed to delete this thread"


# Thread message tests


@pytest.mark.django_db(transaction=True)
def test_get_thread_messages():
user = baker.make(User)
thread = baker.make(Thread, created_by=user)
baker.make(
Message, message={"type": "human", "data": {"content": "hi"}}, thread=thread, _quantity=3
)
response = use_cases.get_thread_messages(thread.id, user)

assert len(response) == 3


@pytest.mark.django_db(transaction=True)
def test_get_thread_messages_raises_exception_when_user_not_allowed():
user = baker.make(User)
thread = baker.make(Thread)
baker.make(
Message, message={"type": "human", "data": {"content": "hi"}}, thread=thread, _quantity=3
)

with pytest.raises(AIUserNotAllowedError) as exc_info:
use_cases.get_thread_messages(thread.id, user)

assert str(exc_info.value) == "User is not allowed to view messages in this thread"


@pytest.mark.django_db(transaction=True)
def test_create_thread_message_as_user():
user = baker.make(User)
thread = baker.make(Thread, created_by=user)
use_cases.create_thread_message_as_user(thread.id, "Hello, how are you?", user)

assert Message.objects.filter(thread=thread).count() == 1


@pytest.mark.django_db(transaction=True)
def test_create_thread_message_as_user_raises_exception_when_user_not_allowed():
user = baker.make(User)
thread = baker.make(Thread)

with pytest.raises(AIUserNotAllowedError) as exc_info:
use_cases.create_thread_message_as_user(thread.id, "Hello, how are you?", user)

assert str(exc_info.value) == "User is not allowed to create messages in this thread"


@pytest.mark.django_db(transaction=True)
def test_delete_message():
user = baker.make(User)
thread = baker.make(Thread, created_by=user)
message = baker.make(Message, thread=thread)
use_cases.delete_message(message, user)

assert not Message.objects.filter(id=message.id).exists()


@pytest.mark.django_db(transaction=True)
def test_delete_message_raises_exception_when_user_not_allowed():
user = baker.make(User)
message = baker.make(Message)

with pytest.raises(AIUserNotAllowedError) as exc_info:
use_cases.delete_message(message, user)

assert str(exc_info.value) == "User is not allowed to delete this message"