Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restructure modules: separate use cases from API #87

Merged
merged 9 commits into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion django_ai_assistant/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,19 @@
from importlib import metadata

from django_ai_assistant.helpers.assistants import ( # noqa
AIAssistant,
register_assistant,
)
from django_ai_assistant.langchain.tools import ( # noqa
BaseModel,
BaseTool,
Field,
StructuredTool,
Tool,
method_tool,
tool,
)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added all the main public API here (in traditional API sense, not web API) he


__version__ = metadata.version(__package__)

version = __version__ = metadata.version(__package__)
package_name = __package__
2 changes: 1 addition & 1 deletion django_ai_assistant/admin.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from django.contrib import admin

from .models import Message, Thread
from django_ai_assistant.models import Message, Thread


@admin.register(Thread)
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from ninja import Field, ModelSchema, Schema

from .models import Thread
from django_ai_assistant.models import Thread


class AssistantSchema(Schema):
Expand Down
53 changes: 28 additions & 25 deletions django_ai_assistant/views.py → django_ai_assistant/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,29 @@

from langchain_core.messages import message_to_dict
from ninja import NinjaAPI
from ninja.operation import Operation

from django_ai_assistant import __package__, __version__

from .exceptions import AIUserNotAllowedError
from .helpers import assistants
from .helpers.assistants import (
create_message,
get_assistants_info,
get_single_assistant_info,
get_single_thread,
get_thread_messages,
get_threads,
)
from .models import Message, Thread
from .schemas import (
from django_ai_assistant import package_name, version
from django_ai_assistant.api.schemas import (
AssistantSchema,
ThreadMessagesSchemaIn,
ThreadMessagesSchemaOut,
ThreadSchema,
ThreadSchemaIn,
)
from django_ai_assistant.exceptions import AIUserNotAllowedError
from django_ai_assistant.helpers import use_cases
from django_ai_assistant.models import Message, Thread


class API(NinjaAPI):
# Force "operationId" to be like "django_ai_assistant_delete_thread"
def get_openapi_operation_id(self, operation: Operation) -> str:
name = operation.view_func.__name__
return (package_name + "_" + name).replace(".", "_")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes us have better names at frontend/src/client/services.gen.ts

Like djangoAiAssistantListAssistants instead of djangoAiAssistantViewsListAssistants



api = NinjaAPI(title=__package__, version=__version__, urls_namespace="django_ai_assistant")
api = API(title=package_name, version=version, urls_namespace="django_ai_assistant")


@api.exception_handler(AIUserNotAllowedError)
Expand All @@ -41,42 +40,44 @@ def ai_user_not_allowed_handler(request, exc):

@api.get("assistants/", response=List[AssistantSchema], url_name="assistants_list")
def list_assistants(request):
return list(get_assistants_info(user=request.user, request=request))
return list(use_cases.get_assistants_info(user=request.user, request=request))


@api.get("assistants/{assistant_id}/", response=AssistantSchema, url_name="assistant_detail")
def get_assistant(request, assistant_id: str):
return get_single_assistant_info(assistant_id=assistant_id, user=request.user, request=request)
return use_cases.get_single_assistant_info(
assistant_id=assistant_id, user=request.user, request=request
)


@api.get("threads/", response=List[ThreadSchema], url_name="threads_list_create")
def list_threads(request):
return list(get_threads(user=request.user, request=request))
return list(use_cases.get_threads(user=request.user, request=request))


@api.post("threads/", response=ThreadSchema, url_name="threads_list_create")
def create_thread(request, payload: ThreadSchemaIn):
name = payload.name
return assistants.create_thread(name=name, user=request.user, request=request)
return use_cases.create_thread(name=name, user=request.user, request=request)


@api.get("threads/{thread_id}/", response=ThreadSchema, url_name="thread_detail_update_delete")
def get_thread(request, thread_id: str):
thread = get_single_thread(thread_id=thread_id, user=request.user, request=request)
thread = use_cases.get_single_thread(thread_id=thread_id, user=request.user, request=request)
return thread


@api.patch("threads/{thread_id}/", response=ThreadSchema, url_name="thread_detail_update_delete")
def update_thread(request, thread_id: str, payload: ThreadSchemaIn):
thread = get_object_or_404(Thread, id=thread_id)
name = payload.name
return assistants.update_thread(thread=thread, name=name, user=request.user, request=request)
return use_cases.update_thread(thread=thread, name=name, user=request.user, request=request)


@api.delete("threads/{thread_id}/", response={204: None}, url_name="thread_detail_update_delete")
def delete_thread(request, thread_id: str):
thread = get_object_or_404(Thread, id=thread_id)
assistants.delete_thread(thread=thread, user=request.user, request=request)
use_cases.delete_thread(thread=thread, user=request.user, request=request)
return 204, None


Expand All @@ -86,7 +87,9 @@ def delete_thread(request, thread_id: str):
url_name="messages_list_create",
)
def list_thread_messages(request, thread_id: str):
messages = get_thread_messages(thread_id=thread_id, user=request.user, request=request)
messages = use_cases.get_thread_messages(
thread_id=thread_id, user=request.user, request=request
)
return [message_to_dict(m)["data"] for m in messages]


