Skip to content

Commit

Permalink
Transforms cast_id into decorator
Browse files Browse the repository at this point in the history
  • Loading branch information
amandasavluchinske committed Jun 21, 2024
1 parent 59db142 commit 0e55aa1
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 9 deletions.
17 changes: 9 additions & 8 deletions django_ai_assistant/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
ThreadSchemaIn,
)
from django_ai_assistant.conf import app_settings
from django_ai_assistant.decorators import with_cast_id
from django_ai_assistant.exceptions import AIAssistantNotDefinedError, AIUserNotAllowedError
from django_ai_assistant.helpers import formatters, use_cases
from django_ai_assistant.helpers import use_cases
from django_ai_assistant.models import Message, Thread


Expand Down Expand Up @@ -85,8 +86,8 @@ def create_thread(request, payload: ThreadSchemaIn):


@api.get("threads/{thread_id}/", response=ThreadSchema, url_name="thread_detail_update_delete")
@with_cast_id
def get_thread(request, thread_id: Any):
thread_id = formatters.format_id(id, Thread)
try:
thread = use_cases.get_single_thread(
thread_id=thread_id, user=request.user, request=request
Expand All @@ -97,16 +98,16 @@ def get_thread(request, thread_id: Any):


@api.patch("threads/{thread_id}/", response=ThreadSchema, url_name="thread_detail_update_delete")
@with_cast_id
def update_thread(request, thread_id: Any, payload: ThreadSchemaIn):
thread_id = formatters.format_id(thread_id, Thread)
thread = get_object_or_404(Thread, id=thread_id)
name = payload.name
return use_cases.update_thread(thread=thread, name=name, user=request.user, request=request)


@api.delete("threads/{thread_id}/", response={204: None}, url_name="thread_detail_update_delete")
@with_cast_id
def delete_thread(request, thread_id: Any):
thread_id = formatters.format_id(thread_id, Thread)
thread = get_object_or_404(Thread, id=thread_id)
use_cases.delete_thread(thread=thread, user=request.user, request=request)
return 204, None
Expand All @@ -117,8 +118,9 @@ def delete_thread(request, thread_id: Any):
response=List[ThreadMessagesSchemaOut],
url_name="messages_list_create",
)
@with_cast_id
def list_thread_messages(request, thread_id: Any):
thread = get_object_or_404(Thread, id=formatters.format_id(thread_id, Thread))
thread = get_object_or_404(Thread, id=thread_id)
messages = use_cases.get_thread_messages(thread=thread, user=request.user, request=request)
return [message_to_dict(m)["data"] for m in messages]

Expand All @@ -129,8 +131,8 @@ def list_thread_messages(request, thread_id: Any):
response={201: None},
url_name="messages_list_create",
)
@with_cast_id
def create_thread_message(request, thread_id: Any, payload: ThreadMessagesSchemaIn):
thread_id = formatters.format_id(thread_id, Thread)
thread = Thread.objects.get(id=thread_id)

use_cases.create_message(
Expand All @@ -146,9 +148,8 @@ def create_thread_message(request, thread_id: Any, payload: ThreadMessagesSchema
@api.delete(
"threads/{thread_id}/messages/{message_id}/", response={204: None}, url_name="messages_delete"
)
@with_cast_id
def delete_thread_message(request, thread_id: Any, message_id: Any):
thread_id = formatters.format_id(thread_id, Message)
message_id = formatters.format_id(message_id, Message)
message = get_object_or_404(Message, id=message_id, thread_id=thread_id)
use_cases.delete_message(
message=message,
Expand Down
23 changes: 23 additions & 0 deletions django_ai_assistant/decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from functools import wraps

from django_ai_assistant.helpers.formatters import cast_id
from django_ai_assistant.models import Message, Thread


def with_cast_id(func):
@wraps(func)
def wrapper(*args, **kwargs):
thread_id = kwargs.get("thread_id")
message_id = kwargs.get("message_id")

if thread_id:
thread_id = cast_id(thread_id, Thread)
kwargs["thread_id"] = thread_id

if message_id:
message_id = cast_id(message_id, Message)
kwargs["message_id"] = message_id

return func(*args, **kwargs)

return wrapper
2 changes: 1 addition & 1 deletion django_ai_assistant/helpers/formatters.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import uuid


def format_id(item_id, model):
def cast_id(item_id, model):
if isinstance(item_id, str) and "UUID" in model._meta.pk.get_internal_type():
return uuid.UUID(item_id)
return item_id

0 comments on commit 0e55aa1

Please sign in to comment.