Skip to content

Commit

Permalink
feat:switch BaseDao from meta_data
Browse files Browse the repository at this point in the history
  • Loading branch information
Aries-ckt committed Oct 19, 2023
1 parent 9e5a7be commit 9b662c0
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 37 deletions.
18 changes: 8 additions & 10 deletions pilot/server/knowledge/chunk_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,13 @@
from typing import List

from sqlalchemy import Column, String, DateTime, Integer, Text, func
from sqlalchemy.orm import declarative_base

from pilot.base_modules.meta_data.base_dao import BaseDao
from pilot.base_modules.meta_data.meta_data import Base, engine, session
from pilot.configs.config import Config
from pilot.connections.rdbms.base_dao import BaseDao

CFG = Config()

Base = declarative_base()


class DocumentChunkEntity(Base):
__tablename__ = "document_chunk"
Expand All @@ -30,11 +28,11 @@ def __repr__(self):
class DocumentChunkDao(BaseDao):
def __init__(self):
super().__init__(
database="knowledge_management", orm_base=Base, create_not_exist_table=True
database="dbgpt", orm_base=Base, db_engine=engine, session=session
)

def create_documents_chunks(self, documents: List):
session = self.Session()
session = self.get_session()
docs = [
DocumentChunkEntity(
doc_name=document.doc_name,
Expand All @@ -52,7 +50,7 @@ def create_documents_chunks(self, documents: List):
session.close()

def get_document_chunks(self, query: DocumentChunkEntity, page=1, page_size=20):
session = self.Session()
session = self.get_session()
document_chunks = session.query(DocumentChunkEntity)
if query.id is not None:
document_chunks = document_chunks.filter(DocumentChunkEntity.id == query.id)
Expand Down Expand Up @@ -82,7 +80,7 @@ def get_document_chunks(self, query: DocumentChunkEntity, page=1, page_size=20):
return result

def get_document_chunks_count(self, query: DocumentChunkEntity):
session = self.Session()
session = self.get_session()
document_chunks = session.query(func.count(DocumentChunkEntity.id))
if query.id is not None:
document_chunks = document_chunks.filter(DocumentChunkEntity.id == query.id)
Expand All @@ -107,13 +105,13 @@ def get_document_chunks_count(self, query: DocumentChunkEntity):
return count

# def update_knowledge_document(self, document:KnowledgeDocumentEntity):
# session = self.Session()
# session = self.get_session()
# updated_space = session.merge(document)
# session.commit()
# return updated_space.id

def delete(self, document_id: int):
session = self.Session()
session = self.get_session()
if document_id is None:
raise Exception("document_id is None")
query = DocumentChunkEntity(document_id=document_id)
Expand Down
20 changes: 9 additions & 11 deletions pilot/server/knowledge/document_db.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
from datetime import datetime

from sqlalchemy import Column, String, DateTime, Integer, Text, func
from sqlalchemy.orm import declarative_base

from pilot.base_modules.meta_data.base_dao import BaseDao
from pilot.base_modules.meta_data.meta_data import Base, engine, session
from pilot.configs.config import Config
from pilot.connections.rdbms.base_dao import BaseDao

CFG = Config()

Base = declarative_base()


class KnowledgeDocumentEntity(Base):
__tablename__ = "knowledge_document"
Expand All @@ -33,11 +31,11 @@ def __repr__(self):
class KnowledgeDocumentDao(BaseDao):
def __init__(self):
super().__init__(
database="knowledge_management", orm_base=Base, create_not_exist_table=True
database="dbgpt", orm_base=Base, db_engine=engine, session=session
)

def create_knowledge_document(self, document: KnowledgeDocumentEntity):
session = self.Session()
session = self.get_session()
knowledge_document = KnowledgeDocumentEntity(
doc_name=document.doc_name,
doc_type=document.doc_type,
Expand All @@ -58,7 +56,7 @@ def create_knowledge_document(self, document: KnowledgeDocumentEntity):
return doc_id

def get_knowledge_documents(self, query, page=1, page_size=20):
session = self.Session()
session = self.get_session()
knowledge_documents = session.query(KnowledgeDocumentEntity)
if query.id is not None:
knowledge_documents = knowledge_documents.filter(
Expand Down Expand Up @@ -92,7 +90,7 @@ def get_knowledge_documents(self, query, page=1, page_size=20):
return result

def get_documents(self, query):
session = self.Session()
session = self.get_session()
knowledge_documents = session.query(KnowledgeDocumentEntity)
if query.id is not None:
knowledge_documents = knowledge_documents.filter(
Expand Down Expand Up @@ -123,7 +121,7 @@ def get_documents(self, query):
return result

def get_knowledge_documents_count(self, query):
session = self.Session()
session = self.get_session()
knowledge_documents = session.query(func.count(KnowledgeDocumentEntity.id))
if query.id is not None:
knowledge_documents = knowledge_documents.filter(
Expand All @@ -150,14 +148,14 @@ def get_knowledge_documents_count(self, query):
return count

def update_knowledge_document(self, document: KnowledgeDocumentEntity):
session = self.Session()
session = self.get_session()
updated_space = session.merge(document)
session.commit()
return updated_space.id

#
def delete(self, query: KnowledgeDocumentEntity):
session = self.Session()
session = self.get_session()
knowledge_documents = session.query(KnowledgeDocumentEntity)
if query.id is not None:
knowledge_documents = knowledge_documents.filter(
Expand Down
15 changes: 7 additions & 8 deletions pilot/server/knowledge/space_db.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from datetime import datetime

from sqlalchemy import Column, Integer, Text, String, DateTime
from sqlalchemy.ext.declarative import declarative_base

from pilot.base_modules.meta_data.base_dao import BaseDao
from pilot.base_modules.meta_data.meta_data import Base, engine, session
from pilot.configs.config import Config
from pilot.server.knowledge.request.request import KnowledgeSpaceRequest
from pilot.connections.rdbms.base_dao import BaseDao

CFG = Config()
Base = declarative_base()


class KnowledgeSpaceEntity(Base):
Expand All @@ -29,11 +28,11 @@ def __repr__(self):
class KnowledgeSpaceDao(BaseDao):
def __init__(self):
super().__init__(
database="knowledge_management", orm_base=Base, create_not_exist_table=True
database="dbgpt", orm_base=Base, db_engine=engine, session=session
)

def create_knowledge_space(self, space: KnowledgeSpaceRequest):
session = self.Session()
session = self.get_session()
knowledge_space = KnowledgeSpaceEntity(
name=space.name,
vector_type=CFG.VECTOR_STORE_TYPE,
Expand All @@ -47,7 +46,7 @@ def create_knowledge_space(self, space: KnowledgeSpaceRequest):
session.close()

def get_knowledge_space(self, query: KnowledgeSpaceEntity):
session = self.Session()
session = self.get_session()
knowledge_spaces = session.query(KnowledgeSpaceEntity)
if query.id is not None:
knowledge_spaces = knowledge_spaces.filter(
Expand Down Expand Up @@ -86,14 +85,14 @@ def get_knowledge_space(self, query: KnowledgeSpaceEntity):
return result

def update_knowledge_space(self, space: KnowledgeSpaceEntity):
session = self.Session()
session = self.get_session()
session.merge(space)
session.commit()
session.close()
return True

def delete_knowledge_space(self, space: KnowledgeSpaceEntity):
session = self.Session()
session = self.get_session()
if space:
session.delete(space)
session.commit()
Expand Down
15 changes: 7 additions & 8 deletions pilot/server/prompt/prompt_manage_db.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
from datetime import datetime

from sqlalchemy import Column, Integer, Text, String, DateTime
from sqlalchemy.ext.declarative import declarative_base

from pilot.base_modules.meta_data.base_dao import BaseDao
from pilot.base_modules.meta_data.meta_data import Base, engine, session
from pilot.configs.config import Config
from pilot.connections.rdbms.base_dao import BaseDao

from pilot.server.prompt.request.request import PromptManageRequest

CFG = Config()
Base = declarative_base()


class PromptManageEntity(Base):
Expand All @@ -31,11 +30,11 @@ def __repr__(self):
class PromptManageDao(BaseDao):
def __init__(self):
super().__init__(
database="prompt_management", orm_base=Base, create_not_exist_table=True
database="dbgpt", orm_base=Base, db_engine=engine, session=session
)

def create_prompt(self, prompt: PromptManageRequest):
session = self.Session()
session = self.get_session()
prompt_manage = PromptManageEntity(
chat_scene=prompt.chat_scene,
sub_chat_scene=prompt.sub_chat_scene,
Expand All @@ -51,7 +50,7 @@ def create_prompt(self, prompt: PromptManageRequest):
session.close()

def get_prompts(self, query: PromptManageEntity):
session = self.Session()
session = self.get_session()
prompts = session.query(PromptManageEntity)
if query.chat_scene is not None:
prompts = prompts.filter(PromptManageEntity.chat_scene == query.chat_scene)
Expand All @@ -78,13 +77,13 @@ def get_prompts(self, query: PromptManageEntity):
return result

def update_prompt(self, prompt: PromptManageEntity):
session = self.Session()
session = self.get_session()
session.merge(prompt)
session.commit()
session.close()

def delete_prompt(self, prompt: PromptManageEntity):
session = self.Session()
session = self.get_session()
if prompt:
session.delete(prompt)
session.commit()
Expand Down

0 comments on commit 9b662c0

Please sign in to comment.