Skip to content

Commit

Permalink
refactor: Refactor storage system (eosphoros-ai#937)
Browse files Browse the repository at this point in the history
  • Loading branch information
fangyinc authored and Hopshine committed Sep 10, 2024
1 parent 089f7b2 commit 650999a
Show file tree
Hide file tree
Showing 55 changed files with 3,788 additions and 688 deletions.
408 changes: 229 additions & 179 deletions assets/schema/knowledge_management.sql

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions dbgpt/_private/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def __init__(self) -> None:
self.LOCAL_DB_USER = os.getenv("LOCAL_DB_USER", "root")
self.LOCAL_DB_PASSWORD = os.getenv("LOCAL_DB_PASSWORD", "aa123456")
self.LOCAL_DB_POOL_SIZE = int(os.getenv("LOCAL_DB_POOL_SIZE", 10))
self.LOCAL_DB_POOL_OVERFLOW = int(os.getenv("LOCAL_DB_POOL_OVERFLOW", 20))

self.CHAT_HISTORY_STORE_TYPE = os.getenv("CHAT_HISTORY_STORE_TYPE", "db")

Expand Down
32 changes: 9 additions & 23 deletions dbgpt/agent/db/my_plugin_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,10 @@
from sqlalchemy import Column, Integer, String, DateTime, func
from sqlalchemy import UniqueConstraint

from dbgpt.storage.metadata import BaseDao
from dbgpt.storage.metadata.meta_data import (
Base,
engine,
session,
META_DATA_DATABASE,
)
from dbgpt.storage.metadata import BaseDao, Model


class MyPluginEntity(Base):
class MyPluginEntity(Model):
__tablename__ = "my_plugin"
__table_args__ = {
"mysql_charset": "utf8mb4",
Expand Down Expand Up @@ -39,16 +33,8 @@ class MyPluginEntity(Base):


class MyPluginDao(BaseDao[MyPluginEntity]):
def __init__(self):
super().__init__(
database=META_DATA_DATABASE,
orm_base=Base,
db_engine=engine,
session=session,
)

def add(self, engity: MyPluginEntity):
session = self.get_session()
session = self.get_raw_session()
my_plugin = MyPluginEntity(
tenant=engity.tenant,
user_code=engity.user_code,
Expand All @@ -68,13 +54,13 @@ def add(self, engity: MyPluginEntity):
return id

def update(self, entity: MyPluginEntity):
session = self.get_session()
session = self.get_raw_session()
updated = session.merge(entity)
session.commit()
return updated.id

def get_by_user(self, user: str) -> list[MyPluginEntity]:
session = self.get_session()
session = self.get_raw_session()
my_plugins = session.query(MyPluginEntity)
if user:
my_plugins = my_plugins.filter(MyPluginEntity.user_code == user)
Expand All @@ -83,7 +69,7 @@ def get_by_user(self, user: str) -> list[MyPluginEntity]:
return result

def get_by_user_and_plugin(self, user: str, plugin: str) -> MyPluginEntity:
session = self.get_session()
session = self.get_raw_session()
my_plugins = session.query(MyPluginEntity)
if user:
my_plugins = my_plugins.filter(MyPluginEntity.user_code == user)
Expand All @@ -93,7 +79,7 @@ def get_by_user_and_plugin(self, user: str, plugin: str) -> MyPluginEntity:
return result

def list(self, query: MyPluginEntity, page=1, page_size=20) -> list[MyPluginEntity]:
session = self.get_session()
session = self.get_raw_session()
my_plugins = session.query(MyPluginEntity)
all_count = my_plugins.count()
if query.id is not None:
Expand Down Expand Up @@ -122,7 +108,7 @@ def list(self, query: MyPluginEntity, page=1, page_size=20) -> list[MyPluginEnti
return result, total_pages, all_count

def count(self, query: MyPluginEntity):
session = self.get_session()
session = self.get_raw_session()
my_plugins = session.query(func.count(MyPluginEntity.id))
if query.id is not None:
my_plugins = my_plugins.filter(MyPluginEntity.id == query.id)
Expand All @@ -143,7 +129,7 @@ def count(self, query: MyPluginEntity):
return count

def delete(self, plugin_id: int):
session = self.get_session()
session = self.get_raw_session()
if plugin_id is None:
raise Exception("plugin_id is None")
query = MyPluginEntity(id=plugin_id)
Expand Down
32 changes: 9 additions & 23 deletions dbgpt/agent/db/plugin_hub_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,13 @@
from sqlalchemy import Column, Integer, String, Index, DateTime, func, DDL
from sqlalchemy import UniqueConstraint

from dbgpt.storage.metadata import BaseDao
from dbgpt.storage.metadata.meta_data import (
Base,
engine,
session,
META_DATA_DATABASE,
)
from dbgpt.storage.metadata import BaseDao, Model

# TODO We should consider that the production environment does not have permission to execute the DDL
char_set_sql = DDL("ALTER TABLE plugin_hub CONVERT TO CHARACTER SET utf8mb4")


class PluginHubEntity(Base):
class PluginHubEntity(Model):
__tablename__ = "plugin_hub"
__table_args__ = {
"mysql_charset": "utf8mb4",
Expand Down Expand Up @@ -43,16 +37,8 @@ class PluginHubEntity(Base):


class PluginHubDao(BaseDao[PluginHubEntity]):
def __init__(self):
super().__init__(
database=META_DATA_DATABASE,
orm_base=Base,
db_engine=engine,
session=session,
)

def add(self, engity: PluginHubEntity):
session = self.get_session()
session = self.get_raw_session()
timezone = pytz.timezone("Asia/Shanghai")
plugin_hub = PluginHubEntity(
name=engity.name,
Expand All @@ -71,7 +57,7 @@ def add(self, engity: PluginHubEntity):
return id

def update(self, entity: PluginHubEntity):
session = self.get_session()
session = self.get_raw_session()
try:
updated = session.merge(entity)
session.commit()
Expand All @@ -82,7 +68,7 @@ def update(self, entity: PluginHubEntity):
def list(
self, query: PluginHubEntity, page=1, page_size=20
) -> list[PluginHubEntity]:
session = self.get_session()
session = self.get_raw_session()
plugin_hubs = session.query(PluginHubEntity)
all_count = plugin_hubs.count()

Expand Down Expand Up @@ -111,23 +97,23 @@ def list(
return result, total_pages, all_count

def get_by_storage_url(self, storage_url):
session = self.get_session()
session = self.get_raw_session()
plugin_hubs = session.query(PluginHubEntity)
plugin_hubs = plugin_hubs.filter(PluginHubEntity.storage_url == storage_url)
result = plugin_hubs.all()
session.close()
return result

def get_by_name(self, name: str) -> PluginHubEntity:
session = self.get_session()
session = self.get_raw_session()
plugin_hubs = session.query(PluginHubEntity)
plugin_hubs = plugin_hubs.filter(PluginHubEntity.name == name)
result = plugin_hubs.first()
session.close()
return result

def count(self, query: PluginHubEntity):
session = self.get_session()
session = self.get_raw_session()
plugin_hubs = session.query(func.count(PluginHubEntity.id))
if query.id is not None:
plugin_hubs = plugin_hubs.filter(PluginHubEntity.id == query.id)
Expand All @@ -146,7 +132,7 @@ def count(self, query: PluginHubEntity):
return count

def delete(self, plugin_id: int):
session = self.get_session()
session = self.get_raw_session()
if plugin_id is None:
raise Exception("plugin_id is None")
plugin_hubs = session.query(PluginHubEntity)
Expand Down
40 changes: 15 additions & 25 deletions dbgpt/agent/hub/agent_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,18 +59,12 @@ def install_plugin(self, plugin_name: str, user_name: str = None):
else:
my_plugin_entity.user_code = Default_User

with self.hub_dao.get_session() as session:
try:
if my_plugin_entity.id is None:
session.add(my_plugin_entity)
else:
session.merge(my_plugin_entity)
session.merge(plugin_entity)
session.commit()
session.close()
except Exception as e:
logger.error("install merge roll back!" + str(e))
session.rollback()
with self.hub_dao.session() as session:
if my_plugin_entity.id is None:
session.add(my_plugin_entity)
else:
session.merge(my_plugin_entity)
session.merge(plugin_entity)
except Exception as e:
logger.error("install pluguin exception!", e)
raise ValueError(f"Install Plugin {plugin_name} Faild! {str(e)}")
Expand All @@ -87,19 +81,15 @@ def uninstall_plugin(self, plugin_name, user):
my_plugin_entity = self.my_plugin_dao.get_by_user_and_plugin(user, plugin_name)
if plugin_entity is not None:
plugin_entity.installed = plugin_entity.installed - 1
with self.hub_dao.get_session() as session:
try:
my_plugin_q = session.query(MyPluginEntity).filter(
MyPluginEntity.name == plugin_name
)
if user:
my_plugin_q.filter(MyPluginEntity.user_code == user)
my_plugin_q.delete()
if plugin_entity is not None:
session.merge(plugin_entity)
session.commit()
except:
session.rollback()
with self.hub_dao.session() as session:
my_plugin_q = session.query(MyPluginEntity).filter(
MyPluginEntity.name == plugin_name
)
if user:
my_plugin_q.filter(MyPluginEntity.user_code == user)
my_plugin_q.delete()
if plugin_entity is not None:
session.merge(plugin_entity)

if plugin_entity is not None:
# delete package file if not use
Expand Down
Loading

0 comments on commit 650999a

Please sign in to comment.