Skip to content

Commit

Permalink
Merge branch 'main' into feat/api-django-auth
Browse files Browse the repository at this point in the history
  • Loading branch information
pamella committed Jun 19, 2024
2 parents ac16357 + 59ff7ed commit 7b698fd
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 55 deletions.
52 changes: 25 additions & 27 deletions frontend/src/hooks/useMessageList.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,19 @@ import {
ThreadMessagesSchemaOut,
} from "../client";


function hasNullThreadId(threadId: string | null): threadId is null {
function hasNullThreadId(
threadId: string | null,
operation: string | null = "fetch"
): threadId is null {
if (threadId == null) {
console.warn("threadId is null or undefined. Ignoring fetch operation.");
console.warn(
`threadId is null or undefined. Ignoring ${operation} operation.`
);
return true;
}
return false;
}


/**
* React hook to manage the list, create, and delete of Messages.
*
Expand All @@ -39,23 +42,22 @@ export function useMessageList({ threadId }: { threadId: string | null }) {
*
* @returns - A promise that resolves with the fetched list of messages.
*/
const fetchMessages = useCallback(
async (): Promise<ThreadMessagesSchemaOut[] | null> => {
if (hasNullThreadId(threadId)) return null;
const fetchMessages = useCallback(async (): Promise<
ThreadMessagesSchemaOut[] | null
> => {
if (hasNullThreadId(threadId)) return null;

try {
setLoadingFetchMessages(true);
const fetchedMessages = await djangoAiAssistantListThreadMessages({
threadId: threadId,
});
setMessages(fetchedMessages);
return fetchedMessages;
} finally {
setLoadingFetchMessages(false);
}
},
[threadId]
);
try {
setLoadingFetchMessages(true);
const fetchedMessages = await djangoAiAssistantListThreadMessages({
threadId: threadId,
});
setMessages(fetchedMessages);
return fetchedMessages;
} finally {
setLoadingFetchMessages(false);
}
}, [threadId]);

/**
* Creates a new message in a thread.
Expand All @@ -72,7 +74,7 @@ export function useMessageList({ threadId }: { threadId: string | null }) {
assistantId: string;
messageTextValue: string;
}): Promise<void> => {
if (hasNullThreadId(threadId)) return;
if (hasNullThreadId(threadId, "create")) return;

try {
setLoadingCreateMessage(true);
Expand Down Expand Up @@ -100,12 +102,8 @@ export function useMessageList({ threadId }: { threadId: string | null }) {
* @param messageId The ID of the message to delete.
*/
const deleteMessage = useCallback(
async ({
messageId,
}: {
messageId: string;
}): Promise<void> => {
if (hasNullThreadId(threadId)) return;
async ({ messageId }: { messageId: string }): Promise<void> => {
if (hasNullThreadId(threadId, "delete")) return;

try {
setLoadingDeleteMessage(true);
Expand Down
98 changes: 87 additions & 11 deletions frontend/tests/useMessageList.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,28 @@ describe("useMessageList", () => {
});

describe("fetchMessages", () => {
it("should not fetch messages if threadId is null", async () => {
const warnSpy = jest.spyOn(console, "warn").mockImplementation(() => {});

const { result } = renderHook(() => useMessageList({ threadId: null }));

expect(result.current.messages).toBeNull();
expect(result.current.loadingFetchMessages).toBe(false);

await act(async () => {
expect(await result.current.fetchMessages()).toBeNull();
});

expect(warnSpy).toHaveBeenCalledWith(
"threadId is null or undefined. Ignoring fetch operation."
);

expect(result.current.messages).toBeNull();
expect(result.current.loadingFetchMessages).toBe(false);

warnSpy.mockRestore();
});

it("should fetch messages and update state correctly", async () => {
(djangoAiAssistantListThreadMessages as jest.Mock).mockResolvedValue(
mockMessages
Expand Down Expand Up @@ -86,6 +108,33 @@ describe("useMessageList", () => {
});

describe("createMessage", () => {
it("should not create message if threadId is null", async () => {
const warnSpy = jest.spyOn(console, "warn").mockImplementation(() => {});

const { result } = renderHook(() => useMessageList({ threadId: null }));

expect(result.current.messages).toBeNull();
expect(result.current.loadingCreateMessage).toBe(false);

await act(async () => {
expect(
await result.current.createMessage({
assistantId: "1",
messageTextValue: "Test message",
})
).toBeUndefined();
});

expect(warnSpy).toHaveBeenCalledWith(
"threadId is null or undefined. Ignoring create operation."
);

expect(result.current.messages).toBeNull();
expect(result.current.loadingCreateMessage).toBe(false);

warnSpy.mockRestore();
});

it("should create message and update state correctly", async () => {
const mockNewMessages = [
{
Expand All @@ -97,12 +146,13 @@ describe("useMessageList", () => {
content: "The current temperature in Recife is 30°C.",
},
];
(
djangoAiAssistantCreateThreadMessage as jest.Mock
).mockResolvedValue(null);
(djangoAiAssistantListThreadMessages as jest.Mock).mockResolvedValue(
[...mockMessages, ...mockNewMessages]
(djangoAiAssistantCreateThreadMessage as jest.Mock).mockResolvedValue(
null
);
(djangoAiAssistantListThreadMessages as jest.Mock).mockResolvedValue([
...mockMessages,
...mockNewMessages,
]);

const { result } = renderHook(() => useMessageList({ threadId: "1" }));

Expand All @@ -126,9 +176,9 @@ describe("useMessageList", () => {
});

it("should set loading to false if create fails", async () => {
(
djangoAiAssistantCreateThreadMessage as jest.Mock
).mockRejectedValue(new Error("Failed to create"));
(djangoAiAssistantCreateThreadMessage as jest.Mock).mockRejectedValue(
new Error("Failed to create")
);

const { result } = renderHook(() => useMessageList({ threadId: "1" }));

Expand All @@ -150,6 +200,32 @@ describe("useMessageList", () => {
});

describe("deleteMessage", () => {
it("should not delete message if threadId is null", async () => {
const warnSpy = jest.spyOn(console, "warn").mockImplementation(() => {});

const { result } = renderHook(() => useMessageList({ threadId: null }));

expect(result.current.messages).toBeNull();
expect(result.current.loadingCreateMessage).toBe(false);

await act(async () => {
expect(
await result.current.deleteMessage({
messageId: mockMessages[0].id,
})
).toBeUndefined();
});

expect(warnSpy).toHaveBeenCalledWith(
"threadId is null or undefined. Ignoring delete operation."
);

expect(result.current.messages).toBeNull();
expect(result.current.loadingDeleteMessage).toBe(false);

warnSpy.mockRestore();
});

it("should delete a message and update state correctly", async () => {
const deletedMessageId = mockMessages[0].id;
(djangoAiAssistantListThreadMessages as jest.Mock).mockResolvedValue(
Expand Down Expand Up @@ -180,9 +256,9 @@ describe("useMessageList", () => {
(djangoAiAssistantListThreadMessages as jest.Mock).mockResolvedValue(
mockMessages.filter((message) => message.id !== deletedMessageId)
);
(
djangoAiAssistantDeleteThreadMessage as jest.Mock
).mockRejectedValue(new Error("Failed to delete"));
(djangoAiAssistantDeleteThreadMessage as jest.Mock).mockRejectedValue(
new Error("Failed to delete")
);

const { result } = renderHook(() => useMessageList({ threadId: "1" }));

Expand Down
60 changes: 43 additions & 17 deletions tests/test_views.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
from http import HTTPStatus

from django.contrib.auth.models import User
from django.urls import reverse

import pytest
from model_bakery import baker
Expand All @@ -19,7 +19,6 @@
class TemperatureAssistant(AIAssistant):
id = "temperature_assistant" # noqa: A003
name = "Temperature Assistant"
description = "A temperature assistant that provides temperature information."
instructions = "You are a temperature bot."
model = "gpt-4o"

Expand Down Expand Up @@ -53,7 +52,7 @@ def authenticated_client(client):

@pytest.mark.django_db()
def test_list_assistants_with_results(authenticated_client):
response = authenticated_client.get("/assistants/")
response = authenticated_client.get(reverse("django_ai_assistant:assistants_list"))

assert response.status_code == HTTPStatus.OK
assert response.json() == [{"id": "temperature_assistant", "name": "Temperature Assistant"}]
Expand All @@ -66,7 +65,11 @@ def test_does_not_list_assistants_if_unauthorized():

@pytest.mark.django_db()
def test_get_assistant_that_exists(authenticated_client):
response = authenticated_client.get("/assistants/temperature_assistant/")
response = authenticated_client.get(
reverse(
"django_ai_assistant:assistant_detail", kwargs={"assistant_id": "temperature_assistant"}
)
)

assert response.status_code == HTTPStatus.OK
assert response.json() == {"id": "temperature_assistant", "name": "Temperature Assistant"}
Expand All @@ -75,7 +78,11 @@ def test_get_assistant_that_exists(authenticated_client):
@pytest.mark.django_db()
def test_get_assistant_that_does_not_exist(authenticated_client):
with pytest.raises(AIAssistantNotDefinedError):
authenticated_client.get("/assistants/fake_assistant/")
authenticated_client.get(
reverse(
"django_ai_assistant:assistant_detail", kwargs={"assistant_id": "fake_assistant"}
)
)


def test_does_not_return_assistant_if_unauthorized():
Expand All @@ -90,7 +97,7 @@ def test_does_not_return_assistant_if_unauthorized():

@pytest.mark.django_db(transaction=True)
def test_list_threads_without_results(authenticated_client):
response = authenticated_client.get("/threads/")
response = authenticated_client.get(reverse("django_ai_assistant:threads_list_create"))

assert response.status_code == HTTPStatus.OK
assert response.json() == []
Expand All @@ -99,22 +106,33 @@ def test_list_threads_without_results(authenticated_client):
@pytest.mark.django_db(transaction=True)
def test_list_threads_with_results(authenticated_client):
user = User.objects.first()
thread = baker.make(Thread, created_by=user)
response = authenticated_client.get("/threads/")
baker.make(Thread, created_by=user, _quantity=2)
response = authenticated_client.get(reverse("django_ai_assistant:threads_list_create"))

assert response.status_code == HTTPStatus.OK
assert response.json()[0].get("id") == thread.id
assert len(response.json()) == 2


@pytest.mark.django_db(transaction=True)
def test_does_not_list_other_users_threads(authenticated_client):
baker.make(Thread)
response = authenticated_client.get("/threads/")
response = authenticated_client.get(reverse("django_ai_assistant:threads_list_create"))

assert response.status_code == HTTPStatus.OK
assert response.json() == []


@pytest.mark.django_db(transaction=True)
def test_gets_specific_thread(authenticated_client):
thread = baker.make(Thread, created_by=User.objects.first())
response = authenticated_client.get(
reverse("django_ai_assistant:thread_detail_update_delete", kwargs={"thread_id": thread.id})
)

assert response.status_code == HTTPStatus.OK
assert response.json().get("id") == thread.id


def test_does_not_list_threads_if_unauthorized():
# TODO: Implement this test once permissions are in place
pass
Expand All @@ -126,7 +144,7 @@ def test_does_not_list_threads_if_unauthorized():
@pytest.mark.django_db(transaction=True)
def test_create_thread(authenticated_client):
response = authenticated_client.post(
"/threads/", data=json.dumps({}), content_type="application/json"
reverse("django_ai_assistant:threads_list_create"), data={}, content_type="application/json"
)

thread = Thread.objects.first()
Expand All @@ -147,24 +165,26 @@ def test_cannot_create_thread_if_unauthorized():
def test_update_thread(authenticated_client):
thread = baker.make(Thread, created_by=User.objects.first())
response = authenticated_client.patch(
f"/threads/{thread.id}/",
data=json.dumps({"name": "New name"}),
reverse("django_ai_assistant:thread_detail_update_delete", kwargs={"thread_id": thread.id}),
data={"name": "New name"},
content_type="application/json",
)

assert response.status_code == HTTPStatus.OK
assert Thread.objects.filter(id=thread.id).first().name == "New name"


@pytest.mark.django_db(transaction=True)
def test_cannot_update_other_users_threads(authenticated_client):
thread = baker.make(Thread)
response = authenticated_client.patch(
f"/threads/{thread.id}/",
data=json.dumps({"name": "New name"}),
reverse("django_ai_assistant:thread_detail_update_delete", kwargs={"thread_id": thread.id}),
data={"name": "New name"},
content_type="application/json",
)

assert response.status_code == HTTPStatus.FORBIDDEN
assert Thread.objects.filter(id=thread.id).first().name != "New name"


def test_cannot_update_thread_if_unauthorized():
Expand All @@ -178,17 +198,23 @@ def test_cannot_update_thread_if_unauthorized():
@pytest.mark.django_db(transaction=True)
def test_delete_thread(authenticated_client):
thread = baker.make(Thread, created_by=User.objects.first())
response = authenticated_client.delete(f"/threads/{thread.id}/")
response = authenticated_client.delete(
reverse("django_ai_assistant:thread_detail_update_delete", kwargs={"thread_id": thread.id})
)

assert response.status_code == HTTPStatus.NO_CONTENT
assert not Thread.objects.filter(id=thread.id).exists()


@pytest.mark.django_db(transaction=True)
def test_cannot_delete_other_users_threads(authenticated_client):
thread = baker.make(Thread)
response = authenticated_client.delete(f"/threads/{thread.id}/")
response = authenticated_client.delete(
reverse("django_ai_assistant:thread_detail_update_delete", kwargs={"thread_id": thread.id})
)

assert response.status_code == HTTPStatus.FORBIDDEN
assert Thread.objects.filter(id=thread.id).exists()


def test_cannot_delete_thread_if_unauthorized():
Expand Down

0 comments on commit 7b698fd

Please sign in to comment.