diff --git a/django_ai_assistant/api/views.py b/django_ai_assistant/api/views.py index 10eb0d6..876f27f 100644 --- a/django_ai_assistant/api/views.py +++ b/django_ai_assistant/api/views.py @@ -5,6 +5,7 @@ from langchain_core.messages import message_to_dict from ninja import NinjaAPI from ninja.operation import Operation +from ninja.security import django_auth from django_ai_assistant import package_name, version from django_ai_assistant.api.schemas import ( @@ -26,7 +27,14 @@ def get_openapi_operation_id(self, operation: Operation) -> str: return (package_name + "_" + name).replace(".", "_") -api = API(title=package_name, version=version, urls_namespace="django_ai_assistant") +api = API( + title=package_name, + version=version, + urls_namespace="django_ai_assistant", + # Add auth to all endpoints + auth=django_auth, + csrf=True, +) @api.exception_handler(AIUserNotAllowedError) diff --git a/example/assets/js/components/Chat/Chat.tsx b/example/assets/js/components/Chat/Chat.tsx index 60c30c7..2606c90 100644 --- a/example/assets/js/components/Chat/Chat.tsx +++ b/example/assets/js/components/Chat/Chat.tsx @@ -93,12 +93,7 @@ function ChatMessageList({ deleteMessage, }: { messages: ThreadMessagesSchemaOut[]; - deleteMessage: ({ - threadId, - messageId, - }: { - messageId: string; - }) => Promise; + deleteMessage: ({ messageId }: { messageId: string }) => Promise; }) { if (messages.length === 0) { return No messages.; diff --git a/frontend/openapi_schema.json b/frontend/openapi_schema.json index 75f5898..1adc83d 100644 --- a/frontend/openapi_schema.json +++ b/frontend/openapi_schema.json @@ -26,7 +26,12 @@ } } } - } + }, + "security": [ + { + "SessionAuth": [] + } + ] } }, "/assistants/{assistant_id}/": { @@ -55,7 +60,12 @@ } } } - } + }, + "security": [ + { + "SessionAuth": [] + } + ] } }, "/threads/": { @@ -78,7 +88,12 @@ } } } - } + }, + "security": [ + { + "SessionAuth": [] + } + ] }, "post": { "operationId": "django_ai_assistant_create_thread", @@ -105,7 +120,12 @@ } }, "required": true - } + }, + "security": [ + { + "SessionAuth": [] + } + ] } }, "/threads/{thread_id}/": { @@ -134,7 +154,12 @@ } } } - } + }, + "security": [ + { + "SessionAuth": [] + } + ] }, "patch": { "operationId": "django_ai_assistant_update_thread", @@ -171,7 +196,12 @@ } }, "required": true - } + }, + "security": [ + { + "SessionAuth": [] + } + ] }, "delete": { "operationId": "django_ai_assistant_delete_thread", @@ -191,7 +221,12 @@ "204": { "description": "No Content" } - } + }, + "security": [ + { + "SessionAuth": [] + } + ] } }, "/threads/{thread_id}/messages/": { @@ -224,7 +259,12 @@ } } } - } + }, + "security": [ + { + "SessionAuth": [] + } + ] }, "post": { "operationId": "django_ai_assistant_create_thread_message", @@ -254,7 +294,12 @@ } }, "required": true - } + }, + "security": [ + { + "SessionAuth": [] + } + ] } }, "/threads/{thread_id}/messages/{message_id}/": { @@ -285,7 +330,12 @@ "204": { "description": "No Content" } - } + }, + "security": [ + { + "SessionAuth": [] + } + ] } } }, @@ -414,6 +464,13 @@ "title": "ThreadMessagesSchemaIn", "type": "object" } + }, + "securitySchemes": { + "SessionAuth": { + "type": "apiKey", + "in": "cookie", + "name": "sessionid" + } } }, "servers": [] diff --git a/frontend/package-lock.json b/frontend/package-lock.json index c0d20ec..156a272 100644 --- a/frontend/package-lock.json +++ b/frontend/package-lock.json @@ -9,12 +9,14 @@ "version": "0.0.1", "license": "MIT", "dependencies": { - "axios": "^1.7.2" + "axios": "^1.7.2", + "cookie": "^0.6.0" }, "devDependencies": { "@hey-api/openapi-ts": "^0.46.3", "@testing-library/dom": "^10.1.0", "@testing-library/react": "^16.0.0", + "@types/cookie": "^0.6.0", "@types/jest": "^29.5.12", "@types/node": "^20.14.1", "@types/react": "^18.3.3", @@ -2813,6 +2815,12 @@ "@babel/types": "^7.20.7" } }, + "node_modules/@types/cookie": { + "version": "0.6.0", + "resolved": "https://registry.npmjs.org/@types/cookie/-/cookie-0.6.0.tgz", + "integrity": "sha512-4Kh9a6B2bQciAhf7FSuMRRkUWecJgJu9nPnx3yzpsfXX/c50REIqpHY4C82bXP90qrLtXtkDxTZosYO3UpOwlA==", + "dev": true + }, "node_modules/@types/estree": { "version": "1.0.5", "resolved": "https://registry.npmjs.org/@types/estree/-/estree-1.0.5.tgz", @@ -4166,6 +4174,14 @@ "integrity": "sha512-Kvp459HrV2FEJ1CAsi1Ku+MY3kasH19TFykTz2xWmMeq6bk2NU3XXvfJ+Q61m0xktWwt+1HSYf3JZsTms3aRJg==", "dev": true }, + "node_modules/cookie": { + "version": "0.6.0", + "resolved": "https://registry.npmjs.org/cookie/-/cookie-0.6.0.tgz", + "integrity": "sha512-U71cyTamuh1CRNCfpGY6to28lxvNwPG4Guz/EVjgf3Jmzv0vlDp1atT9eS5dDjMYHucpHbWns6Lwf3BKz6svdw==", + "engines": { + "node": ">= 0.6" + } + }, "node_modules/core-util-is": { "version": "1.0.2", "resolved": "https://registry.npmjs.org/core-util-is/-/core-util-is-1.0.2.tgz", diff --git a/frontend/package.json b/frontend/package.json index 321b182..81d9036 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -42,7 +42,8 @@ "generate-client": "openapi-ts" }, "dependencies": { - "axios": "^1.7.2" + "axios": "^1.7.2", + "cookie": "^0.6.0" }, "peerDependencies": { "react": "^18.3.1", @@ -52,6 +53,7 @@ "@hey-api/openapi-ts": "^0.46.3", "@testing-library/dom": "^10.1.0", "@testing-library/react": "^16.0.0", + "@types/cookie": "^0.6.0", "@types/jest": "^29.5.12", "@types/node": "^20.14.1", "@types/react": "^18.3.3", diff --git a/frontend/src/config.ts b/frontend/src/config.ts index 0437273..a82dc96 100644 --- a/frontend/src/config.ts +++ b/frontend/src/config.ts @@ -1,9 +1,14 @@ +import cookie from "cookie"; + import { OpenAPI } from "./client"; +import { AxiosRequestConfig } from "axios"; /** * Configures the base URL for the AI Assistant API which is path associated with * the Django include. * + * Configures the Axios request to include the CSRF token if it exists. + * * @param baseURL Base URL of the AI Assistant API. * * @example @@ -11,4 +16,12 @@ import { OpenAPI } from "./client"; */ export function configAIAssistant({ baseURL }: { baseURL: string }) { OpenAPI.BASE = baseURL; + + OpenAPI.interceptors.request.use((request: AxiosRequestConfig) => { + const { csrftoken } = cookie.parse(document.cookie); + if (request.headers && csrftoken) { + request.headers["X-CSRFTOKEN"] = csrftoken; + } + return request; + }); } diff --git a/tests/test_views.py b/tests/test_views.py index 701739e..4bee354 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -50,8 +50,9 @@ def authenticated_client(client): # Assistant Views -def test_list_assistants_with_results(client): - response = client.get(reverse("django_ai_assistant:assistants_list")) +@pytest.mark.django_db() +def test_list_assistants_with_results(authenticated_client): + response = authenticated_client.get(reverse("django_ai_assistant:assistants_list")) assert response.status_code == HTTPStatus.OK assert response.json() == [{"id": "temperature_assistant", "name": "Temperature Assistant"}] @@ -62,8 +63,9 @@ def test_does_not_list_assistants_if_unauthorized(): pass -def test_get_assistant_that_exists(client): - response = client.get( +@pytest.mark.django_db() +def test_get_assistant_that_exists(authenticated_client): + response = authenticated_client.get( reverse( "django_ai_assistant:assistant_detail", kwargs={"assistant_id": "temperature_assistant"} ) @@ -73,9 +75,10 @@ def test_get_assistant_that_exists(client): assert response.json() == {"id": "temperature_assistant", "name": "Temperature Assistant"} -def test_get_assistant_that_does_not_exist(client): +@pytest.mark.django_db() +def test_get_assistant_that_does_not_exist(authenticated_client): with pytest.raises(AIAssistantNotDefinedError): - client.get( + authenticated_client.get( reverse( "django_ai_assistant:assistant_detail", kwargs={"assistant_id": "fake_assistant"} )