Skip to content

Commit

Permalink
feat(model): Support database model registry (#1656)
Browse files Browse the repository at this point in the history
  • Loading branch information
fangyinc authored Jun 24, 2024
1 parent c57ee02 commit 47d205f
Show file tree
Hide file tree
Showing 35 changed files with 2,014 additions and 792 deletions.
23 changes: 23 additions & 0 deletions assets/schema/dbgpt.sql
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ CREATE TABLE `dbgpt_serve_flow` (
`source` varchar(64) DEFAULT NULL COMMENT 'Flow source',
`source_url` varchar(512) DEFAULT NULL COMMENT 'Flow source url',
`version` varchar(32) DEFAULT NULL COMMENT 'Flow version',
`define_type` varchar(32) null comment 'Flow define type(json or python)',
`label` varchar(128) DEFAULT NULL COMMENT 'Flow label',
`editable` int DEFAULT NULL COMMENT 'Editable, 0: editable, 1: not editable',
PRIMARY KEY (`id`),
Expand Down Expand Up @@ -340,6 +341,28 @@ CREATE TABLE `gpts_app_detail` (
UNIQUE KEY `uk_gpts_app_agent_node` (`app_name`,`agent_name`,`node_id`)
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;


-- For deploy model cluster of DB-GPT(StorageModelRegistry)
CREATE TABLE IF NOT EXISTS `dbgpt_cluster_registry_instance` (
`id` int(11) NOT NULL AUTO_INCREMENT COMMENT 'Auto increment id',
`model_name` varchar(128) NOT NULL COMMENT 'Model name',
`host` varchar(128) NOT NULL COMMENT 'Host of the model',
`port` int(11) NOT NULL COMMENT 'Port of the model',
`weight` float DEFAULT 1.0 COMMENT 'Weight of the model',
`check_healthy` tinyint(1) DEFAULT 1 COMMENT 'Whether to check the health of the model',
`healthy` tinyint(1) DEFAULT 0 COMMENT 'Whether the model is healthy',
`enabled` tinyint(1) DEFAULT 1 COMMENT 'Whether the model is enabled',
`prompt_template` varchar(128) DEFAULT NULL COMMENT 'Prompt template for the model instance',
`last_heartbeat` datetime DEFAULT NULL COMMENT 'Last heartbeat time of the model instance',
`user_name` varchar(128) DEFAULT NULL COMMENT 'User name',
`sys_code` varchar(128) DEFAULT NULL COMMENT 'System code',
`gmt_created` datetime DEFAULT CURRENT_TIMESTAMP COMMENT 'Record creation time',
`gmt_modified` datetime DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'Record update time',
PRIMARY KEY (`id`),
UNIQUE KEY `uk_model_instance` (`model_name`, `host`, `port`, `sys_code`)
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT='Cluster model instance table, for registering and managing model instances';


CREATE
DATABASE IF NOT EXISTS EXAMPLE_1;
use EXAMPLE_1;
Expand Down
File renamed without changes.
File renamed without changes.
22 changes: 22 additions & 0 deletions assets/schema/upgrade/v0_5_9/upgrade_to_v0.5.9.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
USE dbgpt;

-- For deploy model cluster of DB-GPT(StorageModelRegistry)
CREATE TABLE IF NOT EXISTS `dbgpt_cluster_registry_instance` (
`id` int(11) NOT NULL AUTO_INCREMENT COMMENT 'Auto increment id',
`model_name` varchar(128) NOT NULL COMMENT 'Model name',
`host` varchar(128) NOT NULL COMMENT 'Host of the model',
`port` int(11) NOT NULL COMMENT 'Port of the model',
`weight` float DEFAULT 1.0 COMMENT 'Weight of the model',
`check_healthy` tinyint(1) DEFAULT 1 COMMENT 'Whether to check the health of the model',
`healthy` tinyint(1) DEFAULT 0 COMMENT 'Whether the model is healthy',
`enabled` tinyint(1) DEFAULT 1 COMMENT 'Whether the model is enabled',
`prompt_template` varchar(128) DEFAULT NULL COMMENT 'Prompt template for the model instance',
`last_heartbeat` datetime DEFAULT NULL COMMENT 'Last heartbeat time of the model instance',
`user_name` varchar(128) DEFAULT NULL COMMENT 'User name',
`sys_code` varchar(128) DEFAULT NULL COMMENT 'System code',
`gmt_created` datetime DEFAULT CURRENT_TIMESTAMP COMMENT 'Record creation time',
`gmt_modified` datetime DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'Record update time',
PRIMARY KEY (`id`),
UNIQUE KEY `uk_model_instance` (`model_name`, `host`, `port`, `sys_code`)
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT='Cluster model instance table, for registering and managing model instances';

396 changes: 396 additions & 0 deletions assets/schema/upgrade/v0_5_9/v0.5.8.sql

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions dbgpt/app/initialization/db_model_initialization.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
"""Import all models to make sure they are registered with SQLAlchemy.
"""

from dbgpt.app.knowledge.chunk_db import DocumentChunkEntity
from dbgpt.app.knowledge.document_db import KnowledgeDocumentEntity
from dbgpt.app.openapi.api_v1.feedback.feed_back_db import ChatFeedBackEntity
from dbgpt.datasource.manages.connect_config_db import ConnectConfigEntity
from dbgpt.model.cluster.registry_impl.db_storage import ModelInstanceEntity
from dbgpt.serve.agent.db.my_plugin_db import MyPluginEntity
from dbgpt.serve.agent.db.plugin_hub_db import PluginHubEntity
from dbgpt.serve.flow.models.models import ServeEntity as FlowServeEntity
from dbgpt.serve.prompt.models.models import ServeEntity as PromptManageEntity
from dbgpt.serve.rag.models.models import KnowledgeSpaceEntity
from dbgpt.storage.chat_history.chat_history_db import (
Expand All @@ -24,4 +27,6 @@
ConnectConfigEntity,
ChatHistoryEntity,
ChatHistoryMessageEntity,
ModelInstanceEntity,
FlowServeEntity,
]
1 change: 1 addition & 0 deletions dbgpt/model/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
build_lazy_click_command,
)

# Your can set environment variable CONTROLLER_ADDRESS to set the default address
MODEL_CONTROLLER_ADDRESS = "http://127.0.0.1:8000"

logger = logging.getLogger("dbgpt_cli")
Expand Down
30 changes: 14 additions & 16 deletions dbgpt/model/cluster/apiserver/tests/test_api.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,11 @@
import importlib.metadata as metadata

import pytest
import pytest_asyncio
from aioresponses import aioresponses
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from httpx import AsyncClient, HTTPError
from httpx import ASGITransport, AsyncClient, HTTPError

from dbgpt.component import SystemApp
from dbgpt.model.cluster.apiserver.api import (
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse,
ChatMessage,
DeltaMessage,
ModelList,
UsageInfo,
api_settings,
initialize_apiserver,
)
Expand Down Expand Up @@ -56,12 +45,13 @@ async def client(request, system_app: SystemApp):
if api_settings:
# Clear global api keys
api_settings.api_keys = []
async with AsyncClient(app=app, base_url="http://test", headers=headers) as client:
async with AsyncClient(
transport=ASGITransport(app), base_url="http://test", headers=headers
) as client:
async with _new_cluster(**param) as cluster:
worker_manager, model_registry = cluster
system_app.register(_DefaultWorkerManagerFactory, worker_manager)
system_app.register_instance(model_registry)
# print(f"Instances {model_registry.registry}")
initialize_apiserver(None, app, system_app, api_keys=api_keys)
yield client

Expand Down Expand Up @@ -113,7 +103,11 @@ async def test_chat_completions(client: AsyncClient, expected_messages):
"Hello world.",
"abc",
),
({"stream_messags": ["你好,我是", "张三。"], "api_keys": ["abc"]}, "你好,我是张三。", "abc"),
(
{"stream_messags": ["你好,我是", "张三。"], "api_keys": ["abc"]},
"你好,我是张三。",
"abc",
),
],
indirect=["client"],
)
Expand Down Expand Up @@ -160,7 +154,11 @@ async def test_chat_completions_with_openai_lib_async_no_stream(
"Hello world.",
"abc",
),
({"stream_messags": ["你好,我是", "张三。"], "api_keys": ["abc"]}, "你好,我是张三。", "abc"),
(
{"stream_messags": ["你好,我是", "张三。"], "api_keys": ["abc"]},
"你好,我是张三。",
"abc",
),
],
indirect=["client"],
)
Expand Down
131 changes: 121 additions & 10 deletions dbgpt/model/cluster/controller/controller.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import logging
from abc import ABC, abstractmethod
from typing import List
from typing import List, Literal, Optional

