Skip to content

Commit

Permalink
fix: fix chat result model
Browse files Browse the repository at this point in the history
  • Loading branch information
Mini256 committed Feb 27, 2025
1 parent 29b7d6f commit 9b3c481
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions backend/app/rag/chat/chat_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import logging

from typing import Generator, List, Optional
from uuid import UUID

from fastapi import HTTPException
from pydantic import BaseModel
from sqlalchemy import text, delete
Expand All @@ -22,16 +24,15 @@
ChatEngine,
)
from app.models.recommend_question import RecommendQuestion
from app.rag.chat.retrieve.retrieve_flow import RetrieveFlow
from app.rag.chat.retrieve.retrieve_flow import RetrieveFlow, SourceDocument
from app.rag.chat.stream_protocol import ChatEvent
from app.rag.retrievers.knowledge_graph.schema import (
KnowledgeGraphRetrievalResult,
RetrievedEntity,
StoredKnowledgeGraph,
RetrievedSubGraph,
)
from app.rag.knowledge_base.index_store import get_kb_tidb_graph_store
from app.repositories import chat_engine_repo, knowledge_base_repo
from app.repositories import knowledge_base_repo

from app.rag.chat.config import (
ChatEngineConfig,
Expand All @@ -50,11 +51,11 @@


class ChatResult(BaseModel):
chat_id: int
chat_id: UUID
message_id: int
trace: str
sources: List[RetrievedEntity]
content: str
trace: Optional[str] = None
sources: Optional[List[SourceDocument]] = []


def get_final_chat_result(
Expand All @@ -68,13 +69,12 @@ def get_final_chat_result(
if m.event_type == ChatEventType.MESSAGE_ANNOTATIONS_PART:
if m.payload.state == ChatMessageSate.SOURCE_NODES:
sources = m.payload.context
elif m.payload.state == ChatMessageSate.TRACE:
trace = m.payload.context
elif m.event_type == ChatEventType.TEXT_PART:
content += m.payload
elif m.event_type == ChatEventType.DATA_PART:
chat_id = m.payload.chat.id
message_id = m.payload.assistant_message.id
trace = m.payload.assistant_message.trace_url
elif m.event_type == ChatEventType.ERROR_PART:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
Expand Down

0 comments on commit 9b3c481

Please sign in to comment.