diff --git a/pilot/server/knowledge/chunk_db.py b/pilot/server/knowledge/chunk_db.py index e5147a73e..205b77044 100644 --- a/pilot/server/knowledge/chunk_db.py +++ b/pilot/server/knowledge/chunk_db.py @@ -2,15 +2,13 @@ from typing import List from sqlalchemy import Column, String, DateTime, Integer, Text, func -from sqlalchemy.orm import declarative_base +from pilot.base_modules.meta_data.base_dao import BaseDao +from pilot.base_modules.meta_data.meta_data import Base, engine, session from pilot.configs.config import Config -from pilot.connections.rdbms.base_dao import BaseDao CFG = Config() -Base = declarative_base() - class DocumentChunkEntity(Base): __tablename__ = "document_chunk" @@ -30,11 +28,11 @@ def __repr__(self): class DocumentChunkDao(BaseDao): def __init__(self): super().__init__( - database="knowledge_management", orm_base=Base, create_not_exist_table=True + database="dbgpt", orm_base=Base, db_engine=engine, session=session ) def create_documents_chunks(self, documents: List): - session = self.Session() + session = self.get_session() docs = [ DocumentChunkEntity( doc_name=document.doc_name, @@ -52,7 +50,7 @@ def create_documents_chunks(self, documents: List): session.close() def get_document_chunks(self, query: DocumentChunkEntity, page=1, page_size=20): - session = self.Session() + session = self.get_session() document_chunks = session.query(DocumentChunkEntity) if query.id is not None: document_chunks = document_chunks.filter(DocumentChunkEntity.id == query.id) @@ -82,7 +80,7 @@ def get_document_chunks(self, query: DocumentChunkEntity, page=1, page_size=20): return result def get_document_chunks_count(self, query: DocumentChunkEntity): - session = self.Session() + session = self.get_session() document_chunks = session.query(func.count(DocumentChunkEntity.id)) if query.id is not None: document_chunks = document_chunks.filter(DocumentChunkEntity.id == query.id) @@ -107,13 +105,13 @@ def get_document_chunks_count(self, query: DocumentChunkEntity): return count # def update_knowledge_document(self, document:KnowledgeDocumentEntity): - # session = self.Session() + # session = self.get_session() # updated_space = session.merge(document) # session.commit() # return updated_space.id def delete(self, document_id: int): - session = self.Session() + session = self.get_session() if document_id is None: raise Exception("document_id is None") query = DocumentChunkEntity(document_id=document_id) diff --git a/pilot/server/knowledge/document_db.py b/pilot/server/knowledge/document_db.py index 42b7ac8f0..5f7b47add 100644 --- a/pilot/server/knowledge/document_db.py +++ b/pilot/server/knowledge/document_db.py @@ -1,15 +1,13 @@ from datetime import datetime from sqlalchemy import Column, String, DateTime, Integer, Text, func -from sqlalchemy.orm import declarative_base +from pilot.base_modules.meta_data.base_dao import BaseDao +from pilot.base_modules.meta_data.meta_data import Base, engine, session from pilot.configs.config import Config -from pilot.connections.rdbms.base_dao import BaseDao CFG = Config() -Base = declarative_base() - class KnowledgeDocumentEntity(Base): __tablename__ = "knowledge_document" @@ -33,11 +31,11 @@ def __repr__(self): class KnowledgeDocumentDao(BaseDao): def __init__(self): super().__init__( - database="knowledge_management", orm_base=Base, create_not_exist_table=True + database="dbgpt", orm_base=Base, db_engine=engine, session=session ) def create_knowledge_document(self, document: KnowledgeDocumentEntity): - session = self.Session() + session = self.get_session() knowledge_document = KnowledgeDocumentEntity( doc_name=document.doc_name, doc_type=document.doc_type, @@ -58,7 +56,7 @@ def create_knowledge_document(self, document: KnowledgeDocumentEntity): return doc_id def get_knowledge_documents(self, query, page=1, page_size=20): - session = self.Session() + session = self.get_session() knowledge_documents = session.query(KnowledgeDocumentEntity) if query.id is not None: knowledge_documents = knowledge_documents.filter( @@ -92,7 +90,7 @@ def get_knowledge_documents(self, query, page=1, page_size=20): return result def get_documents(self, query): - session = self.Session() + session = self.get_session() knowledge_documents = session.query(KnowledgeDocumentEntity) if query.id is not None: knowledge_documents = knowledge_documents.filter( @@ -123,7 +121,7 @@ def get_documents(self, query): return result def get_knowledge_documents_count(self, query): - session = self.Session() + session = self.get_session() knowledge_documents = session.query(func.count(KnowledgeDocumentEntity.id)) if query.id is not None: knowledge_documents = knowledge_documents.filter( @@ -150,14 +148,14 @@ def get_knowledge_documents_count(self, query): return count def update_knowledge_document(self, document: KnowledgeDocumentEntity): - session = self.Session() + session = self.get_session() updated_space = session.merge(document) session.commit() return updated_space.id # def delete(self, query: KnowledgeDocumentEntity): - session = self.Session() + session = self.get_session() knowledge_documents = session.query(KnowledgeDocumentEntity) if query.id is not None: knowledge_documents = knowledge_documents.filter( diff --git a/pilot/server/knowledge/space_db.py b/pilot/server/knowledge/space_db.py index a47a5ffae..491fe303b 100644 --- a/pilot/server/knowledge/space_db.py +++ b/pilot/server/knowledge/space_db.py @@ -1,14 +1,13 @@ from datetime import datetime from sqlalchemy import Column, Integer, Text, String, DateTime -from sqlalchemy.ext.declarative import declarative_base +from pilot.base_modules.meta_data.base_dao import BaseDao +from pilot.base_modules.meta_data.meta_data import Base, engine, session from pilot.configs.config import Config from pilot.server.knowledge.request.request import KnowledgeSpaceRequest -from pilot.connections.rdbms.base_dao import BaseDao CFG = Config() -Base = declarative_base() class KnowledgeSpaceEntity(Base): @@ -29,11 +28,11 @@ def __repr__(self): class KnowledgeSpaceDao(BaseDao): def __init__(self): super().__init__( - database="knowledge_management", orm_base=Base, create_not_exist_table=True + database="dbgpt", orm_base=Base, db_engine=engine, session=session ) def create_knowledge_space(self, space: KnowledgeSpaceRequest): - session = self.Session() + session = self.get_session() knowledge_space = KnowledgeSpaceEntity( name=space.name, vector_type=CFG.VECTOR_STORE_TYPE, @@ -47,7 +46,7 @@ def create_knowledge_space(self, space: KnowledgeSpaceRequest): session.close() def get_knowledge_space(self, query: KnowledgeSpaceEntity): - session = self.Session() + session = self.get_session() knowledge_spaces = session.query(KnowledgeSpaceEntity) if query.id is not None: knowledge_spaces = knowledge_spaces.filter( @@ -86,14 +85,14 @@ def get_knowledge_space(self, query: KnowledgeSpaceEntity): return result def update_knowledge_space(self, space: KnowledgeSpaceEntity): - session = self.Session() + session = self.get_session() session.merge(space) session.commit() session.close() return True def delete_knowledge_space(self, space: KnowledgeSpaceEntity): - session = self.Session() + session = self.get_session() if space: session.delete(space) session.commit() diff --git a/pilot/server/prompt/prompt_manage_db.py b/pilot/server/prompt/prompt_manage_db.py index 6a02e6b5c..56bbac20d 100644 --- a/pilot/server/prompt/prompt_manage_db.py +++ b/pilot/server/prompt/prompt_manage_db.py @@ -1,15 +1,14 @@ from datetime import datetime from sqlalchemy import Column, Integer, Text, String, DateTime -from sqlalchemy.ext.declarative import declarative_base +from pilot.base_modules.meta_data.base_dao import BaseDao +from pilot.base_modules.meta_data.meta_data import Base, engine, session from pilot.configs.config import Config -from pilot.connections.rdbms.base_dao import BaseDao from pilot.server.prompt.request.request import PromptManageRequest CFG = Config() -Base = declarative_base() class PromptManageEntity(Base): @@ -31,11 +30,11 @@ def __repr__(self): class PromptManageDao(BaseDao): def __init__(self): super().__init__( - database="prompt_management", orm_base=Base, create_not_exist_table=True + database="dbgpt", orm_base=Base, db_engine=engine, session=session ) def create_prompt(self, prompt: PromptManageRequest): - session = self.Session() + session = self.get_session() prompt_manage = PromptManageEntity( chat_scene=prompt.chat_scene, sub_chat_scene=prompt.sub_chat_scene, @@ -51,7 +50,7 @@ def create_prompt(self, prompt: PromptManageRequest): session.close() def get_prompts(self, query: PromptManageEntity): - session = self.Session() + session = self.get_session() prompts = session.query(PromptManageEntity) if query.chat_scene is not None: prompts = prompts.filter(PromptManageEntity.chat_scene == query.chat_scene) @@ -78,13 +77,13 @@ def get_prompts(self, query: PromptManageEntity): return result def update_prompt(self, prompt: PromptManageEntity): - session = self.Session() + session = self.get_session() session.merge(prompt) session.commit() session.close() def delete_prompt(self, prompt: PromptManageEntity): - session = self.Session() + session = self.get_session() if prompt: session.delete(prompt) session.commit()