Skip to content

Commit

Permalink
refactor: Refactor storage and new serve template (eosphoros-ai#947)
Browse files Browse the repository at this point in the history
  • Loading branch information
fangyinc authored and penghou.ho committed Jan 18, 2024
1 parent 42d73e5 commit 14f3172
Show file tree
Hide file tree
Showing 63 changed files with 1,889 additions and 236 deletions.
14 changes: 5 additions & 9 deletions .env.template
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,15 @@ WEB_SERVER_PORT=7860
#*******************************************************************#
#** LLM MODELS **#
#*******************************************************************#
# LLM_MODEL, see /pilot/configs/model_config.LLM_MODEL_CONFIG
LLM_MODEL=vicuna-7b-v1.5
# LLM_MODEL=yi-6b-chat
# LLM_MODEL=vicuna-13b-v1.5
# LLM_MODEL=qwen-7b-chat-int4
# LLM_MODEL=llama-cpp
# LLM_MODEL, see dbgpt/configs/model_config.LLM_MODEL_CONFIG
LLM_MODEL=vicuna-13b-v1.5
## LLM model path, by default, DB-GPT will read the model path from LLM_MODEL_CONFIG based on the LLM_MODEL.
## Of course you can specify your model path according to LLM_MODEL_PATH
## In DB-GPT, the priority from high to low to read model path:
## 1. environment variable with key: {LLM_MODEL}_MODEL_PATH (Avoid multi-model conflicts)
## 2. environment variable with key: MODEL_PATH
## 3. environment variable with key: LLM_MODEL_PATH
## 4. the config in /pilot/configs/model_config.LLM_MODEL_CONFIG
## 4. the config in dbgpt/configs/model_config.LLM_MODEL_CONFIG
# LLM_MODEL_PATH=/app/models/vicuna-13b-v1.5
# LLM_PROMPT_TEMPLATE=vicuna_v1.1
MODEL_SERVER=http://127.0.0.1:8000
Expand All @@ -51,7 +47,7 @@ QUANTIZE_8bit=False # True
# PROXYLLM_BACKEND=

### You can configure parameters for a specific model with {model name}_{config key}=xxx
### See /pilot/model/parameter.py
### See dbgpt/model/parameter.py
## prompt template for current model
# llama_cpp_prompt_template=vicuna_v1.1
## llama-2-70b must be 8
Expand Down Expand Up @@ -92,7 +88,7 @@ KNOWLEDGE_SEARCH_REWRITE=False
# EMBEDDING_TOKENIZER=all-MiniLM-L6-v2
# EMBEDDING_TOKEN_LIMIT=8191

## Openai embedding model, See /pilot/model/parameter.py
## Openai embedding model, See dbgpt/model/parameter.py
# EMBEDDING_MODEL=proxy_openai
# proxy_openai_proxy_server_url=https://api.openai.com/v1
# proxy_openai_proxy_api_key={your-openai-sk}
Expand Down
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ __pycache__/

message/

.env
.env*
.vscode
.idea
.chroma
Expand Down
10 changes: 3 additions & 7 deletions dbgpt/agent/db/my_plugin_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,6 @@

class MyPluginEntity(Model):
__tablename__ = "my_plugin"
__table_args__ = {
"mysql_charset": "utf8mb4",
"mysql_collate": "utf8mb4_unicode_ci",
}
id = Column(Integer, primary_key=True, comment="autoincrement id")
tenant = Column(String(255), nullable=True, comment="user's tenant")
user_code = Column(String(255), nullable=False, comment="user code")
Expand All @@ -32,7 +28,7 @@ class MyPluginEntity(Model):
UniqueConstraint("user_code", "name", name="uk_name")


class MyPluginDao(BaseDao[MyPluginEntity]):
class MyPluginDao(BaseDao):
def add(self, engity: MyPluginEntity):
session = self.get_raw_session()
my_plugin = MyPluginEntity(
Expand All @@ -53,7 +49,7 @@ def add(self, engity: MyPluginEntity):
session.close()
return id

def update(self, entity: MyPluginEntity):
def raw_update(self, entity: MyPluginEntity):
session = self.get_raw_session()
updated = session.merge(entity)
session.commit()
Expand Down Expand Up @@ -128,7 +124,7 @@ def count(self, query: MyPluginEntity):
session.close()
return count

def delete(self, plugin_id: int):
def raw_delete(self, plugin_id: int):
session = self.get_raw_session()
if plugin_id is None:
raise Exception("plugin_id is None")
Expand Down
10 changes: 3 additions & 7 deletions dbgpt/agent/db/plugin_hub_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,6 @@

class PluginHubEntity(Model):
__tablename__ = "plugin_hub"
__table_args__ = {
"mysql_charset": "utf8mb4",
"mysql_collate": "utf8mb4_unicode_ci",
}
id = Column(
Integer, primary_key=True, autoincrement=True, comment="autoincrement id"
)
Expand All @@ -36,7 +32,7 @@ class PluginHubEntity(Model):
Index("idx_q_type", "type")


class PluginHubDao(BaseDao[PluginHubEntity]):
class PluginHubDao(BaseDao):
def add(self, engity: PluginHubEntity):
session = self.get_raw_session()
timezone = pytz.timezone("Asia/Shanghai")
Expand All @@ -56,7 +52,7 @@ def add(self, engity: PluginHubEntity):
session.close()
return id

def update(self, entity: PluginHubEntity):
def raw_update(self, entity: PluginHubEntity):
session = self.get_raw_session()
try:
updated = session.merge(entity)
Expand Down Expand Up @@ -131,7 +127,7 @@ def count(self, query: PluginHubEntity):
session.close()
return count

def delete(self, plugin_id: int):
def raw_delete(self, plugin_id: int):
session = self.get_raw_session()
if plugin_id is None:
raise Exception("plugin_id is None")
Expand Down
4 changes: 2 additions & 2 deletions dbgpt/agent/hub/agent_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def refresh_hub_from_git(
plugin_hub_info.name = git_plugin._name
plugin_hub_info.version = git_plugin._version
plugin_hub_info.description = git_plugin._description
self.hub_dao.update(plugin_hub_info)
self.hub_dao.raw_update(plugin_hub_info)
except Exception as e:
raise ValueError(f"Update Agent Hub Db Info Faild!{str(e)}")

Expand Down Expand Up @@ -194,7 +194,7 @@ async def upload_my_plugin(self, doc_file: UploadFile, user: Any = Default_User)
my_plugin_entiy.user_name = user
my_plugin_entiy.tenant = ""
my_plugin_entiy.file_name = doc_file.filename
self.my_plugin_dao.update(my_plugin_entiy)
self.my_plugin_dao.raw_update(my_plugin_entiy)

def reload_my_plugins(self):
logger.info(f"load_plugins start!")
Expand Down
23 changes: 18 additions & 5 deletions dbgpt/app/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,24 @@ def migrate(alembic_ini_path: str, script_location: str, message: str):

@migration.command()
@add_migration_options
def upgrade(alembic_ini_path: str, script_location: str):
@click.option(
"--sql-output",
type=str,
default=None,
help="Generate SQL script for migration instead of applying it. ex: --sql-output=upgrade.sql",
)
def upgrade(alembic_ini_path: str, script_location: str, sql_output: str):
"""Upgrade database to target version"""
from dbgpt.util._db_migration_utils import upgrade_database
from dbgpt.util._db_migration_utils import (
upgrade_database,
generate_sql_for_upgrade,
)

alembic_cfg, db_manager = _get_migration_config(alembic_ini_path, script_location)
upgrade_database(alembic_cfg, db_manager.engine)
if sql_output:
generate_sql_for_upgrade(alembic_cfg, db_manager.engine, output_file=sql_output)
else:
upgrade_database(alembic_cfg, db_manager.engine)


@migration.command()
Expand Down Expand Up @@ -199,6 +211,7 @@ def clean(
def list(alembic_ini_path: str, script_location: str):
"""List all versions in the migration history, marking the current one"""
from alembic.script import ScriptDirectory

from alembic.runtime.migration import MigrationContext

alembic_cfg, db_manager = _get_migration_config(alembic_ini_path, script_location)
Expand Down Expand Up @@ -259,8 +272,8 @@ def _get_migration_config(
from dbgpt.storage.metadata.db_manager import db as db_manager
from dbgpt.util._db_migration_utils import create_alembic_config

# Must import dbgpt_server for initialize db metadata
from dbgpt.app.dbgpt_server import initialize_app as _
# Import all models to make sure they are registered with SQLAlchemy.
from dbgpt.app.initialization.db_model_initialization import _MODELS
from dbgpt.app.base import _initialize_db

# initialize db
Expand Down
42 changes: 34 additions & 8 deletions dbgpt/app/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from dbgpt.component import SystemApp
from dbgpt.util.parameter_utils import BaseParameters

from dbgpt.util._db_migration_utils import _ddl_init_and_upgrade

ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(ROOT_PATH)
Expand Down Expand Up @@ -92,10 +91,27 @@ def _initialize_db_storage(param: "WebServerParameters"):
Now just support sqlite and mysql. If db type is sqlite, the db path is `pilot/meta_data/{db_name}.db`.
"""
default_meta_data_path = _initialize_db(
try_to_create_db=not param.disable_alembic_upgrade
)
_ddl_init_and_upgrade(default_meta_data_path, param.disable_alembic_upgrade)
_initialize_db(try_to_create_db=not param.disable_alembic_upgrade)


def _migration_db_storage(param: "WebServerParameters"):
"""Migration the db storage."""
# Import all models to make sure they are registered with SQLAlchemy.
from dbgpt.app.initialization.db_model_initialization import _MODELS

from dbgpt.configs.model_config import PILOT_PATH

default_meta_data_path = os.path.join(PILOT_PATH, "meta_data")
if not param.disable_alembic_upgrade:
from dbgpt.util._db_migration_utils import _ddl_init_and_upgrade
from dbgpt.storage.metadata.db_manager import db

# try to create all tables
try:
db.create_all()
except Exception as e:
logger.warning(f"Create all tables stored in this metadata error: {str(e)}")
_ddl_init_and_upgrade(default_meta_data_path, param.disable_alembic_upgrade)


def _initialize_db(try_to_create_db: Optional[bool] = False) -> str:
Expand All @@ -112,7 +128,13 @@ def _initialize_db(try_to_create_db: Optional[bool] = False) -> str:
default_meta_data_path = os.path.join(PILOT_PATH, "meta_data")
os.makedirs(default_meta_data_path, exist_ok=True)
if CFG.LOCAL_DB_TYPE == "mysql":
db_url = f"mysql+pymysql://{quote(CFG.LOCAL_DB_USER)}:{urlquote(CFG.LOCAL_DB_PASSWORD)}@{CFG.LOCAL_DB_HOST}:{str(CFG.LOCAL_DB_PORT)}/{db_name}"
db_url = (
f"mysql+pymysql://{quote(CFG.LOCAL_DB_USER)}:"
f"{urlquote(CFG.LOCAL_DB_PASSWORD)}@"
f"{CFG.LOCAL_DB_HOST}:"
f"{str(CFG.LOCAL_DB_PORT)}/"
f"{db_name}?charset=utf8mb4&collation=utf8mb4_unicode_ci"
)
# Try to create database, if failed, will raise exception
_create_mysql_database(db_name, db_url, try_to_create_db)
else:
Expand All @@ -125,7 +147,7 @@ def _initialize_db(try_to_create_db: Optional[bool] = False) -> str:
"pool_recycle": 3600,
"pool_pre_ping": True,
}
initialize_db(db_url, db_name, engine_args, try_to_create_db=try_to_create_db)
initialize_db(db_url, db_name, engine_args)
return default_meta_data_path


Expand Down Expand Up @@ -161,7 +183,11 @@ def _create_mysql_database(db_name: str, db_url: str, try_to_create_db: bool = F
no_db_name_url = db_url.rsplit("/", 1)[0]
engine_no_db = create_engine(no_db_name_url)
with engine_no_db.connect() as conn:
conn.execute(DDL(f"CREATE DATABASE {db_name}"))
conn.execute(
DDL(
f"CREATE DATABASE {db_name} CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci"
)
)
logger.info(f"Database {db_name} successfully created")
except SQLAlchemyError as e:
logger.error(f"Failed to create database {db_name}: {e}")
Expand Down
102 changes: 6 additions & 96 deletions dbgpt/app/component_configs.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Any, Type

from dbgpt.component import ComponentType, SystemApp
from dbgpt.component import SystemApp
from dbgpt._private.config import Config
from dbgpt.configs.model_config import MODEL_DISK_CACHE_DIR
from dbgpt.util.executor_utils import DefaultExecutorFactory
from dbgpt.rag.embedding_engine.embedding_factory import EmbeddingFactory
from dbgpt.app.base import WebServerParameters

if TYPE_CHECKING:
from langchain.embeddings.base import Embeddings

logger = logging.getLogger(__name__)

Expand All @@ -24,7 +20,10 @@ def initialize_components(
embedding_model_name: str,
embedding_model_path: str,
):
# Lazy import to avoid high time cost
from dbgpt.model.cluster.controller.controller import controller
from dbgpt.app.initialization.embedding_component import _initialize_embedding_model
from dbgpt.app.initialization.serve_initialization import register_serve_apps

# Register global default executor factory first
system_app.register(DefaultExecutorFactory)
Expand All @@ -45,97 +44,8 @@ def initialize_components(
_initialize_model_cache(system_app)
# NOTE: cannot disable experimental features
_initialize_awel(system_app)


def _initialize_embedding_model(
param: WebServerParameters,
system_app: SystemApp,
embedding_model_name: str,
embedding_model_path: str,
):
if param.remote_embedding:
logger.info("Register remote RemoteEmbeddingFactory")
system_app.register(RemoteEmbeddingFactory, model_name=embedding_model_name)
else:
logger.info(f"Register local LocalEmbeddingFactory")
system_app.register(
LocalEmbeddingFactory,
default_model_name=embedding_model_name,
default_model_path=embedding_model_path,
)


class RemoteEmbeddingFactory(EmbeddingFactory):
def __init__(self, system_app, model_name: str = None, **kwargs: Any) -> None:
super().__init__(system_app=system_app)
self._default_model_name = model_name
self.kwargs = kwargs
self.system_app = system_app

def init_app(self, system_app):
self.system_app = system_app

def create(
self, model_name: str = None, embedding_cls: Type = None
) -> "Embeddings":
from dbgpt.model.cluster import WorkerManagerFactory
from dbgpt.model.cluster.embedding.remote_embedding import RemoteEmbeddings

if embedding_cls:
raise NotImplementedError
worker_manager = self.system_app.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
# Ignore model_name args
return RemoteEmbeddings(self._default_model_name, worker_manager)


class LocalEmbeddingFactory(EmbeddingFactory):
def __init__(
self,
system_app,
default_model_name: str = None,
default_model_path: str = None,
**kwargs: Any,
) -> None:
super().__init__(system_app=system_app)
self._default_model_name = default_model_name
self._default_model_path = default_model_path
self._kwargs = kwargs
self._model = self._load_model()

def init_app(self, system_app):
pass

def create(
self, model_name: str = None, embedding_cls: Type = None
) -> "Embeddings":
if embedding_cls:
raise NotImplementedError
return self._model

def _load_model(self) -> "Embeddings":
from dbgpt.model.cluster.embedding.loader import EmbeddingLoader
from dbgpt.model.cluster.worker.embedding_worker import _parse_embedding_params
from dbgpt.model.parameter import (
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
BaseEmbeddingModelParameters,
EmbeddingModelParameters,
)

param_cls = EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get(
self._default_model_name, EmbeddingModelParameters
)
model_params: BaseEmbeddingModelParameters = _parse_embedding_params(
model_name=self._default_model_name,
model_path=self._default_model_path,
param_cls=param_cls,
**self._kwargs,
)
logger.info(model_params)
loader = EmbeddingLoader()
# Ignore model_name args
return loader.load(self._default_model_name, model_params)
# Register serve apps
register_serve_apps(system_app)


def _initialize_model_cache(system_app: SystemApp):
Expand Down
Loading

0 comments on commit 14f3172

Please sign in to comment.