from fastapi import APIRouter

from dbgpt.component import BaseComponent, ComponentType, SystemApp
from dbgpt.model.base import ModelInstance
from dbgpt.model.cluster.registry import EmbeddedModelRegistry, ModelRegistry
from dbgpt.model.parameter import ModelControllerParameters
from dbgpt.util.api_utils import APIMixin
from dbgpt.util.api_utils import _api_remote as api_remote
from dbgpt.util.api_utils import _sync_api_remote as sync_api_remote
from dbgpt.util.fastapi import create_app
Expand Down Expand Up @@ -46,9 +47,7 @@ async def model_apply(self) -> bool:


class LocalModelController(BaseModelController):
def __init__(self, registry: ModelRegistry = None) -> None:
if not registry:
registry = EmbeddedModelRegistry()
def __init__(self, registry: ModelRegistry) -> None:
self.registry = registry
self.deployment = None

Expand All @@ -75,9 +74,25 @@ async def send_heartbeat(self, instance: ModelInstance) -> bool:
return await self.registry.send_heartbeat(instance)


class _RemoteModelController(BaseModelController):
def __init__(self, base_url: str) -> None:
self.base_url = base_url
class _RemoteModelController(APIMixin, BaseModelController):
def __init__(
self,
urls: str,
health_check_interval_secs: int = 5,
health_check_timeout_secs: int = 30,
check_health: bool = True,
choice_type: Literal["latest_first", "random"] = "latest_first",
) -> None:
APIMixin.__init__(
self,
urls=urls,
health_check_path="/api/health",
health_check_interval_secs=health_check_interval_secs,
health_check_timeout_secs=health_check_timeout_secs,
check_health=check_health,
choice_type=choice_type,
)
BaseModelController.__init__(self)

