From 8da81486f5e9f873dbd5dc903f8006ef2be7e94c Mon Sep 17 00:00:00 2001 From: FangYin Cheng Date: Thu, 14 Sep 2023 20:11:41 +0800 Subject: [PATCH] fix(model): Fix remote embedding model error in some case --- pilot/componet.py | 1 + pilot/embedding_engine/embedding_factory.py | 14 +++-- .../model/cluster/worker/embedding_worker.py | 40 ++++++++----- pilot/model/cluster/worker/manager.py | 1 + pilot/server/base.py | 6 ++ pilot/server/componet_configs.py | 56 +++++++++++++++++-- pilot/server/dbgpt_server.py | 13 ++++- 7 files changed, 105 insertions(+), 26 deletions(-) diff --git a/pilot/componet.py b/pilot/componet.py index 2c3980cfc..0897b3365 100644 --- a/pilot/componet.py +++ b/pilot/componet.py @@ -152,6 +152,7 @@ def _build(self): @self.app.on_event("startup") async def startup_event(): """ASGI app startup event handler.""" + # TODO catch exception and shutdown if worker manager start failed asyncio.create_task(self.async_after_start()) self.after_start() diff --git a/pilot/embedding_engine/embedding_factory.py b/pilot/embedding_engine/embedding_factory.py index e7345d952..b9d17ad83 100644 --- a/pilot/embedding_engine/embedding_factory.py +++ b/pilot/embedding_engine/embedding_factory.py @@ -18,9 +18,11 @@ def create( class DefaultEmbeddingFactory(EmbeddingFactory): - def __init__(self, system_app=None, model_name: str = None, **kwargs: Any) -> None: + def __init__( + self, system_app=None, default_model_name: str = None, **kwargs: Any + ) -> None: super().__init__(system_app=system_app) - self._default_model_name = model_name + self._default_model_name = default_model_name self.kwargs = kwargs def init_app(self, system_app): @@ -31,9 +33,13 @@ def create( ) -> "Embeddings": if not model_name: model_name = self._default_model_name + + new_kwargs = {k: v for k, v in self.kwargs.items()} + new_kwargs["model_name"] = model_name + if embedding_cls: - return embedding_cls(model_name=model_name, **self.kwargs) + return embedding_cls(**new_kwargs) else: from langchain.embeddings import HuggingFaceEmbeddings - return HuggingFaceEmbeddings(model_name=model_name, **self.kwargs) + return HuggingFaceEmbeddings(**new_kwargs) diff --git a/pilot/model/cluster/worker/embedding_worker.py b/pilot/model/cluster/worker/embedding_worker.py index 80f06b145..a0bfd66bc 100644 --- a/pilot/model/cluster/worker/embedding_worker.py +++ b/pilot/model/cluster/worker/embedding_worker.py @@ -1,5 +1,5 @@ import logging -from typing import Dict, List, Type +from typing import Dict, List, Type, Optional from pilot.configs.model_config import get_device from pilot.model.loader import _get_model_real_path @@ -45,21 +45,12 @@ def parse_parameters( self, command_args: List[str] = None ) -> EmbeddingModelParameters: param_cls = self.model_param_class() - model_args = EnvArgumentParser() - env_prefix = EnvArgumentParser.get_env_prefix(self.model_name) - model_params: EmbeddingModelParameters = model_args.parse_args_into_dataclass( - param_cls, - env_prefix=env_prefix, - command_args=command_args, + return _parse_embedding_params( model_name=self.model_name, model_path=self.model_path, + command_args=command_args, + param_cls=param_cls, ) - if not model_params.device: - model_params.device = get_device() - logger.info( - f"[EmbeddingsModelWorker] Parameters of device is None, use {model_params.device}" - ) - return model_params def start( self, @@ -100,3 +91,26 @@ def embeddings(self, params: Dict) -> List[List[float]]: logger.info(f"Receive embeddings request, model: {model}") input: List[str] = params["input"] return self._embeddings_impl.embed_documents(input) + + +def _parse_embedding_params( + model_name: str, + model_path: str, + command_args: List[str] = None, + param_cls: Optional[Type] = EmbeddingModelParameters, +): + model_args = EnvArgumentParser() + env_prefix = EnvArgumentParser.get_env_prefix(model_name) + model_params: EmbeddingModelParameters = model_args.parse_args_into_dataclass( + param_cls, + env_prefix=env_prefix, + command_args=command_args, + model_name=model_name, + model_path=model_path, + ) + if not model_params.device: + model_params.device = get_device() + logger.info( + f"[EmbeddingsModelWorker] Parameters of device is None, use {model_params.device}" + ) + return model_params diff --git a/pilot/model/cluster/worker/manager.py b/pilot/model/cluster/worker/manager.py index 93f2a373a..4791d8caa 100644 --- a/pilot/model/cluster/worker/manager.py +++ b/pilot/model/cluster/worker/manager.py @@ -635,6 +635,7 @@ def _setup_fastapi(worker_params: ModelWorkerParameters, app=None): @app.on_event("startup") async def startup_event(): + # TODO catch exception and shutdown if worker manager start failed asyncio.create_task(worker_manager.start()) @app.on_event("shutdown") diff --git a/pilot/server/base.py b/pilot/server/base.py index 59116e09c..888ebbf3d 100644 --- a/pilot/server/base.py +++ b/pilot/server/base.py @@ -102,6 +102,12 @@ class WebWerverParameters(BaseParameters): "help": "Whether to create a publicly shareable link for the interface. Creates an SSH tunnel to make your UI accessible from anywhere. " }, ) + remote_embedding: Optional[bool] = field( + default=False, + metadata={ + "help": "Whether to enable remote embedding models. If it is True, you need to start a embedding model through `dbgpt start worker --worker_type text2vec --model_name xxx --model_path xxx`" + }, + ) log_level: Optional[str] = field( default="INFO", metadata={ diff --git a/pilot/server/componet_configs.py b/pilot/server/componet_configs.py index b68b052fa..755f13b21 100644 --- a/pilot/server/componet_configs.py +++ b/pilot/server/componet_configs.py @@ -1,21 +1,65 @@ from typing import Any, Type, TYPE_CHECKING from pilot.componet import SystemApp -from pilot.embedding_engine.embedding_factory import EmbeddingFactory +import logging +from pilot.configs.model_config import get_device +from pilot.embedding_engine.embedding_factory import ( + EmbeddingFactory, + DefaultEmbeddingFactory, +) +from pilot.server.base import WebWerverParameters +from pilot.utils.parameter_utils import EnvArgumentParser if TYPE_CHECKING: from langchain.embeddings.base import Embeddings -def initialize_componets(system_app: SystemApp, embedding_model_name: str): - from pilot.model.cluster import worker_manager +logger = logging.getLogger(__name__) + + +def initialize_componets( + param: WebWerverParameters, + system_app: SystemApp, + embedding_model_name: str, + embedding_model_path: str, +): from pilot.model.cluster.controller.controller import controller - system_app.register( - RemoteEmbeddingFactory, worker_manager, model_name=embedding_model_name - ) system_app.register_instance(controller) + _initialize_embedding_model( + param, system_app, embedding_model_name, embedding_model_path + ) + + +def _initialize_embedding_model( + param: WebWerverParameters, + system_app: SystemApp, + embedding_model_name: str, + embedding_model_path: str, +): + from pilot.model.cluster import worker_manager + + if param.remote_embedding: + logger.info("Register remote RemoteEmbeddingFactory") + system_app.register( + RemoteEmbeddingFactory, worker_manager, model_name=embedding_model_name + ) + else: + from pilot.model.parameter import EmbeddingModelParameters + from pilot.model.cluster.worker.embedding_worker import _parse_embedding_params + + model_params: EmbeddingModelParameters = _parse_embedding_params( + model_name=embedding_model_name, + model_path=embedding_model_path, + param_cls=EmbeddingModelParameters, + ) + kwargs = model_params.build_kwargs(model_name=embedding_model_path) + logger.info(f"Register local DefaultEmbeddingFactory with kwargs: {kwargs}") + system_app.register( + DefaultEmbeddingFactory, default_model_name=embedding_model_path, **kwargs + ) + class RemoteEmbeddingFactory(EmbeddingFactory): def __init__( diff --git a/pilot/server/dbgpt_server.py b/pilot/server/dbgpt_server.py index 3019f067f..d2307f06a 100644 --- a/pilot/server/dbgpt_server.py +++ b/pilot/server/dbgpt_server.py @@ -109,20 +109,27 @@ def initialize_app(param: WebWerverParameters = None, args: List[str] = None): # Before start system_app.before_start() + print(param) + + embedding_model_name = CFG.EMBEDDING_MODEL + embedding_model_path = EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL] + server_init(param, system_app) model_start_listener = _create_model_start_listener(system_app) - initialize_componets(system_app, CFG.EMBEDDING_MODEL) + initialize_componets(param, system_app, embedding_model_name, embedding_model_path) model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL] if not param.light: print("Model Unified Deployment Mode!") + if not param.remote_embedding: + embedding_model_name, embedding_model_path = None, None initialize_worker_manager_in_client( app=app, model_name=CFG.LLM_MODEL, model_path=model_path, local_port=param.port, - embedding_model_name=CFG.EMBEDDING_MODEL, - embedding_model_path=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL], + embedding_model_name=embedding_model_name, + embedding_model_path=embedding_model_path, start_listener=model_start_listener, )