Skip to content

Commit

Permalink
fix(model): Fix remote embedding model error in some case
Browse files Browse the repository at this point in the history
  • Loading branch information
fangyinc committed Sep 14, 2023
1 parent f304f97 commit 8da8148
Show file tree
Hide file tree
Showing 7 changed files with 105 additions and 26 deletions.
1 change: 1 addition & 0 deletions pilot/componet.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def _build(self):
@self.app.on_event("startup")
async def startup_event():
"""ASGI app startup event handler."""
# TODO catch exception and shutdown if worker manager start failed
asyncio.create_task(self.async_after_start())
self.after_start()

Expand Down
14 changes: 10 additions & 4 deletions pilot/embedding_engine/embedding_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@ def create(


class DefaultEmbeddingFactory(EmbeddingFactory):
def __init__(self, system_app=None, model_name: str = None, **kwargs: Any) -> None:
def __init__(
self, system_app=None, default_model_name: str = None, **kwargs: Any
) -> None:
super().__init__(system_app=system_app)
self._default_model_name = model_name
self._default_model_name = default_model_name
self.kwargs = kwargs

def init_app(self, system_app):
Expand All @@ -31,9 +33,13 @@ def create(
) -> "Embeddings":
if not model_name:
model_name = self._default_model_name

new_kwargs = {k: v for k, v in self.kwargs.items()}
new_kwargs["model_name"] = model_name

if embedding_cls:
return embedding_cls(model_name=model_name, **self.kwargs)
return embedding_cls(**new_kwargs)
else:
from langchain.embeddings import HuggingFaceEmbeddings

return HuggingFaceEmbeddings(model_name=model_name, **self.kwargs)
return HuggingFaceEmbeddings(**new_kwargs)
40 changes: 27 additions & 13 deletions pilot/model/cluster/worker/embedding_worker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Dict, List, Type
from typing import Dict, List, Type, Optional

from pilot.configs.model_config import get_device
from pilot.model.loader import _get_model_real_path
Expand Down Expand Up @@ -45,21 +45,12 @@ def parse_parameters(
self, command_args: List[str] = None
) -> EmbeddingModelParameters:
param_cls = self.model_param_class()
model_args = EnvArgumentParser()
env_prefix = EnvArgumentParser.get_env_prefix(self.model_name)
model_params: EmbeddingModelParameters = model_args.parse_args_into_dataclass(
param_cls,
env_prefix=env_prefix,
command_args=command_args,
return _parse_embedding_params(
model_name=self.model_name,
model_path=self.model_path,
command_args=command_args,
param_cls=param_cls,
)
if not model_params.device:
model_params.device = get_device()
logger.info(
f"[EmbeddingsModelWorker] Parameters of device is None, use {model_params.device}"
)
return model_params

def start(
self,
Expand Down Expand Up @@ -100,3 +91,26 @@ def embeddings(self, params: Dict) -> List[List[float]]:
logger.info(f"Receive embeddings request, model: {model}")
input: List[str] = params["input"]
return self._embeddings_impl.embed_documents(input)


def _parse_embedding_params(
model_name: str,
model_path: str,
command_args: List[str] = None,
param_cls: Optional[Type] = EmbeddingModelParameters,
):
model_args = EnvArgumentParser()
env_prefix = EnvArgumentParser.get_env_prefix(model_name)
model_params: EmbeddingModelParameters = model_args.parse_args_into_dataclass(
param_cls,
env_prefix=env_prefix,
command_args=command_args,
model_name=model_name,
model_path=model_path,
)
if not model_params.device:
model_params.device = get_device()
logger.info(
f"[EmbeddingsModelWorker] Parameters of device is None, use {model_params.device}"
)
return model_params
1 change: 1 addition & 0 deletions pilot/model/cluster/worker/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,7 @@ def _setup_fastapi(worker_params: ModelWorkerParameters, app=None):

@app.on_event("startup")
async def startup_event():
# TODO catch exception and shutdown if worker manager start failed
asyncio.create_task(worker_manager.start())

@app.on_event("shutdown")
Expand Down
6 changes: 6 additions & 0 deletions pilot/server/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,12 @@ class WebWerverParameters(BaseParameters):
"help": "Whether to create a publicly shareable link for the interface. Creates an SSH tunnel to make your UI accessible from anywhere. "
},
)
remote_embedding: Optional[bool] = field(
default=False,
metadata={
"help": "Whether to enable remote embedding models. If it is True, you need to start a embedding model through `dbgpt start worker --worker_type text2vec --model_name xxx --model_path xxx`"
},
)
log_level: Optional[str] = field(
default="INFO",
metadata={
Expand Down
56 changes: 50 additions & 6 deletions pilot/server/componet_configs.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,65 @@
from typing import Any, Type, TYPE_CHECKING

from pilot.componet import SystemApp
from pilot.embedding_engine.embedding_factory import EmbeddingFactory
import logging
from pilot.configs.model_config import get_device
from pilot.embedding_engine.embedding_factory import (
EmbeddingFactory,
DefaultEmbeddingFactory,
)
from pilot.server.base import WebWerverParameters
from pilot.utils.parameter_utils import EnvArgumentParser

if TYPE_CHECKING:
from langchain.embeddings.base import Embeddings


def initialize_componets(system_app: SystemApp, embedding_model_name: str):
from pilot.model.cluster import worker_manager
logger = logging.getLogger(__name__)


def initialize_componets(
param: WebWerverParameters,
system_app: SystemApp,
embedding_model_name: str,
embedding_model_path: str,
):
from pilot.model.cluster.controller.controller import controller

system_app.register(
RemoteEmbeddingFactory, worker_manager, model_name=embedding_model_name
)
system_app.register_instance(controller)

_initialize_embedding_model(
param, system_app, embedding_model_name, embedding_model_path
)


def _initialize_embedding_model(
param: WebWerverParameters,
system_app: SystemApp,
embedding_model_name: str,
embedding_model_path: str,
):
from pilot.model.cluster import worker_manager

if param.remote_embedding:
logger.info("Register remote RemoteEmbeddingFactory")
system_app.register(
RemoteEmbeddingFactory, worker_manager, model_name=embedding_model_name
)
else:
from pilot.model.parameter import EmbeddingModelParameters
from pilot.model.cluster.worker.embedding_worker import _parse_embedding_params

model_params: EmbeddingModelParameters = _parse_embedding_params(
model_name=embedding_model_name,
model_path=embedding_model_path,
param_cls=EmbeddingModelParameters,
)
kwargs = model_params.build_kwargs(model_name=embedding_model_path)
logger.info(f"Register local DefaultEmbeddingFactory with kwargs: {kwargs}")
system_app.register(
DefaultEmbeddingFactory, default_model_name=embedding_model_path, **kwargs
)


class RemoteEmbeddingFactory(EmbeddingFactory):
def __init__(
Expand Down
13 changes: 10 additions & 3 deletions pilot/server/dbgpt_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,20 +109,27 @@ def initialize_app(param: WebWerverParameters = None, args: List[str] = None):
# Before start
system_app.before_start()

print(param)

embedding_model_name = CFG.EMBEDDING_MODEL
embedding_model_path = EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]

server_init(param, system_app)
model_start_listener = _create_model_start_listener(system_app)
initialize_componets(system_app, CFG.EMBEDDING_MODEL)
initialize_componets(param, system_app, embedding_model_name, embedding_model_path)

model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL]
if not param.light:
print("Model Unified Deployment Mode!")
if not param.remote_embedding:
embedding_model_name, embedding_model_path = None, None
initialize_worker_manager_in_client(
app=app,
model_name=CFG.LLM_MODEL,
model_path=model_path,
local_port=param.port,
embedding_model_name=CFG.EMBEDDING_MODEL,
embedding_model_path=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
embedding_model_name=embedding_model_name,
embedding_model_path=embedding_model_path,
start_listener=model_start_listener,
)

Expand Down

0 comments on commit 8da8148

Please sign in to comment.