@api_remote(path="/api/controller/models", method="POST")
async def register_instance(self, instance: ModelInstance) -> bool:
Expand Down Expand Up @@ -139,13 +154,19 @@ async def model_apply(self) -> bool:


def initialize_controller(
app=None, remote_controller_addr: str = None, host: str = None, port: int = None
app=None,
remote_controller_addr: str = None,
host: str = None,
port: int = None,
registry: Optional[ModelRegistry] = None,
):
global controller
if remote_controller_addr:
controller.backend = _RemoteModelController(remote_controller_addr)
else:
controller.backend = LocalModelController()
if not registry:
registry = EmbeddedModelRegistry()
controller.backend = LocalModelController(registry=registry)

if app:
app.include_router(router, prefix="/api", tags=["Model"])
Expand All @@ -158,6 +179,12 @@ def initialize_controller(
uvicorn.run(app, host=host, port=port, log_level="info")


@router.get("/health")
async def api_health_check():
"""Health check API."""
return {"status": "ok"}


@router.post("/controller/models")
async def api_register_instance(request: ModelInstance):
return await controller.register_instance(request)
Expand All @@ -179,6 +206,87 @@ async def api_model_heartbeat(request: ModelInstance):
return await controller.send_heartbeat(request)


def _create_registry(controller_params: ModelControllerParameters) -> ModelRegistry:
"""Create a model registry based on the controller parameters.
Registry will store the metadata of all model instances, it will be a high
availability service for model instances if you use a database registry now. Also,
we can implement more registry types in the future.
"""
registry_type = controller_params.registry_type.strip()
if controller_params.registry_type == "embedded":
return EmbeddedModelRegistry(
heartbeat_interval_secs=controller_params.heartbeat_interval_secs,
heartbeat_timeout_secs=controller_params.heartbeat_timeout_secs,
)
elif controller_params.registry_type == "database":
from urllib.parse import quote
from urllib.parse import quote_plus as urlquote

from dbgpt.model.cluster.registry_impl.storage import StorageModelRegistry

try_to_create_db = False

if controller_params.registry_db_type == "mysql":
db_name = controller_params.registry_db_name
db_host = controller_params.registry_db_host
db_port = controller_params.registry_db_port
db_user = controller_params.registry_db_user
db_password = controller_params.registry_db_password
if not db_name:
raise ValueError(
"Registry DB name is required when using MySQL registry."
)
if not db_host:
raise ValueError(
"Registry DB host is required when using MySQL registry."
)
if not db_port:
raise ValueError(
"Registry DB port is required when using MySQL registry."
)
if not db_user:
raise ValueError(
"Registry DB user is required when using MySQL registry."
)
if not db_password:
raise ValueError(
"Registry DB password is required when using MySQL registry."
)
db_url = (
f"mysql+pymysql://{quote(db_user)}:"
f"{urlquote(db_password)}@"
f"{db_host}:"
f"{str(db_port)}/"
f"{db_name}?charset=utf8mb4"
)
elif controller_params.registry_db_type == "sqlite":
db_name = controller_params.registry_db_name
if not db_name:
raise ValueError(
"Registry DB name is required when using SQLite registry."
)
db_url = f"sqlite:///{db_name}"
try_to_create_db = True
else:
raise ValueError(
f"Unsupported registry DB type: {controller_params.registry_db_type}"
)

registry = StorageModelRegistry.from_url(
db_url,
db_name,
pool_size=controller_params.registry_db_pool_size,
max_overflow=controller_params.registry_db_max_overflow,
try_to_create_db=try_to_create_db,
heartbeat_interval_secs=controller_params.heartbeat_interval_secs,
heartbeat_timeout_secs=controller_params.heartbeat_timeout_secs,
)
return registry
else:
raise ValueError(f"Unsupported registry type: {registry_type}")


def run_model_controller():
parser = EnvArgumentParser()
env_prefix = "controller_"
Expand All @@ -192,8 +300,11 @@ def run_model_controller():
logging_level=controller_params.log_level,
logger_filename=controller_params.log_file,
)
registry = _create_registry(controller_params)

initialize_controller(host=controller_params.host, port=controller_params.port)
initialize_controller(
host=controller_params.host, port=controller_params.port, registry=registry
)


if __name__ == "__main__":
Expand Down
File renamed without changes.
Loading

0 comments on commit 47d205f

Please sign in to comment.