-
Notifications
You must be signed in to change notification settings - Fork 20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Restructure modules: separate use cases from API #87
Changes from all commits
22768b1
89930a9
5c7f9e1
ce8fe42
f59ce7a
eaa7d7a
a4e6f86
0f485da
dcb8dcb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,19 @@ | ||
from importlib import metadata | ||
|
||
from django_ai_assistant.helpers.assistants import ( # noqa | ||
AIAssistant, | ||
register_assistant, | ||
) | ||
from django_ai_assistant.langchain.tools import ( # noqa | ||
BaseModel, | ||
BaseTool, | ||
Field, | ||
StructuredTool, | ||
Tool, | ||
method_tool, | ||
tool, | ||
) | ||
|
||
__version__ = metadata.version(__package__) | ||
|
||
version = __version__ = metadata.version(__package__) | ||
package_name = __package__ |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,30 +4,29 @@ | |
|
||
from langchain_core.messages import message_to_dict | ||
from ninja import NinjaAPI | ||
from ninja.operation import Operation | ||
|
||
from django_ai_assistant import __package__, __version__ | ||
|
||
from .exceptions import AIUserNotAllowedError | ||
from .helpers import assistants | ||
from .helpers.assistants import ( | ||
create_message, | ||
get_assistants_info, | ||
get_single_assistant_info, | ||
get_single_thread, | ||
get_thread_messages, | ||
get_threads, | ||
) | ||
from .models import Message, Thread | ||
from .schemas import ( | ||
from django_ai_assistant import package_name, version | ||
from django_ai_assistant.api.schemas import ( | ||
AssistantSchema, | ||
ThreadMessagesSchemaIn, | ||
ThreadMessagesSchemaOut, | ||
ThreadSchema, | ||
ThreadSchemaIn, | ||
) | ||
from django_ai_assistant.exceptions import AIUserNotAllowedError | ||
from django_ai_assistant.helpers import use_cases | ||
from django_ai_assistant.models import Message, Thread | ||
|
||
|
||
class API(NinjaAPI): | ||
# Force "operationId" to be like "django_ai_assistant_delete_thread" | ||
def get_openapi_operation_id(self, operation: Operation) -> str: | ||
name = operation.view_func.__name__ | ||
return (package_name + "_" + name).replace(".", "_") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This makes us have better names at frontend/src/client/services.gen.ts Like djangoAiAssistantListAssistants instead of djangoAiAssistantViewsListAssistants |
||
|
||
|
||
api = NinjaAPI(title=__package__, version=__version__, urls_namespace="django_ai_assistant") | ||
api = API(title=package_name, version=version, urls_namespace="django_ai_assistant") | ||
|
||
|
||
@api.exception_handler(AIUserNotAllowedError) | ||
|
@@ -41,42 +40,44 @@ def ai_user_not_allowed_handler(request, exc): | |
|
||
@api.get("assistants/", response=List[AssistantSchema], url_name="assistants_list") | ||
def list_assistants(request): | ||
return list(get_assistants_info(user=request.user, request=request)) | ||
return list(use_cases.get_assistants_info(user=request.user, request=request)) | ||
|
||
|
||
@api.get("assistants/{assistant_id}/", response=AssistantSchema, url_name="assistant_detail") | ||
def get_assistant(request, assistant_id: str): | ||
return get_single_assistant_info(assistant_id=assistant_id, user=request.user, request=request) | ||
return use_cases.get_single_assistant_info( | ||
assistant_id=assistant_id, user=request.user, request=request | ||
) | ||
|
||
|
||
@api.get("threads/", response=List[ThreadSchema], url_name="threads_list_create") | ||
def list_threads(request): | ||
return list(get_threads(user=request.user, request=request)) | ||
return list(use_cases.get_threads(user=request.user, request=request)) | ||
|
||
|
||
@api.post("threads/", response=ThreadSchema, url_name="threads_list_create") | ||
def create_thread(request, payload: ThreadSchemaIn): | ||
name = payload.name | ||
return assistants.create_thread(name=name, user=request.user, request=request) | ||
return use_cases.create_thread(name=name, user=request.user, request=request) | ||
|
||
|
||
@api.get("threads/{thread_id}/", response=ThreadSchema, url_name="thread_detail_update_delete") | ||
def get_thread(request, thread_id: str): | ||
thread = get_single_thread(thread_id=thread_id, user=request.user, request=request) | ||
thread = use_cases.get_single_thread(thread_id=thread_id, user=request.user, request=request) | ||
return thread | ||
|
||
|
||
@api.patch("threads/{thread_id}/", response=ThreadSchema, url_name="thread_detail_update_delete") | ||
def update_thread(request, thread_id: str, payload: ThreadSchemaIn): | ||
thread = get_object_or_404(Thread, id=thread_id) | ||
name = payload.name | ||
return assistants.update_thread(thread=thread, name=name, user=request.user, request=request) | ||
return use_cases.update_thread(thread=thread, name=name, user=request.user, request=request) | ||
|
||
|
||
@api.delete("threads/{thread_id}/", response={204: None}, url_name="thread_detail_update_delete") | ||
def delete_thread(request, thread_id: str): | ||
thread = get_object_or_404(Thread, id=thread_id) | ||
assistants.delete_thread(thread=thread, user=request.user, request=request) | ||
use_cases.delete_thread(thread=thread, user=request.user, request=request) | ||
return 204, None | ||
|
||
|
||
|
@@ -86,7 +87,9 @@ def delete_thread(request, thread_id: str): | |
url_name="messages_list_create", | ||
) | ||
def list_thread_messages(request, thread_id: str): | ||
messages = get_thread_messages(thread_id=thread_id, user=request.user, request=request) | ||
messages = use_cases.get_thread_messages( | ||
thread_id=thread_id, user=request.user, request=request | ||
) | ||
return [message_to_dict(m)["data"] for m in messages] | ||
|
||
|
||
|
@@ -99,7 +102,7 @@ def list_thread_messages(request, thread_id: str): | |
def create_thread_message(request, thread_id: str, payload: ThreadMessagesSchemaIn): | ||
thread = Thread.objects.get(id=thread_id) | ||
|
||
create_message( | ||
use_cases.create_message( | ||
assistant_id=payload.assistant_id, | ||
thread=thread, | ||
user=request.user, | ||
|
@@ -114,7 +117,7 @@ def create_thread_message(request, thread_id: str, payload: ThreadMessagesSchema | |
) | ||
def delete_thread_message(request, thread_id: str, message_id: str): | ||
message = get_object_or_404(Message, id=message_id, thread_id=thread_id) | ||
assistants.delete_message( | ||
use_cases.delete_message( | ||
message=message, | ||
user=request.user, | ||
request=request, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,8 +3,6 @@ | |
import re | ||
from typing import Any, ClassVar, Sequence, cast | ||
|
||
from django.http import HttpRequest | ||
|
||
from langchain.agents import AgentExecutor | ||
from langchain.agents.format_scratchpad.tools import ( | ||
format_to_tool_messages, | ||
|
@@ -15,7 +13,6 @@ | |
DEFAULT_DOCUMENT_SEPARATOR, | ||
) | ||
from langchain_core.chat_history import InMemoryChatMessageHistory | ||
from langchain_core.messages import BaseMessage, HumanMessage | ||
from langchain_core.output_parsers import StrOutputParser | ||
from langchain_core.prompts import ( | ||
ChatPromptTemplate, | ||
|
@@ -37,22 +34,11 @@ | |
from langchain_core.tools import BaseTool | ||
from langchain_openai import ChatOpenAI | ||
|
||
from django_ai_assistant.ai.chat_message_histories import DjangoChatMessageHistory | ||
from django_ai_assistant.exceptions import ( | ||
AIAssistantMisconfiguredError, | ||
AIAssistantNotDefinedError, | ||
AIUserNotAllowedError, | ||
) | ||
from django_ai_assistant.models import Message, Thread | ||
from django_ai_assistant.permissions import ( | ||
can_create_message, | ||
can_create_thread, | ||
can_delete_message, | ||
can_delete_thread, | ||
can_run_assistant, | ||
) | ||
from django_ai_assistant.tools import Tool | ||
from django_ai_assistant.tools import tool as tool_decorator | ||
from django_ai_assistant.langchain.tools import Tool | ||
from django_ai_assistant.langchain.tools import tool as tool_decorator | ||
|
||
|
||
class AIAssistant(abc.ABC): # noqa: F821 | ||
|
@@ -156,6 +142,9 @@ def get_prompt_template(self): | |
) | ||
|
||
def get_message_history(self, thread_id: int | None): | ||
# DjangoChatMessageHistory must be here because Django may not be loaded yet elsewhere: | ||
from django_ai_assistant.langchain.chat_message_histories import DjangoChatMessageHistory | ||
|
||
if thread_id is None: | ||
return InMemoryChatMessageHistory() | ||
return DjangoChatMessageHistory(thread_id) | ||
|
@@ -309,154 +298,3 @@ def as_tool(self, description) -> BaseTool: | |
def register_assistant(cls: type[AIAssistant]): | ||
ASSISTANT_CLS_REGISTRY[cls.id] = cls | ||
return cls | ||
|
||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Moved to use_cases |
||
def _get_assistant_cls( | ||
assistant_id: str, | ||
user: Any, | ||
request: HttpRequest | None = None, | ||
): | ||
if assistant_id not in ASSISTANT_CLS_REGISTRY: | ||
raise AIAssistantNotDefinedError(f"Assistant with id={assistant_id} not found") | ||
assistant_cls = ASSISTANT_CLS_REGISTRY[assistant_id] | ||
if not can_run_assistant( | ||
assistant_cls=assistant_cls, | ||
user=user, | ||
request=request, | ||
): | ||
raise AIUserNotAllowedError("User is not allowed to use this assistant") | ||
return assistant_cls | ||
|
||
|
||
def get_single_assistant_info( | ||
assistant_id: str, | ||
user: Any, | ||
request: HttpRequest | None = None, | ||
): | ||
assistant_cls = _get_assistant_cls(assistant_id, user, request) | ||
|
||
return { | ||
"id": assistant_id, | ||
"name": assistant_cls.name, | ||
} | ||
|
||
|
||
def get_assistants_info( | ||
user: Any, | ||
request: HttpRequest | None = None, | ||
): | ||
return [ | ||
_get_assistant_cls(assistant_id=assistant_id, user=user, request=request) | ||
for assistant_id in ASSISTANT_CLS_REGISTRY.keys() | ||
] | ||
|
||
|
||
def create_message( | ||
assistant_id: str, | ||
thread: Thread, | ||
user: Any, | ||
content: Any, | ||
request: HttpRequest | None = None, | ||
): | ||
assistant_cls = _get_assistant_cls(assistant_id, user, request) | ||
|
||
if not can_create_message(thread=thread, user=user, request=request): | ||
raise AIUserNotAllowedError("User is not allowed to create messages in this thread") | ||
|
||
# TODO: Check if we can separate the message creation from the chain invoke | ||
assistant = assistant_cls(user=user, request=request) | ||
assistant_message = assistant.invoke( | ||
{"input": content}, | ||
thread_id=thread.id, | ||
) | ||
return assistant_message | ||
|
||
|
||
def create_thread( | ||
name: str, | ||
user: Any, | ||
request: HttpRequest | None = None, | ||
): | ||
if not can_create_thread(user=user, request=request): | ||
raise AIUserNotAllowedError("User is not allowed to create threads") | ||
|
||
thread = Thread.objects.create(name=name, created_by=user) | ||
return thread | ||
|
||
|
||
def get_single_thread( | ||
thread_id: str, | ||
user: Any, | ||
request: HttpRequest | None = None, | ||
): | ||
return Thread.objects.filter(created_by=user).get(id=thread_id) | ||
|
||
|
||
def get_threads( | ||
user: Any, | ||
request: HttpRequest | None = None, | ||
): | ||
return list(Thread.objects.filter(created_by=user)) | ||
|
||
|
||
def update_thread( | ||
thread: Thread, | ||
name: str, | ||
user: Any, | ||
request: HttpRequest | None = None, | ||
): | ||
if not can_delete_thread(thread=thread, user=user, request=request): | ||
raise AIUserNotAllowedError("User is not allowed to update this thread") | ||
|
||
thread.name = name | ||
thread.save() | ||
return thread | ||
|
||
|
||
def delete_thread( | ||
thread: Thread, | ||
user: Any, | ||
request: HttpRequest | None = None, | ||
): | ||
if not can_delete_thread(thread=thread, user=user, request=request): | ||
raise AIUserNotAllowedError("User is not allowed to delete this thread") | ||
|
||
return thread.delete() | ||
|
||
|
||
def get_thread_messages( | ||
thread_id: str, | ||
user: Any, | ||
request: HttpRequest | None = None, | ||
) -> list[BaseMessage]: | ||
# TODO: have more permissions for threads? View thread permission? | ||
thread = Thread.objects.get(id=thread_id) | ||
if user != thread.created_by: | ||
raise AIUserNotAllowedError("User is not allowed to view messages in this thread") | ||
|
||
return DjangoChatMessageHistory(thread.id).get_messages() | ||
|
||
|
||
def create_thread_message_as_user( | ||
thread_id: str, | ||
content: str, | ||
user: Any, | ||
request: HttpRequest | None = None, | ||
): | ||
# TODO: have more permissions for threads? View thread permission? | ||
thread = Thread.objects.get(id=thread_id) | ||
if user != thread.created_by: | ||
raise AIUserNotAllowedError("User is not allowed to create messages in this thread") | ||
|
||
DjangoChatMessageHistory(thread.id).add_messages([HumanMessage(content=content)]) | ||
|
||
|
||
def delete_message( | ||
message: Message, | ||
user: Any, | ||
request: HttpRequest | None = None, | ||
): | ||
if not can_delete_message(message=message, user=user, request=request): | ||
raise AIUserNotAllowedError("User is not allowed to delete this message") | ||
|
||
return DjangoChatMessageHistory(thread_id=message.thread_id).remove_messages([str(message.id)]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added all the main public API here (in traditional API sense, not web API) he