From 14f3172f28dec705c0b29a3133be9385d00867f5 Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Mon, 18 Dec 2023 19:30:40 +0800 Subject: [PATCH] refactor: Refactor storage and new serve template (#947) --- .env.template | 14 +- .gitignore | 2 +- dbgpt/agent/db/my_plugin_db.py | 10 +- dbgpt/agent/db/plugin_hub_db.py | 10 +- dbgpt/agent/hub/agent_hub.py | 4 +- dbgpt/app/_cli.py | 23 +- dbgpt/app/base.py | 42 +++- dbgpt/app/component_configs.py | 102 +-------- dbgpt/app/dbgpt_server.py | 58 +++-- dbgpt/app/initialization/__init__.py | 0 .../initialization/db_model_initialization.py | 29 +++ .../app/initialization/embedding_component.py | 103 +++++++++ .../initialization/serve_initialization.py | 9 + dbgpt/app/knowledge/chunk_db.py | 6 +- dbgpt/app/knowledge/document_db.py | 6 +- dbgpt/app/knowledge/service.py | 8 +- dbgpt/app/knowledge/space_db.py | 4 - .../openapi/api_v1/feedback/feed_back_db.py | 4 - dbgpt/app/prompt/prompt_manage_db.py | 4 - dbgpt/cli/cli_scripts.py | 14 ++ dbgpt/component.py | 13 +- dbgpt/core/interface/message.py | 10 +- dbgpt/datasource/manages/connect_config_db.py | 3 +- dbgpt/model/cluster/worker/default_worker.py | 20 +- dbgpt/model/cluster/worker/remote_manager.py | 4 +- dbgpt/serve/core/__init__.py | 5 + dbgpt/serve/core/config.py | 19 ++ dbgpt/serve/core/schemas.py | 38 ++++ dbgpt/serve/core/service.py | 33 +++ dbgpt/serve/prompt/__init__.py | 2 + dbgpt/serve/prompt/api/__init__.py | 2 + dbgpt/serve/prompt/api/endpoints.py | 114 ++++++++++ dbgpt/serve/prompt/api/schemas.py | 73 +++++++ dbgpt/serve/prompt/config.py | 19 ++ dbgpt/serve/prompt/dependencies.py | 1 + dbgpt/serve/prompt/models/__init__.py | 2 + dbgpt/serve/prompt/models/models.py | 95 +++++++++ dbgpt/serve/prompt/serve.py | 36 ++++ dbgpt/serve/prompt/service/__init__.py | 0 dbgpt/serve/prompt/service/service.py | 117 +++++++++++ dbgpt/serve/utils/__init__.py | 0 .../default_serve_template/__init__.py | 2 + .../default_serve_template/api/__init__.py | 2 + .../default_serve_template/api/endpoints.py | 98 +++++++++ .../default_serve_template/api/schemas.py | 14 ++ .../default_serve_template/config.py | 19 ++ .../default_serve_template/dependencies.py | 1 + .../default_serve_template/models/__init__.py | 2 + .../default_serve_template/models/models.py | 68 ++++++ .../default_serve_template/serve.py | 38 ++++ .../service/__init__.py | 0 .../default_serve_template/service/service.py | 116 ++++++++++ dbgpt/serve/utils/cli.py | 83 ++++++++ dbgpt/storage/chat_history/chat_history_db.py | 22 +- .../store_type/meta_db_history.py | 6 +- dbgpt/storage/metadata/_base_dao.py | 198 +++++++++++++++++- dbgpt/storage/metadata/db_storage.py | 6 +- dbgpt/storage/metadata/tests/test_base_dao.py | 152 ++++++++++++++ dbgpt/util/__init__.py | 14 +- dbgpt/util/_db_migration_utils.py | 177 +++++++++++++++- dbgpt/util/config_utils.py | 32 +++ dbgpt/util/utils.py | 12 -- pilot/meta_data/alembic/env.py | 5 +- 63 files changed, 1889 insertions(+), 236 deletions(-) create mode 100644 dbgpt/app/initialization/__init__.py create mode 100644 dbgpt/app/initialization/db_model_initialization.py create mode 100644 dbgpt/app/initialization/embedding_component.py create mode 100644 dbgpt/app/initialization/serve_initialization.py create mode 100644 dbgpt/serve/core/__init__.py create mode 100644 dbgpt/serve/core/config.py create mode 100644 dbgpt/serve/core/schemas.py create mode 100644 dbgpt/serve/core/service.py create mode 100644 dbgpt/serve/prompt/__init__.py create mode 100644 dbgpt/serve/prompt/api/__init__.py create mode 100644 dbgpt/serve/prompt/api/endpoints.py create mode 100644 dbgpt/serve/prompt/api/schemas.py create mode 100644 dbgpt/serve/prompt/config.py create mode 100644 dbgpt/serve/prompt/dependencies.py create mode 100644 dbgpt/serve/prompt/models/__init__.py create mode 100644 dbgpt/serve/prompt/models/models.py create mode 100644 dbgpt/serve/prompt/serve.py create mode 100644 dbgpt/serve/prompt/service/__init__.py create mode 100644 dbgpt/serve/prompt/service/service.py create mode 100644 dbgpt/serve/utils/__init__.py create mode 100644 dbgpt/serve/utils/_template_files/default_serve_template/__init__.py create mode 100644 dbgpt/serve/utils/_template_files/default_serve_template/api/__init__.py create mode 100644 dbgpt/serve/utils/_template_files/default_serve_template/api/endpoints.py create mode 100644 dbgpt/serve/utils/_template_files/default_serve_template/api/schemas.py create mode 100644 dbgpt/serve/utils/_template_files/default_serve_template/config.py create mode 100644 dbgpt/serve/utils/_template_files/default_serve_template/dependencies.py create mode 100644 dbgpt/serve/utils/_template_files/default_serve_template/models/__init__.py create mode 100644 dbgpt/serve/utils/_template_files/default_serve_template/models/models.py create mode 100644 dbgpt/serve/utils/_template_files/default_serve_template/serve.py create mode 100644 dbgpt/serve/utils/_template_files/default_serve_template/service/__init__.py create mode 100644 dbgpt/serve/utils/_template_files/default_serve_template/service/service.py create mode 100644 dbgpt/serve/utils/cli.py create mode 100644 dbgpt/storage/metadata/tests/test_base_dao.py create mode 100644 dbgpt/util/config_utils.py diff --git a/.env.template b/.env.template index c9cd2a97c..b49102dd5 100644 --- a/.env.template +++ b/.env.template @@ -21,19 +21,15 @@ WEB_SERVER_PORT=7860 #*******************************************************************# #** LLM MODELS **# #*******************************************************************# -# LLM_MODEL, see /pilot/configs/model_config.LLM_MODEL_CONFIG -LLM_MODEL=vicuna-7b-v1.5 -# LLM_MODEL=yi-6b-chat -# LLM_MODEL=vicuna-13b-v1.5 -# LLM_MODEL=qwen-7b-chat-int4 -# LLM_MODEL=llama-cpp +# LLM_MODEL, see dbgpt/configs/model_config.LLM_MODEL_CONFIG +LLM_MODEL=vicuna-13b-v1.5 ## LLM model path, by default, DB-GPT will read the model path from LLM_MODEL_CONFIG based on the LLM_MODEL. ## Of course you can specify your model path according to LLM_MODEL_PATH ## In DB-GPT, the priority from high to low to read model path: ## 1. environment variable with key: {LLM_MODEL}_MODEL_PATH (Avoid multi-model conflicts) ## 2. environment variable with key: MODEL_PATH ## 3. environment variable with key: LLM_MODEL_PATH -## 4. the config in /pilot/configs/model_config.LLM_MODEL_CONFIG +## 4. the config in dbgpt/configs/model_config.LLM_MODEL_CONFIG # LLM_MODEL_PATH=/app/models/vicuna-13b-v1.5 # LLM_PROMPT_TEMPLATE=vicuna_v1.1 MODEL_SERVER=http://127.0.0.1:8000 @@ -51,7 +47,7 @@ QUANTIZE_8bit=False # True # PROXYLLM_BACKEND= ### You can configure parameters for a specific model with {model name}_{config key}=xxx -### See /pilot/model/parameter.py +### See dbgpt/model/parameter.py ## prompt template for current model # llama_cpp_prompt_template=vicuna_v1.1 ## llama-2-70b must be 8 @@ -92,7 +88,7 @@ KNOWLEDGE_SEARCH_REWRITE=False # EMBEDDING_TOKENIZER=all-MiniLM-L6-v2 # EMBEDDING_TOKEN_LIMIT=8191 -## Openai embedding model, See /pilot/model/parameter.py +## Openai embedding model, See dbgpt/model/parameter.py # EMBEDDING_MODEL=proxy_openai # proxy_openai_proxy_server_url=https://api.openai.com/v1 # proxy_openai_proxy_api_key={your-openai-sk} diff --git a/.gitignore b/.gitignore index 275f3758a..50c82741b 100644 --- a/.gitignore +++ b/.gitignore @@ -8,7 +8,7 @@ __pycache__/ message/ -.env +.env* .vscode .idea .chroma diff --git a/dbgpt/agent/db/my_plugin_db.py b/dbgpt/agent/db/my_plugin_db.py index ed5ab176e..0927f4283 100644 --- a/dbgpt/agent/db/my_plugin_db.py +++ b/dbgpt/agent/db/my_plugin_db.py @@ -7,10 +7,6 @@ class MyPluginEntity(Model): __tablename__ = "my_plugin" - __table_args__ = { - "mysql_charset": "utf8mb4", - "mysql_collate": "utf8mb4_unicode_ci", - } id = Column(Integer, primary_key=True, comment="autoincrement id") tenant = Column(String(255), nullable=True, comment="user's tenant") user_code = Column(String(255), nullable=False, comment="user code") @@ -32,7 +28,7 @@ class MyPluginEntity(Model): UniqueConstraint("user_code", "name", name="uk_name") -class MyPluginDao(BaseDao[MyPluginEntity]): +class MyPluginDao(BaseDao): def add(self, engity: MyPluginEntity): session = self.get_raw_session() my_plugin = MyPluginEntity( @@ -53,7 +49,7 @@ def add(self, engity: MyPluginEntity): session.close() return id - def update(self, entity: MyPluginEntity): + def raw_update(self, entity: MyPluginEntity): session = self.get_raw_session() updated = session.merge(entity) session.commit() @@ -128,7 +124,7 @@ def count(self, query: MyPluginEntity): session.close() return count - def delete(self, plugin_id: int): + def raw_delete(self, plugin_id: int): session = self.get_raw_session() if plugin_id is None: raise Exception("plugin_id is None") diff --git a/dbgpt/agent/db/plugin_hub_db.py b/dbgpt/agent/db/plugin_hub_db.py index d374d284d..e1bcfffcf 100644 --- a/dbgpt/agent/db/plugin_hub_db.py +++ b/dbgpt/agent/db/plugin_hub_db.py @@ -11,10 +11,6 @@ class PluginHubEntity(Model): __tablename__ = "plugin_hub" - __table_args__ = { - "mysql_charset": "utf8mb4", - "mysql_collate": "utf8mb4_unicode_ci", - } id = Column( Integer, primary_key=True, autoincrement=True, comment="autoincrement id" ) @@ -36,7 +32,7 @@ class PluginHubEntity(Model): Index("idx_q_type", "type") -class PluginHubDao(BaseDao[PluginHubEntity]): +class PluginHubDao(BaseDao): def add(self, engity: PluginHubEntity): session = self.get_raw_session() timezone = pytz.timezone("Asia/Shanghai") @@ -56,7 +52,7 @@ def add(self, engity: PluginHubEntity): session.close() return id - def update(self, entity: PluginHubEntity): + def raw_update(self, entity: PluginHubEntity): session = self.get_raw_session() try: updated = session.merge(entity) @@ -131,7 +127,7 @@ def count(self, query: PluginHubEntity): session.close() return count - def delete(self, plugin_id: int): + def raw_delete(self, plugin_id: int): session = self.get_raw_session() if plugin_id is None: raise Exception("plugin_id is None") diff --git a/dbgpt/agent/hub/agent_hub.py b/dbgpt/agent/hub/agent_hub.py index f95b49305..0359d81fd 100644 --- a/dbgpt/agent/hub/agent_hub.py +++ b/dbgpt/agent/hub/agent_hub.py @@ -159,7 +159,7 @@ def refresh_hub_from_git( plugin_hub_info.name = git_plugin._name plugin_hub_info.version = git_plugin._version plugin_hub_info.description = git_plugin._description - self.hub_dao.update(plugin_hub_info) + self.hub_dao.raw_update(plugin_hub_info) except Exception as e: raise ValueError(f"Update Agent Hub Db Info Faild!{str(e)}") @@ -194,7 +194,7 @@ async def upload_my_plugin(self, doc_file: UploadFile, user: Any = Default_User) my_plugin_entiy.user_name = user my_plugin_entiy.tenant = "" my_plugin_entiy.file_name = doc_file.filename - self.my_plugin_dao.update(my_plugin_entiy) + self.my_plugin_dao.raw_update(my_plugin_entiy) def reload_my_plugins(self): logger.info(f"load_plugins start!") diff --git a/dbgpt/app/_cli.py b/dbgpt/app/_cli.py index 02dc471b7..f48e57111 100644 --- a/dbgpt/app/_cli.py +++ b/dbgpt/app/_cli.py @@ -108,12 +108,24 @@ def migrate(alembic_ini_path: str, script_location: str, message: str): @migration.command() @add_migration_options -def upgrade(alembic_ini_path: str, script_location: str): +@click.option( + "--sql-output", + type=str, + default=None, + help="Generate SQL script for migration instead of applying it. ex: --sql-output=upgrade.sql", +) +def upgrade(alembic_ini_path: str, script_location: str, sql_output: str): """Upgrade database to target version""" - from dbgpt.util._db_migration_utils import upgrade_database + from dbgpt.util._db_migration_utils import ( + upgrade_database, + generate_sql_for_upgrade, + ) alembic_cfg, db_manager = _get_migration_config(alembic_ini_path, script_location) - upgrade_database(alembic_cfg, db_manager.engine) + if sql_output: + generate_sql_for_upgrade(alembic_cfg, db_manager.engine, output_file=sql_output) + else: + upgrade_database(alembic_cfg, db_manager.engine) @migration.command() @@ -199,6 +211,7 @@ def clean( def list(alembic_ini_path: str, script_location: str): """List all versions in the migration history, marking the current one""" from alembic.script import ScriptDirectory + from alembic.runtime.migration import MigrationContext alembic_cfg, db_manager = _get_migration_config(alembic_ini_path, script_location) @@ -259,8 +272,8 @@ def _get_migration_config( from dbgpt.storage.metadata.db_manager import db as db_manager from dbgpt.util._db_migration_utils import create_alembic_config - # Must import dbgpt_server for initialize db metadata - from dbgpt.app.dbgpt_server import initialize_app as _ + # Import all models to make sure they are registered with SQLAlchemy. + from dbgpt.app.initialization.db_model_initialization import _MODELS from dbgpt.app.base import _initialize_db # initialize db diff --git a/dbgpt/app/base.py b/dbgpt/app/base.py index a750c9bab..ac2618bf0 100644 --- a/dbgpt/app/base.py +++ b/dbgpt/app/base.py @@ -10,7 +10,6 @@ from dbgpt.component import SystemApp from dbgpt.util.parameter_utils import BaseParameters -from dbgpt.util._db_migration_utils import _ddl_init_and_upgrade ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(ROOT_PATH) @@ -92,10 +91,27 @@ def _initialize_db_storage(param: "WebServerParameters"): Now just support sqlite and mysql. If db type is sqlite, the db path is `pilot/meta_data/{db_name}.db`. """ - default_meta_data_path = _initialize_db( - try_to_create_db=not param.disable_alembic_upgrade - ) - _ddl_init_and_upgrade(default_meta_data_path, param.disable_alembic_upgrade) + _initialize_db(try_to_create_db=not param.disable_alembic_upgrade) + + +def _migration_db_storage(param: "WebServerParameters"): + """Migration the db storage.""" + # Import all models to make sure they are registered with SQLAlchemy. + from dbgpt.app.initialization.db_model_initialization import _MODELS + + from dbgpt.configs.model_config import PILOT_PATH + + default_meta_data_path = os.path.join(PILOT_PATH, "meta_data") + if not param.disable_alembic_upgrade: + from dbgpt.util._db_migration_utils import _ddl_init_and_upgrade + from dbgpt.storage.metadata.db_manager import db + + # try to create all tables + try: + db.create_all() + except Exception as e: + logger.warning(f"Create all tables stored in this metadata error: {str(e)}") + _ddl_init_and_upgrade(default_meta_data_path, param.disable_alembic_upgrade) def _initialize_db(try_to_create_db: Optional[bool] = False) -> str: @@ -112,7 +128,13 @@ def _initialize_db(try_to_create_db: Optional[bool] = False) -> str: default_meta_data_path = os.path.join(PILOT_PATH, "meta_data") os.makedirs(default_meta_data_path, exist_ok=True) if CFG.LOCAL_DB_TYPE == "mysql": - db_url = f"mysql+pymysql://{quote(CFG.LOCAL_DB_USER)}:{urlquote(CFG.LOCAL_DB_PASSWORD)}@{CFG.LOCAL_DB_HOST}:{str(CFG.LOCAL_DB_PORT)}/{db_name}" + db_url = ( + f"mysql+pymysql://{quote(CFG.LOCAL_DB_USER)}:" + f"{urlquote(CFG.LOCAL_DB_PASSWORD)}@" + f"{CFG.LOCAL_DB_HOST}:" + f"{str(CFG.LOCAL_DB_PORT)}/" + f"{db_name}?charset=utf8mb4&collation=utf8mb4_unicode_ci" + ) # Try to create database, if failed, will raise exception _create_mysql_database(db_name, db_url, try_to_create_db) else: @@ -125,7 +147,7 @@ def _initialize_db(try_to_create_db: Optional[bool] = False) -> str: "pool_recycle": 3600, "pool_pre_ping": True, } - initialize_db(db_url, db_name, engine_args, try_to_create_db=try_to_create_db) + initialize_db(db_url, db_name, engine_args) return default_meta_data_path @@ -161,7 +183,11 @@ def _create_mysql_database(db_name: str, db_url: str, try_to_create_db: bool = F no_db_name_url = db_url.rsplit("/", 1)[0] engine_no_db = create_engine(no_db_name_url) with engine_no_db.connect() as conn: - conn.execute(DDL(f"CREATE DATABASE {db_name}")) + conn.execute( + DDL( + f"CREATE DATABASE {db_name} CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci" + ) + ) logger.info(f"Database {db_name} successfully created") except SQLAlchemyError as e: logger.error(f"Failed to create database {db_name}: {e}") diff --git a/dbgpt/app/component_configs.py b/dbgpt/app/component_configs.py index 01197e5dc..77cbfcc3a 100644 --- a/dbgpt/app/component_configs.py +++ b/dbgpt/app/component_configs.py @@ -1,17 +1,13 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Any, Type -from dbgpt.component import ComponentType, SystemApp +from dbgpt.component import SystemApp from dbgpt._private.config import Config from dbgpt.configs.model_config import MODEL_DISK_CACHE_DIR from dbgpt.util.executor_utils import DefaultExecutorFactory -from dbgpt.rag.embedding_engine.embedding_factory import EmbeddingFactory from dbgpt.app.base import WebServerParameters -if TYPE_CHECKING: - from langchain.embeddings.base import Embeddings logger = logging.getLogger(__name__) @@ -24,7 +20,10 @@ def initialize_components( embedding_model_name: str, embedding_model_path: str, ): + # Lazy import to avoid high time cost from dbgpt.model.cluster.controller.controller import controller + from dbgpt.app.initialization.embedding_component import _initialize_embedding_model + from dbgpt.app.initialization.serve_initialization import register_serve_apps # Register global default executor factory first system_app.register(DefaultExecutorFactory) @@ -45,97 +44,8 @@ def initialize_components( _initialize_model_cache(system_app) # NOTE: cannot disable experimental features _initialize_awel(system_app) - - -def _initialize_embedding_model( - param: WebServerParameters, - system_app: SystemApp, - embedding_model_name: str, - embedding_model_path: str, -): - if param.remote_embedding: - logger.info("Register remote RemoteEmbeddingFactory") - system_app.register(RemoteEmbeddingFactory, model_name=embedding_model_name) - else: - logger.info(f"Register local LocalEmbeddingFactory") - system_app.register( - LocalEmbeddingFactory, - default_model_name=embedding_model_name, - default_model_path=embedding_model_path, - ) - - -class RemoteEmbeddingFactory(EmbeddingFactory): - def __init__(self, system_app, model_name: str = None, **kwargs: Any) -> None: - super().__init__(system_app=system_app) - self._default_model_name = model_name - self.kwargs = kwargs - self.system_app = system_app - - def init_app(self, system_app): - self.system_app = system_app - - def create( - self, model_name: str = None, embedding_cls: Type = None - ) -> "Embeddings": - from dbgpt.model.cluster import WorkerManagerFactory - from dbgpt.model.cluster.embedding.remote_embedding import RemoteEmbeddings - - if embedding_cls: - raise NotImplementedError - worker_manager = self.system_app.get_component( - ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory - ).create() - # Ignore model_name args - return RemoteEmbeddings(self._default_model_name, worker_manager) - - -class LocalEmbeddingFactory(EmbeddingFactory): - def __init__( - self, - system_app, - default_model_name: str = None, - default_model_path: str = None, - **kwargs: Any, - ) -> None: - super().__init__(system_app=system_app) - self._default_model_name = default_model_name - self._default_model_path = default_model_path - self._kwargs = kwargs - self._model = self._load_model() - - def init_app(self, system_app): - pass - - def create( - self, model_name: str = None, embedding_cls: Type = None - ) -> "Embeddings": - if embedding_cls: - raise NotImplementedError - return self._model - - def _load_model(self) -> "Embeddings": - from dbgpt.model.cluster.embedding.loader import EmbeddingLoader - from dbgpt.model.cluster.worker.embedding_worker import _parse_embedding_params - from dbgpt.model.parameter import ( - EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG, - BaseEmbeddingModelParameters, - EmbeddingModelParameters, - ) - - param_cls = EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get( - self._default_model_name, EmbeddingModelParameters - ) - model_params: BaseEmbeddingModelParameters = _parse_embedding_params( - model_name=self._default_model_name, - model_path=self._default_model_path, - param_cls=param_cls, - **self._kwargs, - ) - logger.info(model_params) - loader = EmbeddingLoader() - # Ignore model_name args - return loader.load(self._default_model_name, model_params) + # Register serve apps + register_serve_apps(system_app) def _initialize_model_cache(system_app: SystemApp): diff --git a/dbgpt/app/dbgpt_server.py b/dbgpt/app/dbgpt_server.py index 6a3322e5d..9813303b4 100644 --- a/dbgpt/app/dbgpt_server.py +++ b/dbgpt/app/dbgpt_server.py @@ -16,29 +16,22 @@ from dbgpt.app.base import ( server_init, + _migration_db_storage, WebServerParameters, _create_model_start_listener, ) + +# initialize_components import time cost about 0.1s from dbgpt.app.component_configs import initialize_components +# fastapi import time cost about 0.05s from fastapi.staticfiles import StaticFiles from fastapi import FastAPI, applications from fastapi.openapi.docs import get_swagger_ui_html from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware -from dbgpt.app.knowledge.api import router as knowledge_router -# NOTE: needed for web -from dbgpt.app.prompt.api import router as prompt_router -from dbgpt.app.llm_manage.api import router as llm_manage_api -from dbgpt.app.openapi.api_v1.api_v1 import router as api_v1 from dbgpt.app.openapi.base import validation_exception_handler -from dbgpt.app.openapi.api_v1.editor.api_editor_v1 import router as api_editor_route_v1 -from dbgpt.app.openapi.api_v1.feedback.api_fb_v1 import router as api_fb_v1 -from dbgpt.agent.commands.disply_type.show_chart_gen import ( - static_message_img_path, -) -from dbgpt.model.cluster import initialize_worker_manager_in_client from dbgpt.util.utils import ( setup_logging, _get_logging_level, @@ -79,17 +72,32 @@ def swagger_monkey_patch(*args, **kwargs): allow_headers=["*"], ) -app.include_router(api_v1, prefix="/api", tags=["Chat"]) -app.include_router(api_editor_route_v1, prefix="/api", tags=["Editor"]) -app.include_router(llm_manage_api, prefix="/api", tags=["LLM Manage"]) -app.include_router(api_fb_v1, prefix="/api", tags=["FeedBack"]) -app.include_router(knowledge_router, tags=["Knowledge"]) -# NOTE: needed for web -app.include_router(prompt_router, tags=["Prompt"]) +def mount_routers(app: FastAPI): + """Lazy import to avoid high time cost""" + from dbgpt.app.knowledge.api import router as knowledge_router + + from dbgpt.app.llm_manage.api import router as llm_manage_api + + from dbgpt.app.openapi.api_v1.api_v1 import router as api_v1 + from dbgpt.app.openapi.api_v1.editor.api_editor_v1 import ( + router as api_editor_route_v1, + ) + from dbgpt.app.openapi.api_v1.feedback.api_fb_v1 import router as api_fb_v1 + + app.include_router(api_v1, prefix="/api", tags=["Chat"]) + app.include_router(api_editor_route_v1, prefix="/api", tags=["Editor"]) + app.include_router(llm_manage_api, prefix="/api", tags=["LLM Manage"]) + app.include_router(api_fb_v1, prefix="/api", tags=["FeedBack"]) + + app.include_router(knowledge_router, tags=["Knowledge"]) -def mount_static_files(app): +def mount_static_files(app: FastAPI): + from dbgpt.agent.commands.disply_type.show_chart_gen import ( + static_message_img_path, + ) + os.makedirs(static_message_img_path, exist_ok=True) app.mount( "/images", @@ -124,14 +132,15 @@ def initialize_app(param: WebServerParameters = None, args: List[str] = None): if not param: param = _get_webserver_params(args) + # import after param is initialized, accelerate --help speed + from dbgpt.model.cluster import initialize_worker_manager_in_client + if not param.log_level: param.log_level = _get_logging_level() setup_logging( "dbgpt", logging_level=param.log_level, logger_filename=param.log_file ) - # Before start - system_app.before_start() model_name = param.model_name or CFG.LLM_MODEL param.model_name = model_name print(param) @@ -140,9 +149,16 @@ def initialize_app(param: WebServerParameters = None, args: List[str] = None): embedding_model_path = EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL] server_init(param, system_app) + mount_routers(app) model_start_listener = _create_model_start_listener(system_app) initialize_components(param, system_app, embedding_model_name, embedding_model_path) + # Before start, after initialize_components + # TODO: initialize_worker_manager_in_client as a component register in system_app + system_app.before_start() + # Migration db storage, so you db models must be imported before this + _migration_db_storage(param) + model_path = CFG.LLM_MODEL_PATH or LLM_MODEL_CONFIG.get(model_name) if not param.light: print("Model Unified Deployment Mode!") diff --git a/dbgpt/app/initialization/__init__.py b/dbgpt/app/initialization/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/dbgpt/app/initialization/db_model_initialization.py b/dbgpt/app/initialization/db_model_initialization.py new file mode 100644 index 000000000..9095da3e0 --- /dev/null +++ b/dbgpt/app/initialization/db_model_initialization.py @@ -0,0 +1,29 @@ +"""Import all models to make sure they are registered with SQLAlchemy. +""" +from dbgpt.agent.db.my_plugin_db import MyPluginEntity +from dbgpt.agent.db.plugin_hub_db import PluginHubEntity +from dbgpt.app.knowledge.chunk_db import DocumentChunkEntity +from dbgpt.app.knowledge.document_db import KnowledgeDocumentEntity +from dbgpt.app.knowledge.space_db import KnowledgeSpaceEntity +from dbgpt.app.openapi.api_v1.feedback.feed_back_db import ChatFeedBackEntity + +# from dbgpt.app.prompt.prompt_manage_db import PromptManageEntity +from dbgpt.serve.prompt.models.models import ServeEntity as PromptManageEntity +from dbgpt.datasource.manages.connect_config_db import ConnectConfigEntity +from dbgpt.storage.chat_history.chat_history_db import ( + ChatHistoryEntity, + ChatHistoryMessageEntity, +) + +_MODELS = [ + PluginHubEntity, + MyPluginEntity, + PromptManageEntity, + KnowledgeSpaceEntity, + KnowledgeDocumentEntity, + DocumentChunkEntity, + ChatFeedBackEntity, + ConnectConfigEntity, + ChatHistoryEntity, + ChatHistoryMessageEntity, +] diff --git a/dbgpt/app/initialization/embedding_component.py b/dbgpt/app/initialization/embedding_component.py new file mode 100644 index 000000000..6d8aa2b6e --- /dev/null +++ b/dbgpt/app/initialization/embedding_component.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +import logging +from typing import Any, Type, TYPE_CHECKING +from dbgpt.component import ComponentType, SystemApp +from dbgpt.rag.embedding_engine.embedding_factory import EmbeddingFactory + +if TYPE_CHECKING: + from langchain.embeddings.base import Embeddings + from dbgpt.app.base import WebServerParameters + +logger = logging.getLogger(__name__) + + +def _initialize_embedding_model( + param: "WebServerParameters", + system_app: SystemApp, + embedding_model_name: str, + embedding_model_path: str, +): + if param.remote_embedding: + logger.info("Register remote RemoteEmbeddingFactory") + system_app.register(RemoteEmbeddingFactory, model_name=embedding_model_name) + else: + logger.info(f"Register local LocalEmbeddingFactory") + system_app.register( + LocalEmbeddingFactory, + default_model_name=embedding_model_name, + default_model_path=embedding_model_path, + ) + + +class RemoteEmbeddingFactory(EmbeddingFactory): + def __init__(self, system_app, model_name: str = None, **kwargs: Any) -> None: + super().__init__(system_app=system_app) + self._default_model_name = model_name + self.kwargs = kwargs + self.system_app = system_app + + def init_app(self, system_app): + self.system_app = system_app + + def create( + self, model_name: str = None, embedding_cls: Type = None + ) -> "Embeddings": + from dbgpt.model.cluster import WorkerManagerFactory + from dbgpt.model.cluster.embedding.remote_embedding import RemoteEmbeddings + + if embedding_cls: + raise NotImplementedError + worker_manager = self.system_app.get_component( + ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory + ).create() + # Ignore model_name args + return RemoteEmbeddings(self._default_model_name, worker_manager) + + +class LocalEmbeddingFactory(EmbeddingFactory): + def __init__( + self, + system_app, + default_model_name: str = None, + default_model_path: str = None, + **kwargs: Any, + ) -> None: + super().__init__(system_app=system_app) + self._default_model_name = default_model_name + self._default_model_path = default_model_path + self._kwargs = kwargs + self._model = self._load_model() + + def init_app(self, system_app): + pass + + def create( + self, model_name: str = None, embedding_cls: Type = None + ) -> "Embeddings": + if embedding_cls: + raise NotImplementedError + return self._model + + def _load_model(self) -> "Embeddings": + from dbgpt.model.cluster.embedding.loader import EmbeddingLoader + from dbgpt.model.cluster.worker.embedding_worker import _parse_embedding_params + from dbgpt.model.parameter import ( + EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG, + BaseEmbeddingModelParameters, + EmbeddingModelParameters, + ) + + param_cls = EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get( + self._default_model_name, EmbeddingModelParameters + ) + model_params: BaseEmbeddingModelParameters = _parse_embedding_params( + model_name=self._default_model_name, + model_path=self._default_model_path, + param_cls=param_cls, + **self._kwargs, + ) + logger.info(model_params) + loader = EmbeddingLoader() + # Ignore model_name args + return loader.load(self._default_model_name, model_params) diff --git a/dbgpt/app/initialization/serve_initialization.py b/dbgpt/app/initialization/serve_initialization.py new file mode 100644 index 000000000..5f132cc2d --- /dev/null +++ b/dbgpt/app/initialization/serve_initialization.py @@ -0,0 +1,9 @@ +from dbgpt.component import SystemApp + + +def register_serve_apps(system_app: SystemApp): + """Register serve apps""" + from dbgpt.serve.prompt.serve import Serve as PromptServe + + # Replace old prompt serve + system_app.register(PromptServe, api_prefix="/prompt") diff --git a/dbgpt/app/knowledge/chunk_db.py b/dbgpt/app/knowledge/chunk_db.py index 11dde75b8..e8b6137ef 100644 --- a/dbgpt/app/knowledge/chunk_db.py +++ b/dbgpt/app/knowledge/chunk_db.py @@ -11,10 +11,6 @@ class DocumentChunkEntity(Model): __tablename__ = "document_chunk" - __table_args__ = { - "mysql_charset": "utf8mb4", - "mysql_collate": "utf8mb4_unicode_ci", - } id = Column(Integer, primary_key=True) document_id = Column(Integer) doc_name = Column(String(100)) @@ -112,7 +108,7 @@ def get_document_chunks_count(self, query: DocumentChunkEntity): session.close() return count - def delete(self, document_id: int): + def raw_delete(self, document_id: int): session = self.get_raw_session() if document_id is None: raise Exception("document_id is None") diff --git a/dbgpt/app/knowledge/document_db.py b/dbgpt/app/knowledge/document_db.py index 4808e7744..983bb001c 100644 --- a/dbgpt/app/knowledge/document_db.py +++ b/dbgpt/app/knowledge/document_db.py @@ -10,10 +10,6 @@ class KnowledgeDocumentEntity(Model): __tablename__ = "knowledge_document" - __table_args__ = { - "mysql_charset": "utf8mb4", - "mysql_collate": "utf8mb4_unicode_ci", - } id = Column(Integer, primary_key=True) doc_name = Column(String(100)) doc_type = Column(String(100)) @@ -180,7 +176,7 @@ def update_knowledge_document(self, document: KnowledgeDocumentEntity): return updated_space.id # - def delete(self, query: KnowledgeDocumentEntity): + def raw_delete(self, query: KnowledgeDocumentEntity): session = self.get_raw_session() knowledge_documents = session.query(KnowledgeDocumentEntity) if query.id is not None: diff --git a/dbgpt/app/knowledge/service.py b/dbgpt/app/knowledge/service.py index 8898a364a..0f45c8d40 100644 --- a/dbgpt/app/knowledge/service.py +++ b/dbgpt/app/knowledge/service.py @@ -367,9 +367,9 @@ def delete_space(self, space_name: str): # delete chunks documents = knowledge_document_dao.get_documents(document_query) for document in documents: - document_chunk_dao.delete(document.id) + document_chunk_dao.raw_delete(document.id) # delete documents - knowledge_document_dao.delete(document_query) + knowledge_document_dao.raw_delete(document_query) # delete space return knowledge_space_dao.delete_knowledge_space(space) @@ -395,9 +395,9 @@ def delete_document(self, space_name: str, doc_name: str): # delete vector by ids vector_client.delete_by_ids(vector_ids) # delete chunks - document_chunk_dao.delete(documents[0].id) + document_chunk_dao.raw_delete(documents[0].id) # delete document - return knowledge_document_dao.delete(document_query) + return knowledge_document_dao.raw_delete(document_query) def get_document_chunks(self, request: ChunkQueryRequest): """get document chunks diff --git a/dbgpt/app/knowledge/space_db.py b/dbgpt/app/knowledge/space_db.py index 8dbd904ab..4c958c613 100644 --- a/dbgpt/app/knowledge/space_db.py +++ b/dbgpt/app/knowledge/space_db.py @@ -11,10 +11,6 @@ class KnowledgeSpaceEntity(Model): __tablename__ = "knowledge_space" - __table_args__ = { - "mysql_charset": "utf8mb4", - "mysql_collate": "utf8mb4_unicode_ci", - } id = Column(Integer, primary_key=True) name = Column(String(100)) vector_type = Column(String(100)) diff --git a/dbgpt/app/openapi/api_v1/feedback/feed_back_db.py b/dbgpt/app/openapi/api_v1/feedback/feed_back_db.py index 999434924..434830625 100644 --- a/dbgpt/app/openapi/api_v1/feedback/feed_back_db.py +++ b/dbgpt/app/openapi/api_v1/feedback/feed_back_db.py @@ -9,10 +9,6 @@ class ChatFeedBackEntity(Model): __tablename__ = "chat_feed_back" - __table_args__ = { - "mysql_charset": "utf8mb4", - "mysql_collate": "utf8mb4_unicode_ci", - } id = Column(Integer, primary_key=True) conv_uid = Column(String(128)) conv_index = Column(Integer) diff --git a/dbgpt/app/prompt/prompt_manage_db.py b/dbgpt/app/prompt/prompt_manage_db.py index 0529ad4a7..6ba281668 100644 --- a/dbgpt/app/prompt/prompt_manage_db.py +++ b/dbgpt/app/prompt/prompt_manage_db.py @@ -13,10 +13,6 @@ class PromptManageEntity(Model): __tablename__ = "prompt_manage" - __table_args__ = { - "mysql_charset": "utf8mb4", - "mysql_collate": "utf8mb4_unicode_ci", - } id = Column(Integer, primary_key=True) chat_scene = Column(String(100)) sub_chat_scene = Column(String(100)) diff --git a/dbgpt/cli/cli_scripts.py b/dbgpt/cli/cli_scripts.py index f0ac7bd03..fcd8edc51 100644 --- a/dbgpt/cli/cli_scripts.py +++ b/dbgpt/cli/cli_scripts.py @@ -57,6 +57,12 @@ def db(): pass +@click.group() +def new(): + """New a template""" + pass + + stop_all_func_list = [] @@ -71,6 +77,7 @@ def stop_all(): cli.add_command(stop) cli.add_command(install) cli.add_command(db) +cli.add_command(new) add_command_alias(stop_all, name="all", parent_group=stop) try: @@ -130,6 +137,13 @@ def stop_all(): except ImportError as e: logging.warning(f"Integrating dbgpt trace command line tool failed: {e}") +try: + from dbgpt.serve.utils.cli import serve + + add_command_alias(serve, name="serve", parent_group=new) +except ImportError as e: + logging.warning(f"Integrating dbgpt serve command line tool failed: {e}") + def main(): return cli() diff --git a/dbgpt/component.py b/dbgpt/component.py index 040732dbf..84ea9afde 100644 --- a/dbgpt/component.py +++ b/dbgpt/component.py @@ -7,6 +7,7 @@ import logging import asyncio from dbgpt.util.annotations import PublicAPI +from dbgpt.util import AppConfig # Checking for type hints during runtime if TYPE_CHECKING: @@ -87,17 +88,27 @@ def init_app(self, system_app: SystemApp): class SystemApp(LifeCycle): """Main System Application class that manages the lifecycle and registration of components.""" - def __init__(self, asgi_app: Optional["FastAPI"] = None) -> None: + def __init__( + self, + asgi_app: Optional["FastAPI"] = None, + app_config: Optional[AppConfig] = None, + ) -> None: self.components: Dict[ str, BaseComponent ] = {} # Dictionary to store registered components. self._asgi_app = asgi_app + self._app_config = app_config or AppConfig() @property def app(self) -> Optional["FastAPI"]: """Returns the internal ASGI app.""" return self._asgi_app + @property + def config(self) -> AppConfig: + """Returns the internal AppConfig.""" + return self._app_config + def register(self, component: Type[BaseComponent], *args, **kwargs) -> T: """Register a new component by its type. diff --git a/dbgpt/core/interface/message.py b/dbgpt/core/interface/message.py index 36bba00f4..bd06f0dc7 100644 --- a/dbgpt/core/interface/message.py +++ b/dbgpt/core/interface/message.py @@ -655,6 +655,11 @@ def __init__( super().__init__(chat_mode, user_name, sys_code, summary, **kwargs) self.conv_uid = conv_uid self._message_ids = message_ids + # Record the message index last time saved to the storage, + # next time save messages which index is _has_stored_message_index + 1 + self._has_stored_message_index = ( + len(kwargs["messages"]) - 1 if "messages" in kwargs else -1 + ) self.save_message_independent = save_message_independent self._id = ConversationIdentifier(conv_uid) if conv_storage is None: @@ -695,7 +700,9 @@ def save_to_storage(self) -> None: self._message_ids = [ message.identifier.str_identifier for message in message_list ] - self.message_storage.save_list(message_list) + messages_to_save = message_list[self._has_stored_message_index + 1 :] + self._has_stored_message_index = len(message_list) - 1 + self.message_storage.save_list(messages_to_save) # Save conversation self.conv_storage.save_or_update(self) @@ -729,6 +736,7 @@ def load_from_storage( messages = [message.to_message() for message in message_list] conversation.messages = messages self._message_ids = message_ids + self._has_stored_message_index = len(messages) - 1 self.from_conversation(conversation) diff --git a/dbgpt/datasource/manages/connect_config_db.py b/dbgpt/datasource/manages/connect_config_db.py index 436f8f21f..228be4cdc 100644 --- a/dbgpt/datasource/manages/connect_config_db.py +++ b/dbgpt/datasource/manages/connect_config_db.py @@ -25,11 +25,10 @@ class ConnectConfigEntity(Model): __table_args__ = ( UniqueConstraint("db_name", name="uk_db"), Index("idx_q_db_type", "db_type"), - {"mysql_charset": "utf8mb4", "mysql_collate": "utf8mb4_unicode_ci"}, ) -class ConnectConfigDao(BaseDao[ConnectConfigEntity]): +class ConnectConfigDao(BaseDao): """db connect config dao""" def update(self, entity: ConnectConfigEntity): diff --git a/dbgpt/model/cluster/worker/default_worker.py b/dbgpt/model/cluster/worker/default_worker.py index 6d8e17ad4..885b6e9eb 100644 --- a/dbgpt/model/cluster/worker/default_worker.py +++ b/dbgpt/model/cluster/worker/default_worker.py @@ -19,12 +19,7 @@ logger = logging.getLogger(__name__) _torch_imported = False -try: - import torch - - _torch_imported = True -except ImportError: - pass +torch = None class DefaultModelWorker(ModelWorker): @@ -95,6 +90,8 @@ def parse_parameters(self, command_args: List[str] = None) -> ModelParameters: def start( self, model_params: ModelParameters = None, command_args: List[str] = None ) -> None: + # Lazy load torch + _try_import_torch() if not model_params: model_params = self.parse_parameters(command_args) self._model_params = model_params @@ -436,3 +433,14 @@ def _new_metrics_from_model_output( ].available_memory_gb return metrics + + +def _try_import_torch(): + global torch + global _torch_imported + try: + import torch + + _torch_imported = True + except ImportError: + pass diff --git a/dbgpt/model/cluster/worker/remote_manager.py b/dbgpt/model/cluster/worker/remote_manager.py index 4f6b675ad..5561c026d 100644 --- a/dbgpt/model/cluster/worker/remote_manager.py +++ b/dbgpt/model/cluster/worker/remote_manager.py @@ -1,7 +1,6 @@ import asyncio from typing import Any, Callable -import httpx from dbgpt.model.base import ModelInstance, WorkerApplyOutput, WorkerSupportedModel from dbgpt.model.cluster.base import * from dbgpt.model.cluster.registry import ModelRegistry @@ -34,6 +33,9 @@ async def _fetch_from_worker( success_handler: Callable = None, error_handler: Callable = None, ) -> Any: + # Lazy import to avoid high time cost + import httpx + url = worker_run_data.worker.worker_addr + endpoint headers = {**worker_run_data.worker.headers, **(additional_headers or {})} timeout = worker_run_data.worker.timeout diff --git a/dbgpt/serve/core/__init__.py b/dbgpt/serve/core/__init__.py new file mode 100644 index 000000000..36a1900e9 --- /dev/null +++ b/dbgpt/serve/core/__init__.py @@ -0,0 +1,5 @@ +from dbgpt.serve.core.schemas import Result +from dbgpt.serve.core.config import BaseServeConfig +from dbgpt.serve.core.service import BaseService + +__ALL__ = ["Result", "BaseServeConfig", "BaseService"] diff --git a/dbgpt/serve/core/config.py b/dbgpt/serve/core/config.py new file mode 100644 index 000000000..0793fc9a5 --- /dev/null +++ b/dbgpt/serve/core/config.py @@ -0,0 +1,19 @@ +from dataclasses import dataclass +from dbgpt.component import AppConfig +from dbgpt.util import BaseParameters + + +@dataclass +class BaseServeConfig(BaseParameters): + """Base configuration class for serve""" + + @classmethod + def from_app_config(cls, config: AppConfig, config_prefix: str): + """Create a configuration object from a dictionary + + Args: + config (AppConfig): Application configuration + config_prefix (str): Configuration prefix + """ + config_dict = config.get_all_by_prefix(config_prefix) + return cls(**config_dict) diff --git a/dbgpt/serve/core/schemas.py b/dbgpt/serve/core/schemas.py new file mode 100644 index 000000000..b97a936d5 --- /dev/null +++ b/dbgpt/serve/core/schemas.py @@ -0,0 +1,38 @@ +from typing import TypeVar, Generic, Any, Optional + +from dbgpt._private.pydantic import BaseModel, Field + +T = TypeVar("T") + + +class Result(BaseModel, Generic[T]): + """Common result entity class""" + + success: bool = Field( + ..., description="Whether it is successful, True: success, False: failure" + ) + err_code: str | None = Field(None, description="Error code") + err_msg: str | None = Field(None, description="Error message") + data: T | None = Field(None, description="Return data") + + @staticmethod + def succ(data: T) -> "Result[T]": + """Build a successful result entity + + Args: + data (T): Return data + + Returns: + Result[T]: Result entity + """ + return Result(success=True, err_code=None, err_msg=None, data=data) + + @staticmethod + def failed(msg: str, err_code: Optional[str] = "E000X") -> "Result[Any]": + """Build a failed result entity + + Args: + msg (str): Error message + err_code (Optional[str], optional): Error code. Defaults to "E000X". + """ + return Result(success=False, err_code=err_code, err_msg=msg, data=None) diff --git a/dbgpt/serve/core/service.py b/dbgpt/serve/core/service.py new file mode 100644 index 000000000..401b93517 --- /dev/null +++ b/dbgpt/serve/core/service.py @@ -0,0 +1,33 @@ +from abc import ABC, abstractmethod +from typing import Generic +from dbgpt.component import BaseComponent +from dbgpt.storage.metadata._base_dao import BaseDao, T, REQ, RES +from dbgpt.serve.core.config import BaseServeConfig + + +class BaseService(BaseComponent, Generic[T, REQ, RES], ABC): + name = "dbgpt_serve_base_service" + + def __init__(self, system_app): + super().__init__(system_app) + + @property + @abstractmethod + def dao(self) -> BaseDao[T, REQ, RES]: + """Returns the internal DAO.""" + + @property + @abstractmethod + def config(self) -> BaseServeConfig: + """Returns the internal ServeConfig.""" + + def create(self, request: REQ) -> RES: + """Create a new entity + + Args: + request (REQ): The request + + Returns: + RES: The response + """ + return self.dao.create(request) diff --git a/dbgpt/serve/prompt/__init__.py b/dbgpt/serve/prompt/__init__.py new file mode 100644 index 000000000..430e822f7 --- /dev/null +++ b/dbgpt/serve/prompt/__init__.py @@ -0,0 +1,2 @@ +# This is an auto-generated __init__.py file +# generated by `dbgpt new serve prompt` diff --git a/dbgpt/serve/prompt/api/__init__.py b/dbgpt/serve/prompt/api/__init__.py new file mode 100644 index 000000000..430e822f7 --- /dev/null +++ b/dbgpt/serve/prompt/api/__init__.py @@ -0,0 +1,2 @@ +# This is an auto-generated __init__.py file +# generated by `dbgpt new serve prompt` diff --git a/dbgpt/serve/prompt/api/endpoints.py b/dbgpt/serve/prompt/api/endpoints.py new file mode 100644 index 000000000..493ac2f85 --- /dev/null +++ b/dbgpt/serve/prompt/api/endpoints.py @@ -0,0 +1,114 @@ +from typing import Optional, List +from fastapi import APIRouter, Depends, Query + +from dbgpt.component import SystemApp +from dbgpt.serve.core import Result +from dbgpt.util import PaginationResult +from .schemas import ServeRequest, ServerResponse +from ..service.service import Service +from ..config import APP_NAME, SERVE_APP_NAME, ServeConfig, SERVE_SERVICE_COMPONENT_NAME + +router = APIRouter() + +# Add your API endpoints here + +global_system_app: Optional[SystemApp] = None + + +def get_service() -> Service: + """Get the service instance""" + return global_system_app.get_component(SERVE_SERVICE_COMPONENT_NAME, Service) + + +@router.get("/health") +async def health(): + """Health check endpoint""" + return {"status": "ok"} + + +# TODO: Compatible with old API, will be modified in the future +@router.post("/add", response_model=Result[ServerResponse]) +async def create( + request: ServeRequest, service: Service = Depends(get_service) +) -> Result[ServerResponse]: + """Create a new Prompt entity + + Args: + request (ServeRequest): The request + service (Service): The service + Returns: + ServerResponse: The response + """ + return Result.succ(service.create(request)) + + +@router.post("/update", response_model=Result[ServerResponse]) +async def update( + request: ServeRequest, service: Service = Depends(get_service) +) -> Result[ServerResponse]: + """Update a Prompt entity + + Args: + request (ServeRequest): The request + service (Service): The service + Returns: + ServerResponse: The response + """ + return Result.succ(service.update(request)) + + +@router.post("/delete", response_model=Result[None]) +async def delete( + request: ServeRequest, service: Service = Depends(get_service) +) -> Result[None]: + """Delete a Prompt entity + + Args: + request (ServeRequest): The request + service (Service): The service + Returns: + ServerResponse: The response + """ + return Result.succ(service.delete(request)) + + +@router.post("/list", response_model=Result[List[ServerResponse]]) +async def query( + request: ServeRequest, service: Service = Depends(get_service) +) -> Result[List[ServerResponse]]: + """Query Prompt entities + + Args: + request (ServeRequest): The request + service (Service): The service + Returns: + List[ServerResponse]: The response + """ + return Result.succ(service.get_list(request)) + + +@router.post("/query_page", response_model=Result[PaginationResult[ServerResponse]]) +async def query_page( + request: ServeRequest, + page: Optional[int] = Query(default=1, description="current page"), + page_size: Optional[int] = Query(default=20, description="page size"), + service: Service = Depends(get_service), +) -> Result[PaginationResult[ServerResponse]]: + """Query Prompt entities + + Args: + request (ServeRequest): The request + page (int): The page number + page_size (int): The page size + service (Service): The service + Returns: + ServerResponse: The response + """ + return Result.succ(service.get_list_by_page(request, page, page_size)) + + +def init_endpoints(system_app: SystemApp) -> None: + """Initialize the endpoints""" + global global_system_app + system_app.register(Service) + global_system_app = system_app diff --git a/dbgpt/serve/prompt/api/schemas.py b/dbgpt/serve/prompt/api/schemas.py new file mode 100644 index 000000000..a6131dea5 --- /dev/null +++ b/dbgpt/serve/prompt/api/schemas.py @@ -0,0 +1,73 @@ +# Define your Pydantic schemas here +from typing import Optional +from dbgpt._private.pydantic import BaseModel, Field + + +class ServeRequest(BaseModel): + """Prompt request model""" + + chat_scene: Optional[str] = None + """ + The chat scene, e.g. chat_with_db_execute, chat_excel, chat_with_db_qa. + """ + + sub_chat_scene: Optional[str] = None + """ + The sub chat scene. + """ + + prompt_type: Optional[str] = None + """ + The prompt type, either common or private. + """ + + content: Optional[str] = None + """ + The prompt content. + """ + + user_name: Optional[str] = None + """ + The user name. + """ + + sys_code: Optional[str] = None + """ + System code + """ + + prompt_name: Optional[str] = None + """ + The prompt name. + """ + + +class ServerResponse(BaseModel): + """Prompt response model""" + + id: int = None + """chat_scene: for example: chat_with_db_execute, chat_excel, chat_with_db_qa""" + + chat_scene: str = None + + """sub_chat_scene: sub chat scene""" + sub_chat_scene: str = None + + """prompt_type: common or private""" + prompt_type: str = None + + """content: prompt content""" + content: str = None + + """user_name: user name""" + user_name: str = None + + sys_code: Optional[str] = None + """ + System code + """ + + """prompt_name: prompt name""" + prompt_name: str = None + gmt_created: str = None + gmt_modified: str = None diff --git a/dbgpt/serve/prompt/config.py b/dbgpt/serve/prompt/config.py new file mode 100644 index 000000000..6d033e3eb --- /dev/null +++ b/dbgpt/serve/prompt/config.py @@ -0,0 +1,19 @@ +from dataclasses import dataclass + +from dbgpt.serve.core import BaseServeConfig + + +APP_NAME = "prompt" +SERVE_APP_NAME = "dbgpt_serve_prompt" +SERVE_APP_NAME_HUMP = "dbgpt_serve_Prompt" +SERVE_CONFIG_KEY_PREFIX = "dbgpt.serve.prompt." +SERVE_SERVICE_COMPONENT_NAME = f"{SERVE_APP_NAME}_service" +# Database table name +SERVER_APP_TABLE_NAME = "dbgpt_serve_prompt" + + +@dataclass +class ServeConfig(BaseServeConfig): + """Parameters for the serve command""" + + # TODO: add your own parameters here diff --git a/dbgpt/serve/prompt/dependencies.py b/dbgpt/serve/prompt/dependencies.py new file mode 100644 index 000000000..8598ecd97 --- /dev/null +++ b/dbgpt/serve/prompt/dependencies.py @@ -0,0 +1 @@ +# Define your dependencies here diff --git a/dbgpt/serve/prompt/models/__init__.py b/dbgpt/serve/prompt/models/__init__.py new file mode 100644 index 000000000..430e822f7 --- /dev/null +++ b/dbgpt/serve/prompt/models/__init__.py @@ -0,0 +1,2 @@ +# This is an auto-generated __init__.py file +# generated by `dbgpt new serve prompt` diff --git a/dbgpt/serve/prompt/models/models.py b/dbgpt/serve/prompt/models/models.py new file mode 100644 index 000000000..812ff8fae --- /dev/null +++ b/dbgpt/serve/prompt/models/models.py @@ -0,0 +1,95 @@ +"""This is an auto-generated model file +You can define your own models and DAOs here +""" +from typing import Union, Any, Dict +from datetime import datetime +from sqlalchemy import Column, Integer, String, Index, Text, DateTime, UniqueConstraint +from dbgpt.storage.metadata import Model, BaseDao, db +from ..api.schemas import ServeRequest, ServerResponse +from ..config import ServeConfig, SERVER_APP_TABLE_NAME + + +class ServeEntity(Model): + __tablename__ = "prompt_manage" + __table_args__ = ( + UniqueConstraint("prompt_name", "sys_code", name="uk_prompt_name_sys_code"), + ) + id = Column(Integer, primary_key=True, comment="Auto increment id") + + chat_scene = Column(String(100)) + sub_chat_scene = Column(String(100)) + prompt_type = Column(String(100)) + prompt_name = Column(String(512)) + content = Column(Text) + user_name = Column(String(128)) + sys_code = Column(String(128), index=True, nullable=True, comment="System code") + gmt_created = Column(DateTime, default=datetime.now, comment="Record creation time") + gmt_modified = Column(DateTime, default=datetime.now, comment="Record update time") + + def __repr__(self): + return f"ServeEntity(id={self.id}, chat_scene='{self.chat_scene}', sub_chat_scene='{self.sub_chat_scene}', prompt_type='{self.prompt_type}', prompt_name='{self.prompt_name}', content='{self.content}',user_name='{self.user_name}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')" + + +class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]): + """The DAO class for Prompt""" + + def __init__(self, serve_config: ServeConfig): + super().__init__() + self._serve_config = serve_config + + def from_request(self, request: Union[ServeRequest, Dict[str, Any]]) -> ServeEntity: + """Convert the request to an entity + + Args: + request (Union[ServeRequest, Dict[str, Any]]): The request + + Returns: + T: The entity + """ + request_dict = request.dict() if isinstance(request, ServeRequest) else request + entity = ServeEntity(**request_dict) + return entity + + def to_request(self, entity: ServeEntity) -> ServeRequest: + """Convert the entity to a request + + Args: + entity (T): The entity + + Returns: + REQ: The request + """ + return ServeRequest( + chat_scene=entity.chat_scene, + sub_chat_scene=entity.sub_chat_scene, + prompt_type=entity.prompt_type, + prompt_name=entity.prompt_name, + content=entity.content, + user_name=entity.user_name, + sys_code=entity.sys_code, + ) + + def to_response(self, entity: ServeEntity) -> ServerResponse: + """Convert the entity to a response + + Args: + entity (T): The entity + + Returns: + RES: The response + """ + # TODO implement your own logic here, transfer the entity to a response + gmt_created_str = entity.gmt_created.strftime("%Y-%m-%d %H:%M:%S") + gmt_modified_str = entity.gmt_modified.strftime("%Y-%m-%d %H:%M:%S") + return ServerResponse( + id=entity.id, + chat_scene=entity.chat_scene, + sub_chat_scene=entity.sub_chat_scene, + prompt_type=entity.prompt_type, + prompt_name=entity.prompt_name, + content=entity.content, + user_name=entity.user_name, + sys_code=entity.sys_code, + gmt_created=gmt_created_str, + gmt_modified=gmt_modified_str, + ) diff --git a/dbgpt/serve/prompt/serve.py b/dbgpt/serve/prompt/serve.py new file mode 100644 index 000000000..9a19adf7b --- /dev/null +++ b/dbgpt/serve/prompt/serve.py @@ -0,0 +1,36 @@ +from typing import List, Optional +from dbgpt.component import BaseComponent, SystemApp + +from .api.endpoints import router, init_endpoints +from .config import SERVE_APP_NAME, SERVE_APP_NAME_HUMP, APP_NAME + + +class Serve(BaseComponent): + name = SERVE_APP_NAME + + def __init__( + self, + system_app: SystemApp, + api_prefix: Optional[str] = f"/api/v1/serve/{APP_NAME}", + tags: Optional[List[str]] = None, + ): + if tags is None: + tags = [SERVE_APP_NAME_HUMP] + self._system_app = None + self._api_prefix = api_prefix + self._tags = tags + + def init_app(self, system_app: SystemApp): + self._system_app = system_app + self._system_app.app.include_router( + router, prefix=self._api_prefix, tags=self._tags + ) + init_endpoints(self._system_app) + + def before_start(self): + """Called before the start of the application. + + You can do some initialization here. + """ + # import your own module here to ensure the module is loaded before the application starts + from .models.models import ServeEntity diff --git a/dbgpt/serve/prompt/service/__init__.py b/dbgpt/serve/prompt/service/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/dbgpt/serve/prompt/service/service.py b/dbgpt/serve/prompt/service/service.py new file mode 100644 index 000000000..7d1273161 --- /dev/null +++ b/dbgpt/serve/prompt/service/service.py @@ -0,0 +1,117 @@ +from typing import Optional, List +from dbgpt.component import BaseComponent, SystemApp +from dbgpt.storage.metadata import BaseDao +from dbgpt.util.pagination_utils import PaginationResult +from dbgpt.serve.core import BaseService +from ..models.models import ServeDao, ServeEntity +from ..api.schemas import ServeRequest, ServerResponse +from ..config import SERVE_SERVICE_COMPONENT_NAME, SERVE_CONFIG_KEY_PREFIX, ServeConfig + + +class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]): + """The service class for Prompt""" + + name = SERVE_SERVICE_COMPONENT_NAME + + def __init__(self, system_app: SystemApp): + self._system_app = None + self._serve_config: ServeConfig = None + self._dao: ServeDao = None + super().__init__(system_app) + + def init_app(self, system_app: SystemApp) -> None: + """Initialize the service + + Args: + system_app (SystemApp): The system app + """ + self._serve_config = ServeConfig.from_app_config( + system_app.config, SERVE_CONFIG_KEY_PREFIX + ) + self._dao = ServeDao(self._serve_config) + self._system_app = system_app + + @property + def dao(self) -> BaseDao[ServeEntity, ServeRequest, ServerResponse]: + """Returns the internal DAO.""" + return self._dao + + @property + def config(self) -> ServeConfig: + """Returns the internal ServeConfig.""" + return self._serve_config + + def update(self, request: ServeRequest) -> ServerResponse: + """Update a Prompt entity + + Args: + request (ServeRequest): The request + + Returns: + ServerResponse: The response + """ + # Build the query request from the request + query_request = { + "prompt_name": request.prompt_name, + "sys_code": request.sys_code, + } + return self.dao.update(query_request, update_request=request) + + def get(self, request: ServeRequest) -> Optional[ServerResponse]: + """Get a Prompt entity + + Args: + request (ServeRequest): The request + + Returns: + ServerResponse: The response + """ + # TODO: implement your own logic here + # Build the query request from the request + query_request = request + return self.dao.get_one(query_request) + + def delete(self, request: ServeRequest) -> None: + """Delete a Prompt entity + + Args: + request (ServeRequest): The request + """ + + # TODO: implement your own logic here + # Build the query request from the request + query_request = { + "prompt_name": request.prompt_name, + "sys_code": request.sys_code, + } + self.dao.delete(query_request) + + def get_list(self, request: ServeRequest) -> List[ServerResponse]: + """Get a list of Prompt entities + + Args: + request (ServeRequest): The request + + Returns: + List[ServerResponse]: The response + """ + # TODO: implement your own logic here + # Build the query request from the request + query_request = request + return self.dao.get_list(query_request) + + def get_list_by_page( + self, request: ServeRequest, page: int, page_size: int + ) -> PaginationResult[ServerResponse]: + """Get a list of Prompt entities by page + + Args: + request (ServeRequest): The request + page (int): The page number + page_size (int): The page size + + Returns: + List[ServerResponse]: The response + """ + query_request = request + return self.dao.get_list_page(query_request, page, page_size) diff --git a/dbgpt/serve/utils/__init__.py b/dbgpt/serve/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/dbgpt/serve/utils/_template_files/default_serve_template/__init__.py b/dbgpt/serve/utils/_template_files/default_serve_template/__init__.py new file mode 100644 index 000000000..321e6a94a --- /dev/null +++ b/dbgpt/serve/utils/_template_files/default_serve_template/__init__.py @@ -0,0 +1,2 @@ +# This is an auto-generated __init__.py file +# generated by `dbgpt new serve {__template_app_name__}` diff --git a/dbgpt/serve/utils/_template_files/default_serve_template/api/__init__.py b/dbgpt/serve/utils/_template_files/default_serve_template/api/__init__.py new file mode 100644 index 000000000..321e6a94a --- /dev/null +++ b/dbgpt/serve/utils/_template_files/default_serve_template/api/__init__.py @@ -0,0 +1,2 @@ +# This is an auto-generated __init__.py file +# generated by `dbgpt new serve {__template_app_name__}` diff --git a/dbgpt/serve/utils/_template_files/default_serve_template/api/endpoints.py b/dbgpt/serve/utils/_template_files/default_serve_template/api/endpoints.py new file mode 100644 index 000000000..fc55669cb --- /dev/null +++ b/dbgpt/serve/utils/_template_files/default_serve_template/api/endpoints.py @@ -0,0 +1,98 @@ +from typing import Optional, List +from fastapi import APIRouter, Depends, Query + +from dbgpt.component import SystemApp +from dbgpt.serve.core import Result +from dbgpt.util import PaginationResult +from .schemas import ServeRequest, ServerResponse +from ..service.service import Service +from ..config import APP_NAME, SERVE_APP_NAME, ServeConfig, SERVE_SERVICE_COMPONENT_NAME + +router = APIRouter() + +# Add your API endpoints here + +global_system_app: Optional[SystemApp] = None + + +def get_service() -> Service: + """Get the service instance""" + return global_system_app.get_component(SERVE_SERVICE_COMPONENT_NAME, Service) + + +@router.get("/health") +async def health(): + """Health check endpoint""" + return {"status": "ok"} + + +@router.post("/", response_model=Result[ServerResponse]) +async def create( + request: ServeRequest, service: Service = Depends(get_service) +) -> Result[ServerResponse]: + """Create a new {__template_app_name__hump__} entity + + Args: + request (ServeRequest): The request + service (Service): The service + Returns: + ServerResponse: The response + """ + return Result.succ(service.create(request)) + + +@router.put("/", response_model=Result[ServerResponse]) +async def update( + request: ServeRequest, service: Service = Depends(get_service) +) -> Result[ServerResponse]: + """Update a {__template_app_name__hump__} entity + + Args: + request (ServeRequest): The request + service (Service): The service + Returns: + ServerResponse: The response + """ + return Result.succ(service.update(request)) + + +@router.post("/query", response_model=Result[ServerResponse]) +async def query( + request: ServeRequest, service: Service = Depends(get_service) +) -> Result[ServerResponse]: + """Query {__template_app_name__hump__} entities + + Args: + request (ServeRequest): The request + service (Service): The service + Returns: + ServerResponse: The response + """ + return Result.succ(service.get(request)) + + +@router.post("/query_page", response_model=Result[PaginationResult[ServerResponse]]) +async def query_page( + request: ServeRequest, + page: Optional[int] = Query(default=1, description="current page"), + page_size: Optional[int] = Query(default=20, description="page size"), + service: Service = Depends(get_service), +) -> Result[PaginationResult[ServerResponse]]: + """Query {__template_app_name__hump__} entities + + Args: + request (ServeRequest): The request + page (int): The page number + page_size (int): The page size + service (Service): The service + Returns: + ServerResponse: The response + """ + return Result.succ(service.get_list_by_page(request, page, page_size)) + + +def init_endpoints(system_app: SystemApp) -> None: + """Initialize the endpoints""" + global global_system_app + system_app.register(Service) + global_system_app = system_app diff --git a/dbgpt/serve/utils/_template_files/default_serve_template/api/schemas.py b/dbgpt/serve/utils/_template_files/default_serve_template/api/schemas.py new file mode 100644 index 000000000..d123aa159 --- /dev/null +++ b/dbgpt/serve/utils/_template_files/default_serve_template/api/schemas.py @@ -0,0 +1,14 @@ +# Define your Pydantic schemas here +from dbgpt._private.pydantic import BaseModel, Field + + +class ServeRequest(BaseModel): + """{__template_app_name__hump__} request model""" + + # TODO define your own fields here + + +class ServerResponse(BaseModel): + """{__template_app_name__hump__} response model""" + + # TODO define your own fields here diff --git a/dbgpt/serve/utils/_template_files/default_serve_template/config.py b/dbgpt/serve/utils/_template_files/default_serve_template/config.py new file mode 100644 index 000000000..25a60a020 --- /dev/null +++ b/dbgpt/serve/utils/_template_files/default_serve_template/config.py @@ -0,0 +1,19 @@ +from dataclasses import dataclass + +from dbgpt.serve.core import BaseServeConfig + + +APP_NAME = "{__template_app_name__all_lower__}" +SERVE_APP_NAME = "dbgpt_serve_{__template_app_name__all_lower__}" +SERVE_APP_NAME_HUMP = "dbgpt_serve_{__template_app_name__hump__}" +SERVE_CONFIG_KEY_PREFIX = "dbgpt.serve.{__template_app_name__all_lower__}." +SERVE_SERVICE_COMPONENT_NAME = f"{SERVE_APP_NAME}_service" +# Database table name +SERVER_APP_TABLE_NAME = "dbgpt_serve_{__template_app_name__all_lower__}" + + +@dataclass +class ServeConfig(BaseServeConfig): + """Parameters for the serve command""" + + # TODO: add your own parameters here diff --git a/dbgpt/serve/utils/_template_files/default_serve_template/dependencies.py b/dbgpt/serve/utils/_template_files/default_serve_template/dependencies.py new file mode 100644 index 000000000..8598ecd97 --- /dev/null +++ b/dbgpt/serve/utils/_template_files/default_serve_template/dependencies.py @@ -0,0 +1 @@ +# Define your dependencies here diff --git a/dbgpt/serve/utils/_template_files/default_serve_template/models/__init__.py b/dbgpt/serve/utils/_template_files/default_serve_template/models/__init__.py new file mode 100644 index 000000000..321e6a94a --- /dev/null +++ b/dbgpt/serve/utils/_template_files/default_serve_template/models/__init__.py @@ -0,0 +1,2 @@ +# This is an auto-generated __init__.py file +# generated by `dbgpt new serve {__template_app_name__}` diff --git a/dbgpt/serve/utils/_template_files/default_serve_template/models/models.py b/dbgpt/serve/utils/_template_files/default_serve_template/models/models.py new file mode 100644 index 000000000..516635df5 --- /dev/null +++ b/dbgpt/serve/utils/_template_files/default_serve_template/models/models.py @@ -0,0 +1,68 @@ +"""This is an auto-generated model file +You can define your own models and DAOs here +""" +from typing import Union, Any, Dict +from datetime import datetime +from sqlalchemy import Column, Integer, String, Index, Text, DateTime +from dbgpt.storage.metadata import Model, BaseDao, db +from ..api.schemas import ServeRequest, ServerResponse +from ..config import ServeConfig, SERVER_APP_TABLE_NAME + + +class ServeEntity(Model): + __tablename__ = SERVER_APP_TABLE_NAME + id = Column(Integer, primary_key=True, comment="Auto increment id") + + # TODO: define your own fields here + + gmt_created = Column(DateTime, default=datetime.now, comment="Record creation time") + gmt_modified = Column(DateTime, default=datetime.now, comment="Record update time") + + def __repr__(self): + return f"ServeEntity(id={self.id}, gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')" + + +class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]): + """The DAO class for {__template_app_name__hump__}""" + + def __init__(self, serve_config: ServeConfig): + super().__init__() + self._serve_config = serve_config + + def from_request(self, request: Union[ServeRequest, Dict[str, Any]]) -> ServeEntity: + """Convert the request to an entity + + Args: + request (Union[ServeRequest, Dict[str, Any]]): The request + + Returns: + T: The entity + """ + request_dict = request.dict() if isinstance(request, ServeRequest) else request + entity = ServeEntity(**request_dict) + # TODO implement your own logic here, transfer the request_dict to an entity + return entity + + def to_request(self, entity: ServeEntity) -> ServeRequest: + """Convert the entity to a request + + Args: + entity (T): The entity + + Returns: + REQ: The request + """ + # TODO implement your own logic here, transfer the entity to a request + return ServeRequest() + + def to_response(self, entity: ServeEntity) -> ServerResponse: + """Convert the entity to a response + + Args: + entity (T): The entity + + Returns: + RES: The response + """ + # TODO implement your own logic here, transfer the entity to a response + return ServerResponse() diff --git a/dbgpt/serve/utils/_template_files/default_serve_template/serve.py b/dbgpt/serve/utils/_template_files/default_serve_template/serve.py new file mode 100644 index 000000000..ef99ce936 --- /dev/null +++ b/dbgpt/serve/utils/_template_files/default_serve_template/serve.py @@ -0,0 +1,38 @@ +from typing import List, Optional +from dbgpt.component import BaseComponent, SystemApp + +from .api.endpoints import router, init_endpoints +from .config import SERVE_APP_NAME, SERVE_APP_NAME_HUMP, APP_NAME + + +class Serve(BaseComponent): + """Serve component for DB-GPT""" + + name = SERVE_APP_NAME + + def __init__( + self, + system_app: SystemApp, + api_prefix: Optional[str] = f"/api/v1/serve/{APP_NAME}", + tags: Optional[List[str]] = None, + ): + if tags is None: + tags = [SERVE_APP_NAME_HUMP] + self._system_app = None + self._api_prefix = api_prefix + self._tags = tags + + def init_app(self, system_app: SystemApp): + self._system_app = system_app + self._system_app.app.include_router( + router, prefix=self._api_prefix, tags=self._tags + ) + init_endpoints(self._system_app) + + def before_start(self): + """Called before the start of the application. + + You can do some initialization here. + """ + # import your own module here to ensure the module is loaded before the application starts + from .models.models import ServeEntity diff --git a/dbgpt/serve/utils/_template_files/default_serve_template/service/__init__.py b/dbgpt/serve/utils/_template_files/default_serve_template/service/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/dbgpt/serve/utils/_template_files/default_serve_template/service/service.py b/dbgpt/serve/utils/_template_files/default_serve_template/service/service.py new file mode 100644 index 000000000..89f06b2a2 --- /dev/null +++ b/dbgpt/serve/utils/_template_files/default_serve_template/service/service.py @@ -0,0 +1,116 @@ +from typing import Optional, List +from dbgpt.component import BaseComponent, SystemApp +from dbgpt.storage.metadata import BaseDao +from dbgpt.util.pagination_utils import PaginationResult +from dbgpt.serve.core import BaseService +from ..models.models import ServeDao, ServeEntity +from ..api.schemas import ServeRequest, ServerResponse +from ..config import SERVE_SERVICE_COMPONENT_NAME, SERVE_CONFIG_KEY_PREFIX, ServeConfig + + +class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]): + """The service class for {__template_app_name__hump__}""" + + name = SERVE_SERVICE_COMPONENT_NAME + + def __init__(self, system_app: SystemApp): + self._system_app = None + self._serve_config: ServeConfig = None + self._dao: ServeDao = None + super().__init__(system_app) + + def init_app(self, system_app: SystemApp) -> None: + """Initialize the service + + Args: + system_app (SystemApp): The system app + """ + self._serve_config = ServeConfig.from_app_config( + system_app.config, SERVE_CONFIG_KEY_PREFIX + ) + self._dao = ServeDao(self._serve_config) + self._system_app = system_app + + @property + def dao(self) -> BaseDao[ServeEntity, ServeRequest, ServerResponse]: + """Returns the internal DAO.""" + return self._dao + + @property + def config(self) -> ServeConfig: + """Returns the internal ServeConfig.""" + return self._serve_config + + def update(self, request: ServeRequest) -> ServerResponse: + """Update a {__template_app_name__hump__} entity + + Args: + request (ServeRequest): The request + + Returns: + ServerResponse: The response + """ + # TODO: implement your own logic here + # Build the query request from the request + query_request = { + # "id": request.id + } + return self.dao.update(query_request, update_request=request) + + def get(self, request: ServeRequest) -> Optional[ServerResponse]: + """Get a {__template_app_name__hump__} entity + + Args: + request (ServeRequest): The request + + Returns: + ServerResponse: The response + """ + # TODO: implement your own logic here + # Build the query request from the request + query_request = request + return self.dao.get_one(query_request) + + def delete(self, request: ServeRequest) -> None: + """Delete a {__template_app_name__hump__} entity + + Args: + request (ServeRequest): The request + """ + + # TODO: implement your own logic here + # Build the query request from the request + query_request = { + # "id": request.id + } + self.dao.delete(query_request) + + def get_list(self, request: ServeRequest) -> List[ServerResponse]: + """Get a list of {__template_app_name__hump__} entities + + Args: + request (ServeRequest): The request + + Returns: + List[ServerResponse]: The response + """ + # TODO: implement your own logic here + # Build the query request from the request + query_request = request + return self.dao.get_list(query_request) + + def get_list_by_page( + self, request: ServeRequest, page: int, page_size: int + ) -> PaginationResult[ServerResponse]: + """Get a list of {__template_app_name__hump__} entities by page + + Args: + request (ServeRequest): The request + page (int): The page number + page_size (int): The page size + + Returns: + List[ServerResponse]: The response + """ + query_request = request + return self.dao.get_list_page(query_request, page, page_size) diff --git a/dbgpt/serve/utils/cli.py b/dbgpt/serve/utils/cli.py new file mode 100644 index 000000000..eb4ca359f --- /dev/null +++ b/dbgpt/serve/utils/cli.py @@ -0,0 +1,83 @@ +import os +import click + + +@click.command(name="serve") +@click.option( + "-n", + "--name", + required=True, + type=str, + show_default=True, + help="The name of the serve module to create", +) +@click.option( + "-t", + "--template", + required=False, + type=str, + default="default_serve_template", + show_default=True, + help="The template to use to create the serve module", +) +def serve(name: str, template: str): + """Create a serve module structure with the given name.""" + from dbgpt.configs.model_config import ROOT_PATH + + base_path = os.path.join(ROOT_PATH, "dbgpt", "serve", name) + template_path = os.path.join( + ROOT_PATH, "dbgpt", "serve", "utils", "_template_files", template + ) + if not os.path.exists(template_path): + raise ValueError(f"Template '{template}' not found") + if os.path.exists(base_path): + # TODO: backup the old serve module + click.confirm( + f"Serve module '{name}' already exists in {base_path}, do you want to overwrite it?", + abort=True, + ) + import shutil + + shutil.rmtree(base_path) + + copy_template_files(template_path, base_path, name) + click.echo(f"Serve application '{name}' created successfully in {base_path}") + + +def replace_template_variables(content: str, app_name: str): + """Replace the template variables in the given content with the given app name.""" + template_values = { + "{__template_app_name__}": app_name, + "{__template_app_name__all_lower__}": app_name.lower(), + "{__template_app_name__hump__}": "".join( + part.capitalize() for part in app_name.split("_") + ), + } + + for key in sorted(template_values, key=len, reverse=True): + content = content.replace(key, template_values[key]) + + return content + + +def copy_template_files(src_dir: str, dst_dir: str, app_name: str): + for root, dirs, files in os.walk(src_dir): + relative_path = os.path.relpath(root, src_dir) + if relative_path == ".": + relative_path = "" + + target_dir = os.path.join(dst_dir, relative_path) + os.makedirs(target_dir, exist_ok=True) + + for file in files: + try: + with open(os.path.join(root, file), "r") as f: + content = f.read() + + content = replace_template_variables(content, app_name) + + with open(os.path.join(target_dir, file), "w") as f: + f.write(content) + except Exception as e: + click.echo(f"Error copying file {file} from {src_dir} to {dst_dir}") + raise e diff --git a/dbgpt/storage/chat_history/chat_history_db.py b/dbgpt/storage/chat_history/chat_history_db.py index 05002f189..e304d7aa0 100644 --- a/dbgpt/storage/chat_history/chat_history_db.py +++ b/dbgpt/storage/chat_history/chat_history_db.py @@ -8,16 +8,14 @@ class ChatHistoryEntity(Model): __tablename__ = "chat_history" + __table_args__ = (UniqueConstraint("conv_uid", name="uk_conv_uid"),) id = Column( Integer, primary_key=True, autoincrement=True, comment="autoincrement id" ) - __table_args__ = { - "mysql_charset": "utf8mb4", - "mysql_collate": "utf8mb4_unicode_ci", - } conv_uid = Column( String(255), - unique=True, + # Change from False to True, the alembic migration will fail, so we use UniqueConstraint to replace it + unique=False, nullable=False, comment="Conversation record unique id", ) @@ -41,13 +39,12 @@ class ChatHistoryEntity(Model): class ChatHistoryMessageEntity(Model): __tablename__ = "chat_history_message" + __table_args__ = ( + UniqueConstraint("conv_uid", "index", name="uk_conversation_message"), + ) id = Column( Integer, primary_key=True, autoincrement=True, comment="autoincrement id" ) - __table_args__ = { - "mysql_charset": "utf8mb4", - "mysql_collate": "utf8mb4_unicode_ci", - } conv_uid = Column( String(255), unique=False, @@ -61,10 +58,9 @@ class ChatHistoryMessageEntity(Model): ) gmt_created = Column(DateTime, default=datetime.now, comment="Record creation time") gmt_modified = Column(DateTime, default=datetime.now, comment="Record update time") - UniqueConstraint("conv_uid", "index", name="uk_conversation_message") -class ChatHistoryDao(BaseDao[ChatHistoryEntity]): +class ChatHistoryDao(BaseDao): def list_last_20( self, user_name: Optional[str] = None, sys_code: Optional[str] = None ): @@ -81,7 +77,7 @@ def list_last_20( session.close() return result - def update(self, entity: ChatHistoryEntity): + def raw_update(self, entity: ChatHistoryEntity): session = self.get_raw_session() try: updated = session.merge(entity) @@ -101,7 +97,7 @@ def update_message_by_uid(self, message: str, conv_uid: str): finally: session.close() - def delete(self, conv_uid: int): + def raw_delete(self, conv_uid: int): if conv_uid is None: raise Exception("conv_uid is None") with self.session() as session: diff --git a/dbgpt/storage/chat_history/store_type/meta_db_history.py b/dbgpt/storage/chat_history/store_type/meta_db_history.py index d4280e146..48dfdd2fe 100644 --- a/dbgpt/storage/chat_history/store_type/meta_db_history.py +++ b/dbgpt/storage/chat_history/store_type/meta_db_history.py @@ -37,7 +37,7 @@ def create(self, chat_mode, summary: str, user_name: str) -> None: chat_history.summary = summary chat_history.user_name = user_name - self.chat_history_dao.update(chat_history) + self.chat_history_dao.raw_update(chat_history) except Exception as e: logger.error("init create conversation log error!" + str(e)) @@ -65,7 +65,7 @@ def append(self, once_message: OnceConversation) -> None: # Avoid (pymysql.err.DataError) (1406, "Data too long for column 'messages') #chat_history.messages = json.dumps(conversations, ensure_ascii=False) - self.chat_history_dao.update(chat_history) + self.chat_history_dao.raw_update(chat_history) def update(self, messages: List[OnceConversation]) -> None: self.chat_history_dao.update_message_by_uid( @@ -73,7 +73,7 @@ def update(self, messages: List[OnceConversation]) -> None: ) def delete(self) -> bool: - self.chat_history_dao.delete(self.chat_seesion_id) + self.chat_history_dao.raw_delete(self.chat_seesion_id) def conv_info(self, conv_uid: str = None) -> None: logger.info("conv_info:{}", conv_uid) diff --git a/dbgpt/storage/metadata/_base_dao.py b/dbgpt/storage/metadata/_base_dao.py index 93ff289b4..78ef0c4aa 100644 --- a/dbgpt/storage/metadata/_base_dao.py +++ b/dbgpt/storage/metadata/_base_dao.py @@ -1,18 +1,27 @@ from contextlib import contextmanager -from typing import TypeVar, Generic, Any, Optional +from typing import TypeVar, Generic, Any, Optional, Dict, Union, List from sqlalchemy.orm.session import Session +from dbgpt.util.pagination_utils import PaginationResult +# The entity type T = TypeVar("T") +# The request schema type +REQ = TypeVar("REQ") +# The response schema type +RES = TypeVar("RES") -from .db_manager import db, DatabaseManager +from .db_manager import db, DatabaseManager, BaseQuery -class BaseDao(Generic[T]): +QUERY_SPEC = Union[REQ, Dict[str, Any]] + + +class BaseDao(Generic[T, REQ, RES]): """The base class for all DAOs. Examples: .. code-block:: python - class UserDao(BaseDao[User]): + class UserDao(BaseDao): def get_user_by_name(self, name: str) -> User: with self.session() as session: return session.query(User).filter(User.name == name).first() @@ -70,3 +79,184 @@ def session(self) -> Session: """ with self._db_manager.session() as session: yield session + + def from_request(self, request: QUERY_SPEC) -> T: + """Convert a request schema object to an entity object. + + Args: + request (REQ): The request schema object or dict for query. + + Returns: + T: The entity object. + """ + raise NotImplementedError + + def to_request(self, entity: T) -> REQ: + """Convert an entity object to a request schema object. + + Args: + entity (T): The entity object. + + Returns: + REQ: The request schema object. + """ + raise NotImplementedError + + def from_response(self, response: RES) -> T: + """Convert a response schema object to an entity object. + + Args: + response (RES): The response schema object. + + Returns: + T: The entity object. + """ + raise NotImplementedError + + def to_response(self, entity: T) -> RES: + """Convert an entity object to a response schema object. + + Args: + entity (T): The entity object. + + Returns: + RES: The response schema object. + """ + raise NotImplementedError + + def create(self, request: REQ) -> RES: + """Create an entity object. + + Args: + request (REQ): The request schema object. + + Returns: + RES: The response schema object. + """ + entry = self.from_request(request) + with self.session() as session: + session.add(entry) + return self.get_one(self.to_request(entry)) + + def update(self, query_request: QUERY_SPEC, update_request: REQ) -> RES: + """Update an entity object. + + Args: + query_request (REQ): The request schema object or dict for query. + update_request (REQ): The request schema object for update. + Returns: + RES: The response schema object. + """ + with self.session() as session: + query = self._create_query_object(session, query_request) + entry = query.first() + if entry is None: + raise Exception("Invalid request") + for key, value in update_request.dict().items(): + setattr(entry, key, value) + session.merge(entry) + return self.get_one(self.to_request(entry)) + + def delete(self, query_request: QUERY_SPEC) -> None: + """Delete an entity object. + + Args: + query_request (REQ): The request schema object or dict for query. + """ + with self.session() as session: + result_list = self._get_entity_list(session, query_request) + if len(result_list) != 1: + raise ValueError( + f"Delete request should return one result, but got {len(result_list)}" + ) + session.delete(result_list[0]) + + def get_one(self, query_request: QUERY_SPEC) -> Optional[RES]: + """Get an entity object. + + Args: + query_request (REQ): The request schema object or dict for query. + + Returns: + Optional[RES]: The response schema object. + """ + with self.session() as session: + query = self._create_query_object(session, query_request) + result = query.first() + if result is None: + return None + return self.to_response(result) + + def get_list(self, query_request: QUERY_SPEC) -> List[RES]: + """Get a list of entity objects. + + Args: + query_request (REQ): The request schema object or dict for query. + Returns: + List[RES]: The response schema object. + """ + with self.session() as session: + result_list = self._get_entity_list(session, query_request) + return [self.to_response(item) for item in result_list] + + def _get_entity_list(self, session: Session, query_request: QUERY_SPEC) -> List[T]: + """Get a list of entity objects. + + Args: + session (Session): The session object. + query_request (REQ): The request schema object or dict for query. + Returns: + List[RES]: The response schema object. + """ + query = self._create_query_object(session, query_request) + result_list = query.all() + return result_list + + def get_list_page( + self, query_request: QUERY_SPEC, page: int, page_size: int + ) -> PaginationResult[RES]: + """Get a page of entity objects. + + Args: + query_request (REQ): The request schema object or dict for query. + page (int): The page number. + page_size (int): The page size. + + Returns: + PaginationResult: The pagination result. + """ + with self.session() as session: + query = self._create_query_object(session, query_request) + total_count = query.count() + items = query.offset((page - 1) * page_size).limit(page_size) + items = [self.to_response(item) for item in items] + total_pages = (total_count + page_size - 1) // page_size + + return PaginationResult( + items=items, + total_count=total_count, + total_pages=total_pages, + page=page, + page_size=page_size, + ) + + def _create_query_object( + self, session: Session, query_request: QUERY_SPEC + ) -> BaseQuery: + """Create a query object. + + Args: + session (Session): The session object. + query_request (QUERY_SPEC): The request schema object or dict for query. + Returns: + BaseQuery: The query object. + """ + model_cls = type(self.from_request(query_request)) + query = session.query(model_cls) + query_dict = ( + query_request if isinstance(query_request, dict) else query_request.dict() + ) + for key, value in query_dict.items(): + if value is not None: + query = query.filter(getattr(model_cls, key) == value) + return query diff --git a/dbgpt/storage/metadata/db_storage.py b/dbgpt/storage/metadata/db_storage.py index d85a1578d..2cc1a5118 100644 --- a/dbgpt/storage/metadata/db_storage.py +++ b/dbgpt/storage/metadata/db_storage.py @@ -103,7 +103,8 @@ def query(self, spec: QuerySpec, cls: Type[T]) -> List[T]: with self.session() as session: query = session.query(self._model_class) for key, value in spec.conditions.items(): - query = query.filter(getattr(self._model_class, key) == value) + if value is not None: + query = query.filter(getattr(self._model_class, key) == value) if spec.limit is not None: query = query.limit(spec.limit) if spec.offset is not None: @@ -124,5 +125,6 @@ def count(self, spec: QuerySpec, cls: Type[T]) -> int: with self.session() as session: query = session.query(self._model_class) for key, value in spec.conditions.items(): - query = query.filter(getattr(self._model_class, key) == value) + if value is not None: + query = query.filter(getattr(self._model_class, key) == value) return query.count() diff --git a/dbgpt/storage/metadata/tests/test_base_dao.py b/dbgpt/storage/metadata/tests/test_base_dao.py new file mode 100644 index 000000000..e7563d935 --- /dev/null +++ b/dbgpt/storage/metadata/tests/test_base_dao.py @@ -0,0 +1,152 @@ +from typing import Type, Optional, Union, Dict, Any +import pytest +from sqlalchemy import Column, Integer, String +from dbgpt._private.pydantic import BaseModel as PydanticBaseModel, Field +from dbgpt.storage.metadata.db_manager import ( + DatabaseManager, + PaginationResult, + create_model, + BaseModel, +) + +from .._base_dao import BaseDao + + +class UserRequest(PydanticBaseModel): + name: str = Field(..., description="User name") + age: Optional[int] = Field(default=-1, description="User age") + password: Optional[str] = Field(default="", description="User password") + + +class UserResponse(PydanticBaseModel): + id: int = Field(..., description="User id") + name: str = Field(..., description="User name") + age: Optional[int] = Field(default=-1, description="User age") + + +@pytest.fixture +def db(): + db = DatabaseManager() + db.init_db("sqlite:///:memory:") + return db + + +@pytest.fixture +def Model(db): + return create_model(db) + + +@pytest.fixture +def User(Model): + class User(Model): + __tablename__ = "user" + id = Column(Integer, primary_key=True) + name = Column(String(50), unique=True) + age = Column(Integer) + password = Column(String(50)) + + return User + + +@pytest.fixture +def user_req(): + return UserRequest(name="Edward Snowden", age=30, password="123456") + + +@pytest.fixture +def user_dao(db, User): + class UserDao(BaseDao[User, UserRequest, UserResponse]): + def from_request(self, request: Union[UserRequest, Dict[str, Any]]) -> User: + if isinstance(request, UserRequest): + return User(**request.dict()) + else: + return User(**request) + + def to_request(self, entity: User) -> UserRequest: + return UserRequest( + name=entity.name, age=entity.age, password=entity.password + ) + + def from_response(self, response: UserResponse) -> User: + return User(**response.dict()) + + def to_response(self, entity: User): + return UserResponse(id=entity.id, name=entity.name, age=entity.age) + + db.create_all() + return UserDao(db) + + +def test_create_user(db: DatabaseManager, User: Type[BaseModel], user_dao, user_req): + user_dao.create(user_req) + with db.session() as session: + user = session.query(User).first() + assert user.name == user_req.name + assert user.age == user_req.age + assert user.password == user_req.password + + +def test_update_user(db: DatabaseManager, User: Type[BaseModel], user_dao, user_req): + # Create a user + created_user_response = user_dao.create(user_req) + + # Update the user + updated_req = UserRequest(name=user_req.name, age=35, password="newpassword") + updated_user = user_dao.update( + query_request={"name": user_req.name}, update_request=updated_req + ) + assert updated_user.id == created_user_response.id + assert updated_user.age == 35 + + # Verify that the user is updated in the database + with db.session() as session: + user = session.query(User).get(created_user_response.id) + assert user.age == 35 + + +def test_get_user(db: DatabaseManager, User: Type[BaseModel], user_dao, user_req): + # Create a user + created_user_response = user_dao.create(user_req) + + # Query the user + fetched_user = user_dao.get_one({"name": user_req.name}) + assert fetched_user.id == created_user_response.id + assert fetched_user.name == user_req.name + assert fetched_user.age == user_req.age + + +def test_get_list_user(db: DatabaseManager, User: Type[BaseModel], user_dao): + for i in range(20): + user_dao.create( + UserRequest( + name=f"User {i}", age=i, password="123456" if i % 2 == 0 else "abcdefg" + ) + ) + # Query the user + fetched_user = user_dao.get_list({"password": "123456"}) + assert len(fetched_user) == 10 + + +def test_get_list_page_user(db: DatabaseManager, User: Type[BaseModel], user_dao): + for i in range(20): + user_dao.create( + UserRequest( + name=f"User {i}", age=i, password="123456" if i % 2 == 0 else "abcdefg" + ) + ) + page_result: PaginationResult = user_dao.get_list_page( + {"password": "123456"}, page=1, page_size=3 + ) + assert page_result.total_count == 10 + assert page_result.total_pages == 4 + assert len(page_result.items) == 3 + assert page_result.items[0].name == "User 0" + + # Test query next page + page_result: PaginationResult = user_dao.get_list_page( + {"password": "123456"}, page=2, page_size=3 + ) + assert page_result.total_count == 10 + assert page_result.total_pages == 4 + assert len(page_result.items) == 3 + assert page_result.items[0].name == "User 6" diff --git a/dbgpt/util/__init__.py b/dbgpt/util/__init__.py index 83798b7ad..74945e5ae 100644 --- a/dbgpt/util/__init__.py +++ b/dbgpt/util/__init__.py @@ -1,5 +1,17 @@ from .utils import ( get_gpu_memory, - server_error_msg, get_or_create_event_loop, ) +from .pagination_utils import PaginationResult +from .parameter_utils import BaseParameters, ParameterDescription, EnvArgumentParser +from .config_utils import AppConfig + +__ALL__ = [ + "get_gpu_memory", + "get_or_create_event_loop", + "PaginationResult", + "BaseParameters", + "ParameterDescription", + "EnvArgumentParser", + "AppConfig", +] diff --git a/dbgpt/util/_db_migration_utils.py b/dbgpt/util/_db_migration_utils.py index 2d0212467..13734960d 100644 --- a/dbgpt/util/_db_migration_utils.py +++ b/dbgpt/util/_db_migration_utils.py @@ -51,19 +51,50 @@ def create_alembic_config( def create_migration_script( - alembic_cfg: AlembicConfig, engine: Engine, message: str = "New migration" -) -> None: + alembic_cfg: AlembicConfig, + engine: Engine, + message: str = "New migration", + create_new_revision_if_noting_to_update: Optional[bool] = True, +) -> str: """Create migration script. Args: alembic_cfg: alembic config engine: sqlalchemy engine message: migration message - + create_new_revision_if_noting_to_update: Whether to create a new revision if there is nothing to update, + pass False to avoid creating a new revision if there is nothing to update, default is True + Returns: + The path of the generated migration script. """ + from alembic.script import ScriptDirectory + from alembic.runtime.migration import MigrationContext + + # Check if the database is up-to-date + script_dir = ScriptDirectory.from_config(alembic_cfg) with engine.connect() as connection: - alembic_cfg.attributes["connection"] = connection - command.revision(alembic_cfg, message, autogenerate=True) + context = MigrationContext.configure(connection=connection) + current_rev = context.get_current_revision() + head_rev = script_dir.get_current_head() + + logger.info( + f"alembic migration current revision: {current_rev}, latest revision: {head_rev}" + ) + should_create_revision = ( + (current_rev is None and head_rev is None) + or current_rev != head_rev + or create_new_revision_if_noting_to_update + ) + if should_create_revision: + with engine.connect() as connection: + alembic_cfg.attributes["connection"] = connection + revision = command.revision(alembic_cfg, message=message, autogenerate=True) + # Return the path of the generated migration script + return revision.path + elif current_rev == head_rev: + logger.info("No migration script to generate, database is up-to-date") + # If no new revision is created, return None or an appropriate message + return None def upgrade_database( @@ -82,6 +113,37 @@ def upgrade_database( command.upgrade(alembic_cfg, target_version) +def generate_sql_for_upgrade( + alembic_cfg: AlembicConfig, + engine: Engine, + target_version: Optional[str] = "head", + output_file: Optional[str] = "migration.sql", +) -> None: + """Generate SQL for upgrading database to target version. + + Args: + alembic_cfg: alembic config + engine: sqlalchemy engine + target_version: target version, default is head (latest version) + output_file: file to write the SQL script + + TODO: Can't generate SQL for most of the operations. + """ + import contextlib + import io + + with engine.connect() as connection, contextlib.redirect_stdout( + io.StringIO() + ) as stdout: + alembic_cfg.attributes["connection"] = connection + # Generating SQL instead of applying changes + command.upgrade(alembic_cfg, target_version, sql=True) + + # Write the generated SQL to a file + with open(output_file, "w", encoding="utf-8") as file: + file.write(stdout.getvalue()) + + def downgrade_database( alembic_cfg: AlembicConfig, engine: Engine, revision: str = "-1" ): @@ -160,9 +222,94 @@ def clean_alembic_migration(alembic_cfg: AlembicConfig, engine: Engine) -> None: rm -rf pilot/meta_data/alembic/versions/* rm -rf pilot/meta_data/alembic/dbgpt.db ``` + +If your database is a shared database, and you run DB-GPT in multiple instances, +you should make sure that all migration scripts are same in all instances, in this case, +wo strongly recommend you close migration feature by setting `--disable_alembic_upgrade`. +and use `dbgpt db migration` command to manage migration scripts. """ +def _check_database_migration_status(alembic_cfg: AlembicConfig, engine: Engine): + """Check if the database is at the latest migration revision. + + If your database is a shared database, and you run DB-GPT in multiple instances, + you should make sure that all migration scripts are same in all instances, in this case, + wo strongly recommend you close migration feature by setting `disable_alembic_upgrade` to True. + and use `dbgpt db migration` command to manage migration scripts. + + Args: + alembic_cfg: Alembic configuration object. + engine: SQLAlchemy engine instance. + Raises: + Exception: If the database is not at the latest revision. + """ + from alembic.script import ScriptDirectory + from alembic.runtime.migration import MigrationContext + + script = ScriptDirectory.from_config(alembic_cfg) + + def get_current_revision(engine): + with engine.connect() as connection: + context = MigrationContext.configure(connection=connection) + return context.get_current_revision() + + current_rev = get_current_revision(engine) + head_rev = script.get_current_head() + + script_info_msg = "Migration versions and their file paths:" + script_info_msg += f"\n{'='*40}Migration versions{'='*40}\n" + for revision in script.walk_revisions(base="base"): + current_marker = "(current)" if revision.revision == current_rev else "" + script_path = script.get_revision(revision.revision).path + script_info_msg += f"\n{revision.revision} {current_marker}: {revision.doc} (Path: {script_path})" + script_info_msg += f"\n{'='*90}" + + logger.info(script_info_msg) + + if current_rev != head_rev: + logger.error( + "Database is not at the latest revision. " + f"Current revision: {current_rev}, latest revision: {head_rev}\n" + "Please apply existing migration scripts before generating new ones. " + "Check the listed file paths for migration scripts.\n" + f"Also you can try the following solutions:\n{_MIGRATION_SOLUTION}\n" + ) + raise Exception( + "Check database migration status failed, you can see the error and solutions above" + ) + + +def _get_latest_revision(alembic_cfg: AlembicConfig, engine: Engine) -> str: + """Get the latest revision of the database. + + Args: + alembic_cfg: Alembic configuration object. + engine: SQLAlchemy engine instance. + + Returns: + The latest revision as a string. + """ + from alembic.runtime.migration import MigrationContext + + with engine.connect() as connection: + context = MigrationContext.configure(connection=connection) + return context.get_current_revision() + + +def _delete_migration_script(script_path: str): + """Delete a migration script. + + Args: + script_path: The path of the migration script to delete. + """ + if os.path.exists(script_path): + os.remove(script_path) + logger.info(f"Deleted migration script at: {script_path}") + else: + logger.warning(f"Migration script not found at: {script_path}") + + def _ddl_init_and_upgrade( default_meta_data_path: str, disable_alembic_upgrade: bool, @@ -203,7 +350,19 @@ def _ddl_init_and_upgrade( script_location, ) try: - create_migration_script(alembic_cfg, db.engine) + _check_database_migration_status(alembic_cfg, db.engine) + except Exception as e: + logger.error(f"Failed to check database migration status: {e}") + raise + latest_revision_before = "__latest_revision_before__" + new_script_path = None + try: + latest_revision_before = _get_latest_revision(alembic_cfg, db.engine) + # create_new_revision_if_noting_to_update=False avoid creating a lot of empty migration scripts + # TODO Set create_new_revision_if_noting_to_update=False, not working now. + new_script_path = create_migration_script( + alembic_cfg, db.engine, create_new_revision_if_noting_to_update=True + ) upgrade_database(alembic_cfg, db.engine) except CommandError as e: if "Target database is not up to date" in str(e): @@ -216,4 +375,10 @@ def _ddl_init_and_upgrade( "you can see the error and solutions above" ) from e else: + latest_revision_after = _get_latest_revision(alembic_cfg, db.engine) + if latest_revision_before != latest_revision_after: + logger.error( + f"Upgrade database failed. Please review the migration script manually. " + f"Failed script path: {new_script_path}\nError: {e}" + ) raise e diff --git a/dbgpt/util/config_utils.py b/dbgpt/util/config_utils.py new file mode 100644 index 000000000..5d07561ee --- /dev/null +++ b/dbgpt/util/config_utils.py @@ -0,0 +1,32 @@ +from functools import cache +from typing import Any, Dict, Optional + + +class AppConfig: + def __init__(self): + self.configs = {} + + def set(self, key: str, value: Any) -> None: + """Set config value by key + Args: + key (str): The key of config + value (Any): The value of config + """ + self.configs[key] = value + + def get(self, key, default: Optional[Any] = None) -> Any: + """Get config value by key + + Args: + key (str): The key of config + default (Optional[Any], optional): The default value if key not found. Defaults to None. + """ + return self.configs.get(key, default) + + @cache + def get_all_by_prefix(self, prefix) -> Dict[str, Any]: + """Get all config values by prefix + Args: + prefix (str): The prefix of config + """ + return {k: v for k, v in self.configs.items() if k.startswith(prefix)} diff --git a/dbgpt/util/utils.py b/dbgpt/util/utils.py index 23a78120f..0d39bc4b6 100644 --- a/dbgpt/util/utils.py +++ b/dbgpt/util/utils.py @@ -6,7 +6,6 @@ from typing import Any, List import os -import sys import asyncio from dbgpt.configs.model_config import LOGDIR @@ -81,17 +80,6 @@ def _build_logger(logger_name, logging_level=None, logger_filename: str = None): setup_logging_level(logging_level=logging_level) logging.getLogger().handlers[0].setFormatter(formatter) - # Redirect stdout and stderr to loggers - # stdout_logger = logging.getLogger("stdout") - # stdout_logger.setLevel(logging.INFO) - # sl_1 = StreamToLogger(stdout_logger, logging.INFO) - # sys.stdout = sl_1 - # - # stderr_logger = logging.getLogger("stderr") - # stderr_logger.setLevel(logging.ERROR) - # sl = StreamToLogger(stderr_logger, logging.ERROR) - # sys.stderr = sl - # Add a file handler for all loggers if handler is None and logger_filename: os.makedirs(LOGDIR, exist_ok=True) diff --git a/pilot/meta_data/alembic/env.py b/pilot/meta_data/alembic/env.py index 5ed4384a9..1af5180f4 100644 --- a/pilot/meta_data/alembic/env.py +++ b/pilot/meta_data/alembic/env.py @@ -33,9 +33,10 @@ def run_migrations_offline() -> None: script output. """ - engine = db.engine target_metadata = db.metadata - url = config.get_main_option(engine.url) + url = config.get_main_option("sqlalchemy.url") + assert target_metadata is not None + assert url is not None context.configure( url=url, target_metadata=target_metadata,