diff --git a/dbgpt/model/cluster/worker/manager.py b/dbgpt/model/cluster/worker/manager.py index 234b44726..8e65ed8e2 100644 --- a/dbgpt/model/cluster/worker/manager.py +++ b/dbgpt/model/cluster/worker/manager.py @@ -217,7 +217,6 @@ async def model_startup(self, startup_req: WorkerStartupRequest): ) if not worker_params.model_name: worker_params.model_name = model_name - assert model_name == worker_params.model_name worker = _build_worker(worker_params) command_args = _dict_to_command_args(params) success = await self.run_blocking_func( @@ -235,7 +234,9 @@ async def model_startup(self, startup_req: WorkerStartupRequest): f"Unsupported worker type: {worker_type}, now supported worker type: {supported_types}" ) start_apply_req = WorkerApplyRequest( - model=model_name, apply_type=WorkerApplyType.START, worker_type=worker_type + model=worker_params.model_name, + apply_type=WorkerApplyType.START, + worker_type=worker_type, ) out: WorkerApplyOutput = None try: @@ -895,6 +896,8 @@ def _parse_worker_params( **kwargs, ) worker_params.update_from(new_worker_params) + if worker_params.model_alias: + worker_params.model_name = worker_params.model_alias # logger.info(f"Worker params: {worker_params}") return worker_params diff --git a/dbgpt/model/parameter.py b/dbgpt/model/parameter.py index 3dc68afcc..7af18a0a9 100644 --- a/dbgpt/model/parameter.py +++ b/dbgpt/model/parameter.py @@ -164,6 +164,10 @@ class ModelWorkerParameters(BaseModelParameters): default=None, metadata={"valid_values": WorkerType.values(), "help": "Worker type"}, ) + model_alias: Optional[str] = field( + default=None, + metadata={"help": "model alias"}, + ) worker_class: Optional[str] = field( default=None, metadata={"help": "Model worker class, dbgpt.model.cluster.DefaultModelWorker"},