Expand All @@ -99,7 +102,7 @@ def list_thread_messages(request, thread_id: str):
def create_thread_message(request, thread_id: str, payload: ThreadMessagesSchemaIn):
thread = Thread.objects.get(id=thread_id)

create_message(
use_cases.create_message(
assistant_id=payload.assistant_id,
thread=thread,
user=request.user,
Expand All @@ -114,7 +117,7 @@ def create_thread_message(request, thread_id: str, payload: ThreadMessagesSchema
)
def delete_thread_message(request, thread_id: str, message_id: str):
message = get_object_or_404(Message, id=message_id, thread_id=thread_id)
assistants.delete_message(
use_cases.delete_message(
message=message,
user=request.user,
request=request,
Expand Down
172 changes: 5 additions & 167 deletions django_ai_assistant/helpers/assistants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
import re
from typing import Any, ClassVar, Sequence, cast

from django.http import HttpRequest

from langchain.agents import AgentExecutor
from langchain.agents.format_scratchpad.tools import (
format_to_tool_messages,
Expand All @@ -15,7 +13,6 @@
DEFAULT_DOCUMENT_SEPARATOR,
)
from langchain_core.chat_history import InMemoryChatMessageHistory
from langchain_core.messages import BaseMessage, HumanMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import (
ChatPromptTemplate,
Expand All @@ -37,22 +34,11 @@
from langchain_core.tools import BaseTool
from langchain_openai import ChatOpenAI

from django_ai_assistant.ai.chat_message_histories import DjangoChatMessageHistory
from django_ai_assistant.exceptions import (
AIAssistantMisconfiguredError,
AIAssistantNotDefinedError,
AIUserNotAllowedError,
)
from django_ai_assistant.models import Message, Thread
from django_ai_assistant.permissions import (
can_create_message,
can_create_thread,
can_delete_message,
can_delete_thread,
can_run_assistant,
)
from django_ai_assistant.tools import Tool
from django_ai_assistant.tools import tool as tool_decorator
from django_ai_assistant.langchain.tools import Tool
from django_ai_assistant.langchain.tools import tool as tool_decorator


class AIAssistant(abc.ABC): # noqa: F821
Expand Down Expand Up @@ -156,6 +142,9 @@ def get_prompt_template(self):
)

def get_message_history(self, thread_id: int | None):
# DjangoChatMessageHistory must be here because Django may not be loaded yet elsewhere:
from django_ai_assistant.langchain.chat_message_histories import DjangoChatMessageHistory

if thread_id is None:
return InMemoryChatMessageHistory()
return DjangoChatMessageHistory(thread_id)
Expand Down Expand Up @@ -309,154 +298,3 @@ def as_tool(self, description) -> BaseTool:
def register_assistant(cls: type[AIAssistant]):
ASSISTANT_CLS_REGISTRY[cls.id] = cls
return cls


Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved to use_cases

def _get_assistant_cls(
assistant_id: str,
user: Any,
request: HttpRequest | None = None,
):
if assistant_id not in ASSISTANT_CLS_REGISTRY:
raise AIAssistantNotDefinedError(f"Assistant with id={assistant_id} not found")
assistant_cls = ASSISTANT_CLS_REGISTRY[assistant_id]
if not can_run_assistant(
assistant_cls=assistant_cls,
user=user,
request=request,
):
raise AIUserNotAllowedError("User is not allowed to use this assistant")
return assistant_cls


def get_single_assistant_info(
assistant_id: str,
user: Any,
request: HttpRequest | None = None,
):
assistant_cls = _get_assistant_cls(assistant_id, user, request)

return {
"id": assistant_id,
"name": assistant_cls.name,
}


def get_assistants_info(
user: Any,
request: HttpRequest | None = None,
):
return [
_get_assistant_cls(assistant_id=assistant_id, user=user, request=request)
for assistant_id in ASSISTANT_CLS_REGISTRY.keys()
]


def create_message(
assistant_id: str,
thread: Thread,
user: Any,
content: Any,
request: HttpRequest | None = None,
):
assistant_cls = _get_assistant_cls(assistant_id, user, request)

if not can_create_message(thread=thread, user=user, request=request):
raise AIUserNotAllowedError("User is not allowed to create messages in this thread")

# TODO: Check if we can separate the message creation from the chain invoke
assistant = assistant_cls(user=user, request=request)
assistant_message = assistant.invoke(
{"input": content},
thread_id=thread.id,
)
return assistant_message


def create_thread(
name: str,
user: Any,
request: HttpRequest | None = None,
):
if not can_create_thread(user=user, request=request):
raise AIUserNotAllowedError("User is not allowed to create threads")

thread = Thread.objects.create(name=name, created_by=user)
return thread


def get_single_thread(
thread_id: str,
user: Any,
request: HttpRequest | None = None,
):
return Thread.objects.filter(created_by=user).get(id=thread_id)


def get_threads(
user: Any,
request: HttpRequest | None = None,
):
return list(Thread.objects.filter(created_by=user))


def update_thread(
thread: Thread,
name: str,
user: Any,
request: HttpRequest | None = None,
):
if not can_delete_thread(thread=thread, user=user, request=request):
raise AIUserNotAllowedError("User is not allowed to update this thread")

thread.name = name
thread.save()
return thread


def delete_thread(
thread: Thread,
user: Any,
request: HttpRequest | None = None,
):
if not can_delete_thread(thread=thread, user=user, request=request):
raise AIUserNotAllowedError("User is not allowed to delete this thread")

return thread.delete()


def get_thread_messages(
thread_id: str,
user: Any,
request: HttpRequest | None = None,
) -> list[BaseMessage]:
# TODO: have more permissions for threads? View thread permission?
thread = Thread.objects.get(id=thread_id)
if user != thread.created_by:
raise AIUserNotAllowedError("User is not allowed to view messages in this thread")

return DjangoChatMessageHistory(thread.id).get_messages()


def create_thread_message_as_user(
thread_id: str,
content: str,
user: Any,
request: HttpRequest | None = None,
):
# TODO: have more permissions for threads? View thread permission?
thread = Thread.objects.get(id=thread_id)
if user != thread.created_by:
raise AIUserNotAllowedError("User is not allowed to create messages in this thread")

DjangoChatMessageHistory(thread.id).add_messages([HumanMessage(content=content)])


def delete_message(
message: Message,
user: Any,
request: HttpRequest | None = None,
):
if not can_delete_message(message=message, user=user, request=request):
raise AIUserNotAllowedError("User is not allowed to delete this message")

return DjangoChatMessageHistory(thread_id=message.thread_id).remove_messages([str(message.id)])
Loading