From 9b3c4812719968d2e1e829a59058948f383c2ae3 Mon Sep 17 00:00:00 2001 From: Mini256 Date: Thu, 27 Feb 2025 12:23:06 +0800 Subject: [PATCH] fix: fix chat result model --- backend/app/rag/chat/chat_service.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/backend/app/rag/chat/chat_service.py b/backend/app/rag/chat/chat_service.py index e2a1b045..d8d2102a 100644 --- a/backend/app/rag/chat/chat_service.py +++ b/backend/app/rag/chat/chat_service.py @@ -2,6 +2,8 @@ import logging from typing import Generator, List, Optional +from uuid import UUID + from fastapi import HTTPException from pydantic import BaseModel from sqlalchemy import text, delete @@ -22,16 +24,15 @@ ChatEngine, ) from app.models.recommend_question import RecommendQuestion -from app.rag.chat.retrieve.retrieve_flow import RetrieveFlow +from app.rag.chat.retrieve.retrieve_flow import RetrieveFlow, SourceDocument from app.rag.chat.stream_protocol import ChatEvent from app.rag.retrievers.knowledge_graph.schema import ( KnowledgeGraphRetrievalResult, - RetrievedEntity, StoredKnowledgeGraph, RetrievedSubGraph, ) from app.rag.knowledge_base.index_store import get_kb_tidb_graph_store -from app.repositories import chat_engine_repo, knowledge_base_repo +from app.repositories import knowledge_base_repo from app.rag.chat.config import ( ChatEngineConfig, @@ -50,11 +51,11 @@ class ChatResult(BaseModel): - chat_id: int + chat_id: UUID message_id: int - trace: str - sources: List[RetrievedEntity] content: str + trace: Optional[str] = None + sources: Optional[List[SourceDocument]] = [] def get_final_chat_result( @@ -68,13 +69,12 @@ def get_final_chat_result( if m.event_type == ChatEventType.MESSAGE_ANNOTATIONS_PART: if m.payload.state == ChatMessageSate.SOURCE_NODES: sources = m.payload.context - elif m.payload.state == ChatMessageSate.TRACE: - trace = m.payload.context elif m.event_type == ChatEventType.TEXT_PART: content += m.payload elif m.event_type == ChatEventType.DATA_PART: chat_id = m.payload.chat.id message_id = m.payload.assistant_message.id + trace = m.payload.assistant_message.trace_url elif m.event_type == ChatEventType.ERROR_PART: raise HTTPException( status_code=HTTPStatus.INTERNAL_SERVER_ERROR,