diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index 2afac5162..b4632003e 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -41,25 +41,12 @@ def get_device() -> str: # (Llama2 based) see https://huggingface.co/lmsys/vicuna-13b-v1.5 "vicuna-13b-v1.5": os.path.join(MODEL_PATH, "vicuna-13b-v1.5"), "vicuna-7b-v1.5": os.path.join(MODEL_PATH, "vicuna-7b-v1.5"), - "text2vec": os.path.join(MODEL_PATH, "text2vec-large-chinese"), - # https://huggingface.co/moka-ai/m3e-large - "m3e-base": os.path.join(MODEL_PATH, "m3e-base"), - # https://huggingface.co/moka-ai/m3e-base - "m3e-large": os.path.join(MODEL_PATH, "m3e-large"), - # https://huggingface.co/BAAI/bge-large-en - "bge-large-en": os.path.join(MODEL_PATH, "bge-large-en"), - "bge-base-en": os.path.join(MODEL_PATH, "bge-base-en"), - # https://huggingface.co/BAAI/bge-large-zh - "bge-large-zh": os.path.join(MODEL_PATH, "bge-large-zh"), - "bge-base-zh": os.path.join(MODEL_PATH, "bge-base-zh"), - "sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2"), "codegen2-1b": os.path.join(MODEL_PATH, "codegen2-1B"), "codet5p-2b": os.path.join(MODEL_PATH, "codet5p-2b"), "chatglm-6b-int4": os.path.join(MODEL_PATH, "chatglm-6b-int4"), "chatglm-6b": os.path.join(MODEL_PATH, "chatglm-6b"), "chatglm2-6b": os.path.join(MODEL_PATH, "chatglm2-6b"), "chatglm2-6b-int4": os.path.join(MODEL_PATH, "chatglm2-6b-int4"), - "text2vec-base": os.path.join(MODEL_PATH, "text2vec-base-chinese"), "guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged"), "falcon-40b": os.path.join(MODEL_PATH, "falcon-40b"), "gorilla-7b": os.path.join(MODEL_PATH, "gorilla-7b"), @@ -84,6 +71,22 @@ def get_device() -> str: "llama-cpp": os.path.join(MODEL_PATH, "ggml-model-q4_0.bin"), } +EMBEDDING_MODEL_CONFIG = { + "text2vec": os.path.join(MODEL_PATH, "text2vec-large-chinese"), + "text2vec-base": os.path.join(MODEL_PATH, "text2vec-base-chinese"), + # https://huggingface.co/moka-ai/m3e-large + "m3e-base": os.path.join(MODEL_PATH, "m3e-base"), + # https://huggingface.co/moka-ai/m3e-base + "m3e-large": os.path.join(MODEL_PATH, "m3e-large"), + # https://huggingface.co/BAAI/bge-large-en + "bge-large-en": os.path.join(MODEL_PATH, "bge-large-en"), + "bge-base-en": os.path.join(MODEL_PATH, "bge-base-en"), + # https://huggingface.co/BAAI/bge-large-zh + "bge-large-zh": os.path.join(MODEL_PATH, "bge-large-zh"), + "bge-base-zh": os.path.join(MODEL_PATH, "bge-base-zh"), + "sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2"), +} + # Load model config ISDEBUG = False diff --git a/pilot/model/adapter.py b/pilot/model/adapter.py index 44d115820..5d3acf505 100644 --- a/pilot/model/adapter.py +++ b/pilot/model/adapter.py @@ -108,6 +108,17 @@ def _dynamic_model_parser() -> Callable[[None], List[Type]]: return [param_class] +def _parse_model_param_class(model_name: str, model_path: str) -> ModelParameters: + try: + llm_adapter = get_llm_model_adapter(model_name, model_path) + return llm_adapter.model_param_class() + except Exception as e: + logger.warn( + f"Parse model parameters with model name {model_name} and model {model_path} failed {str(e)}, return `ModelParameters`" + ) + return ModelParameters + + # TODO support cpu? for practise we support gpt4all or chatglm-6b-int4? diff --git a/pilot/model/base.py b/pilot/model/base.py index c81279886..b6eb9da25 100644 --- a/pilot/model/base.py +++ b/pilot/model/base.py @@ -2,9 +2,10 @@ # -*- coding: utf-8 -*- from enum import Enum -from typing import TypedDict, Optional, Dict +from typing import TypedDict, Optional, Dict, List from dataclasses import dataclass from datetime import datetime +from pilot.utils.parameter_utils import ParameterDescription class Message(TypedDict): @@ -46,5 +47,40 @@ class ModelOutput: @dataclass class WorkerApplyOutput: message: str + success: Optional[bool] = True # The seconds cost to apply some action to worker instances timecost: Optional[int] = -1 + + +@dataclass +class SupportedModel: + model: str + path: str + worker_type: str + path_exist: bool + proxy: bool + enabled: bool + params: List[ParameterDescription] + + @classmethod + def from_dict(cls, model_data: Dict) -> "SupportedModel": + params = model_data.get("params", []) + if params: + params = [ParameterDescription(**param) for param in params] + model_data["params"] = params + return cls(**model_data) + + +@dataclass +class WorkerSupportedModel: + host: str + port: int + models: List[SupportedModel] + + @classmethod + def from_dict(cls, worker_data: Dict) -> "WorkerSupportedModel": + models = [ + SupportedModel.from_dict(model_data) for model_data in worker_data["models"] + ] + worker_data["models"] = models + return cls(**worker_data) diff --git a/pilot/model/cli.py b/pilot/model/cli.py index 4b981e504..2406f6920 100644 --- a/pilot/model/cli.py +++ b/pilot/model/cli.py @@ -2,19 +2,27 @@ import functools import logging import os -from typing import Callable, List, Type +from typing import Callable, List, Type, Optional -from pilot.model.controller.controller import ModelRegistryClient from pilot.configs.model_config import LOGDIR from pilot.model.base import WorkerApplyType from pilot.model.parameter import ( ModelControllerParameters, ModelWorkerParameters, ModelParameters, + BaseParameters, ) from pilot.utils import get_or_create_event_loop -from pilot.utils.parameter_utils import EnvArgumentParser -from pilot.utils.command_utils import _run_current_with_daemon, _stop_service +from pilot.utils.parameter_utils import ( + EnvArgumentParser, + _build_parameter_class, + build_lazy_click_command, +) +from pilot.utils.command_utils import ( + _run_current_with_daemon, + _stop_service, + _detect_controller_address, +) MODEL_CONTROLLER_ADDRESS = "http://127.0.0.1:8000" @@ -22,6 +30,14 @@ logger = logging.getLogger("dbgpt_cli") +def _get_worker_manager(address: str): + from pilot.model.cluster import RemoteWorkerManager, ModelRegistryClient + + registry = ModelRegistryClient(address) + worker_manager = RemoteWorkerManager(registry) + return worker_manager + + @click.group("model") @click.option( "--address", @@ -38,8 +54,6 @@ def model_cli_group(address: str): """Clients that manage model serving""" global MODEL_CONTROLLER_ADDRESS if not address: - from pilot.utils.command_utils import _detect_controller_address - MODEL_CONTROLLER_ADDRESS = _detect_controller_address() else: MODEL_CONTROLLER_ADDRESS = address @@ -55,6 +69,7 @@ def model_cli_group(address: str): def list(model_name: str, model_type: str): """List model instances""" from prettytable import PrettyTable + from pilot.model.cluster import ModelRegistryClient loop = get_or_create_event_loop() registry = ModelRegistryClient(MODEL_CONTROLLER_ADDRESS) @@ -90,7 +105,7 @@ def list(model_name: str, model_type: str): instance.port, instance.healthy, instance.enabled, - instance.prompt_template, + instance.prompt_template if instance.prompt_template else "", instance.last_heartbeat, ] ) @@ -122,18 +137,156 @@ def wrapper(*args, **kwargs): @model_cli_group.command() @add_model_options -def stop(model_name: str, model_type: str): +@click.option( + "--host", + type=str, + required=True, + help=("The remote host to stop model"), +) +@click.option( + "--port", + type=int, + required=True, + help=("The remote port to stop model"), +) +def stop(model_name: str, model_type: str, host: str, port: int): """Stop model instances""" - worker_apply(MODEL_CONTROLLER_ADDRESS, model_name, model_type, WorkerApplyType.STOP) + from pilot.model.cluster import WorkerStartupRequest, RemoteWorkerManager + + worker_manager: RemoteWorkerManager = _get_worker_manager(MODEL_CONTROLLER_ADDRESS) + req = WorkerStartupRequest( + host=host, + port=port, + worker_type=model_type, + model=model_name, + params={}, + ) + loop = get_or_create_event_loop() + res = loop.run_until_complete(worker_manager.model_shutdown(req)) + print(res) -@model_cli_group.command() -@add_model_options -def start(model_name: str, model_type: str): +def _remote_model_dynamic_factory() -> Callable[[None], List[Type]]: + from pilot.model.adapter import _dynamic_model_parser + from pilot.utils.parameter_utils import _SimpleArgParser + from pilot.model.cluster import RemoteWorkerManager + from pilot.model.parameter import WorkerType + from dataclasses import dataclass, field, fields + + pre_args = _SimpleArgParser("model_name", "address", "host", "port") + pre_args.parse() + model_name = pre_args.get("model_name") + address = pre_args.get("address") + host = pre_args.get("host") + port = pre_args.get("port") + if port: + port = int(port) + + if not address: + address = _detect_controller_address() + + worker_manager: RemoteWorkerManager = _get_worker_manager(address) + loop = get_or_create_event_loop() + models = loop.run_until_complete(worker_manager.supported_models()) + + fields_dict = {} + fields_dict["model_name"] = ( + str, + field(default=None, metadata={"help": "The model name to deploy"}), + ) + fields_dict["host"] = ( + str, + field(default=None, metadata={"help": "The remote host to deploy model"}), + ) + fields_dict["port"] = ( + int, + field(default=None, metadata={"help": "The remote port to deploy model"}), + ) + result_class = dataclass( + type("RemoteModelWorkerParameters", (object,), fields_dict) + ) + + if not models: + return [result_class] + + valid_models = [] + valid_model_cls = [] + for model in models: + if host and host != model.host: + continue + if port and port != model.port: + continue + valid_models += [m.model for m in model.models] + valid_model_cls += [ + (m, _build_parameter_class(m.params)) for m in model.models if m.params + ] + real_model, real_params_cls = valid_model_cls[0] + real_path = None + real_worker_type = "llm" + if model_name: + params_cls_list = [m for m in valid_model_cls if m[0].model == model_name] + if not params_cls_list: + raise ValueError(f"Not supported model with model name: {model_name}") + real_model, real_params_cls = params_cls_list[0] + real_path = real_model.path + real_worker_type = real_model.worker_type + + @dataclass + class RemoteModelWorkerParameters(BaseParameters): + model_name: str = field( + metadata={"valid_values": valid_models, "help": "The model name to deploy"} + ) + model_path: Optional[str] = field( + default=real_path, metadata={"help": "The model path to deploy"} + ) + host: Optional[str] = field( + default=models[0].host, + metadata={ + "valid_values": [model.host for model in models], + "help": "The remote host to deploy model", + }, + ) + + port: Optional[int] = field( + default=models[0].port, + metadata={ + "valid_values": [model.port for model in models], + "help": "The remote port to deploy model", + }, + ) + worker_type: Optional[str] = field( + default=real_worker_type, + metadata={ + "valid_values": WorkerType.values(), + "help": "Worker type", + }, + ) + + return [RemoteModelWorkerParameters, real_params_cls] + + +@model_cli_group.command( + cls=build_lazy_click_command(_dynamic_factory=_remote_model_dynamic_factory) +) +def start(**kwargs): """Start model instances""" - worker_apply( - MODEL_CONTROLLER_ADDRESS, model_name, model_type, WorkerApplyType.START + from pilot.model.cluster import WorkerStartupRequest, RemoteWorkerManager + + worker_manager: RemoteWorkerManager = _get_worker_manager(MODEL_CONTROLLER_ADDRESS) + req = WorkerStartupRequest( + host=kwargs["host"], + port=kwargs["port"], + worker_type=kwargs["worker_type"], + model=kwargs["model_name"], + params={}, ) + del kwargs["host"] + del kwargs["port"] + del kwargs["worker_type"] + req.params = kwargs + loop = get_or_create_event_loop() + res = loop.run_until_complete(worker_manager.model_startup(req)) + print(res) @model_cli_group.command() @@ -165,25 +318,10 @@ def chat(model_name: str, system: str): _cli_chat(MODEL_CONTROLLER_ADDRESS, model_name, system) -# @model_cli_group.command() -# @add_model_options -# def modify(address: str, model_name: str, model_type: str): -# """Restart model instances""" -# worker_apply(address, model_name, model_type, WorkerApplyType.UPDATE_PARAMS) - - -def _get_worker_manager(address: str): - from pilot.model.worker.manager import RemoteWorkerManager, WorkerApplyRequest - - registry = ModelRegistryClient(address) - worker_manager = RemoteWorkerManager(registry) - return worker_manager - - def worker_apply( address: str, model_name: str, model_type: str, apply_type: WorkerApplyType ): - from pilot.model.worker.manager import WorkerApplyRequest + from pilot.model.cluster import WorkerApplyRequest loop = get_or_create_event_loop() worker_manager = _get_worker_manager(address) @@ -201,7 +339,7 @@ def _cli_chat(address: str, model_name: str, system_prompt: str = None): async def _chat_stream(worker_manager, model_name: str, system_prompt: str = None): - from pilot.model.worker.manager import PromptRequest + from pilot.model.cluster import PromptRequest from pilot.scene.base_message import ModelMessage, ModelMessageRoleType print(f"Chatbot started with model {model_name}. Type 'exit' to leave the chat.") @@ -249,13 +387,11 @@ def wrapper(*args, **kwargs): def start_model_controller(**kwargs): """Start model controller""" - from pilot.model.controller.controller import run_model_controller - if kwargs["daemon"]: log_file = os.path.join(LOGDIR, "model_controller_uvicorn.log") _run_current_with_daemon("ModelController", log_file) else: - from pilot.model.controller.controller import run_model_controller + from pilot.model.cluster import run_model_controller run_model_controller() @@ -279,9 +415,8 @@ def _model_dynamic_factory() -> Callable[[None], List[Type]]: return fix_class -@click.command(name="worker") -@EnvArgumentParser.create_click_option( - ModelWorkerParameters, ModelParameters, _dynamic_factory=_model_dynamic_factory +@click.command( + name="worker", cls=build_lazy_click_command(_dynamic_factory=_model_dynamic_factory) ) def start_model_worker(**kwargs): """Start model worker""" @@ -291,7 +426,7 @@ def start_model_worker(**kwargs): log_file = os.path.join(LOGDIR, f"model_worker_{model_type}_{port}_uvicorn.log") _run_current_with_daemon("ModelWorker", log_file) else: - from pilot.model.worker.manager import run_worker_manager + from pilot.model.cluster import run_worker_manager run_worker_manager() diff --git a/pilot/model/cluster/__init__.py b/pilot/model/cluster/__init__.py new file mode 100644 index 000000000..b518a756b --- /dev/null +++ b/pilot/model/cluster/__init__.py @@ -0,0 +1,33 @@ +from pilot.model.cluster.base import ( + EmbeddingsRequest, + PromptRequest, + WorkerApplyRequest, + WorkerParameterRequest, + WorkerStartupRequest, +) +from pilot.model.cluster.worker.manager import ( + initialize_worker_manager_in_client, + run_worker_manager, + worker_manager, +) + +from pilot.model.cluster.registry import ModelRegistry +from pilot.model.cluster.controller.controller import ( + ModelRegistryClient, + run_model_controller, +) + +from pilot.model.cluster.worker.remote_manager import RemoteWorkerManager + +__all__ = [ + "EmbeddingsRequest", + "PromptRequest", + "WorkerApplyRequest", + "WorkerParameterRequest" + "WorkerStartupRequest" + "worker_manager" + "run_worker_manager", + "initialize_worker_manager_in_client", + "ModelRegistry", + "ModelRegistryClient" "RemoteWorkerManager" "run_model_controller", +] diff --git a/pilot/model/cluster/base.py b/pilot/model/cluster/base.py new file mode 100644 index 000000000..7d97e6bd9 --- /dev/null +++ b/pilot/model/cluster/base.py @@ -0,0 +1,46 @@ +from dataclasses import dataclass +from typing import Dict, List + +from pilot.model.base import WorkerApplyType +from pilot.model.parameter import WorkerType +from pilot.scene.base_message import ModelMessage +from pydantic import BaseModel + +WORKER_MANAGER_SERVICE_TYPE = "service" +WORKER_MANAGER_SERVICE_NAME = "WorkerManager" + + +class PromptRequest(BaseModel): + messages: List[ModelMessage] + model: str + prompt: str = None + temperature: float = None + max_new_tokens: int = None + stop: str = None + echo: bool = True + + +class EmbeddingsRequest(BaseModel): + model: str + input: List[str] + + +class WorkerApplyRequest(BaseModel): + model: str + apply_type: WorkerApplyType + worker_type: WorkerType = WorkerType.LLM + params: Dict = None + apply_user: str = None + + +class WorkerParameterRequest(BaseModel): + model: str + worker_type: WorkerType = WorkerType.LLM + + +class WorkerStartupRequest(BaseModel): + host: str + port: int + model: str + worker_type: WorkerType + params: Dict diff --git a/pilot/model/controller/__init__.py b/pilot/model/cluster/controller/__init__.py similarity index 100% rename from pilot/model/controller/__init__.py rename to pilot/model/cluster/controller/__init__.py diff --git a/pilot/model/controller/controller.py b/pilot/model/cluster/controller/controller.py similarity index 98% rename from pilot/model/controller/controller.py rename to pilot/model/cluster/controller/controller.py index bb19e31df..d48e1362f 100644 --- a/pilot/model/controller/controller.py +++ b/pilot/model/cluster/controller/controller.py @@ -6,7 +6,7 @@ from fastapi import APIRouter, FastAPI from pilot.model.base import ModelInstance from pilot.model.parameter import ModelControllerParameters -from pilot.model.controller.registry import EmbeddedModelRegistry, ModelRegistry +from pilot.model.cluster.registry import EmbeddedModelRegistry, ModelRegistry from pilot.utils.parameter_utils import EnvArgumentParser from pilot.utils.api_utils import _api_remote as api_remote diff --git a/pilot/model/controller/ray_controller.py b/pilot/model/cluster/controller/ray_controller.py similarity index 100% rename from pilot/model/controller/ray_controller.py rename to pilot/model/cluster/controller/ray_controller.py diff --git a/pilot/model/controller/tests/__init__.py b/pilot/model/cluster/controller/tests/__init__.py similarity index 100% rename from pilot/model/controller/tests/__init__.py rename to pilot/model/cluster/controller/tests/__init__.py diff --git a/pilot/model/controller/tests/test_registry.py b/pilot/model/cluster/controller/tests/test_registry.py similarity index 88% rename from pilot/model/controller/tests/test_registry.py rename to pilot/model/cluster/controller/tests/test_registry.py index b69b08508..ed366c058 100644 --- a/pilot/model/controller/tests/test_registry.py +++ b/pilot/model/cluster/controller/tests/test_registry.py @@ -2,9 +2,8 @@ from datetime import datetime, timedelta import asyncio -from unittest.mock import patch from pilot.model.base import ModelInstance -from pilot.model.controller.registry import ModelRegistry, EmbeddedModelRegistry +from pilot.model.cluster.registry import EmbeddedModelRegistry @pytest.fixture @@ -16,7 +15,7 @@ def model_registry(): def model_instance(): return ModelInstance( model_name="test_model", - ip="192.168.1.1", + host="192.168.1.1", port=5000, ) @@ -89,12 +88,7 @@ async def test_send_heartbeat(model_registry, model_instance): await model_registry.register_instance(model_instance) last_heartbeat = datetime.now() - timedelta(seconds=10) model_instance.last_heartbeat = last_heartbeat - assert ( - await model_registry.send_heartbeat( - model_instance.model_name, model_instance.ip, model_instance.port - ) - == True - ) + assert await model_registry.send_heartbeat(model_instance) == True assert ( model_registry.registry[model_instance.model_name][0].last_heartbeat > last_heartbeat @@ -125,7 +119,7 @@ async def test_multiple_instances(model_registry, model_instance): """ model_instance2 = ModelInstance( model_name="test_model", - ip="192.168.1.2", + host="192.168.1.2", port=5000, ) await model_registry.register_instance(model_instance) @@ -138,11 +132,11 @@ async def test_same_model_name_different_ip_port(model_registry): """ Test if instances with the same model name but different IP and port are handled correctly """ - instance1 = ModelInstance(model_name="test_model", ip="192.168.1.1", port=5000) - instance2 = ModelInstance(model_name="test_model", ip="192.168.1.2", port=6000) + instance1 = ModelInstance(model_name="test_model", host="192.168.1.1", port=5000) + instance2 = ModelInstance(model_name="test_model", host="192.168.1.2", port=6000) await model_registry.register_instance(instance1) await model_registry.register_instance(instance2) instances = await model_registry.get_all_instances("test_model") assert len(instances) == 2 - assert instances[0].ip != instances[1].ip + assert instances[0].host != instances[1].host assert instances[0].port != instances[1].port diff --git a/pilot/model/worker/__init__.py b/pilot/model/cluster/controller_base.py similarity index 100% rename from pilot/model/worker/__init__.py rename to pilot/model/cluster/controller_base.py diff --git a/pilot/model/cluster/manager_base.py b/pilot/model/cluster/manager_base.py new file mode 100644 index 000000000..4f3fd27e3 --- /dev/null +++ b/pilot/model/cluster/manager_base.py @@ -0,0 +1,82 @@ +import asyncio +from dataclasses import dataclass +from typing import List, Optional, Dict, Iterator +from abc import ABC, abstractmethod +from datetime import datetime +from concurrent.futures import Future +from pilot.model.base import WorkerSupportedModel, ModelOutput, WorkerApplyOutput +from pilot.model.cluster.worker_base import ModelWorker +from pilot.model.cluster.base import WorkerStartupRequest, WorkerApplyRequest +from pilot.model.parameter import ModelWorkerParameters, ModelParameters +from pilot.utils.parameter_utils import ParameterDescription + + +@dataclass +class WorkerRunData: + host: str + port: int + worker_key: str + worker: ModelWorker + worker_params: ModelWorkerParameters + model_params: ModelParameters + stop_event: asyncio.Event + semaphore: asyncio.Semaphore = None + command_args: List[str] = None + _heartbeat_future: Optional[Future] = None + _last_heartbeat: Optional[datetime] = None + + +class WorkerManager(ABC): + @abstractmethod + async def start(self): + """Start worker manager""" + + @abstractmethod + async def stop(self): + """Stop worker manager""" + + @abstractmethod + async def get_model_instances( + self, worker_type: str, model_name: str, healthy_only: bool = True + ) -> List[WorkerRunData]: + """Get model instances by worker type and model name""" + + @abstractmethod + async def select_one_instance( + self, worker_type: str, model_name: str, healthy_only: bool = True + ) -> WorkerRunData: + """Select one instance""" + + @abstractmethod + async def supported_models(self) -> List[WorkerSupportedModel]: + """List supported models""" + + @abstractmethod + async def model_startup(self, startup_req: WorkerStartupRequest) -> bool: + """Create and start a model instance""" + + @abstractmethod + async def model_shutdown(self, shutdown_req: WorkerStartupRequest) -> bool: + """Shutdown model instance""" + + @abstractmethod + async def generate_stream(self, params: Dict, **kwargs) -> Iterator[ModelOutput]: + """Generate stream result, chat scene""" + + @abstractmethod + async def generate(self, params: Dict) -> ModelOutput: + """Generate non stream result""" + + @abstractmethod + async def embeddings(self, params: Dict) -> List[List[float]]: + """Embed input""" + + @abstractmethod + async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput: + """Worker apply""" + + @abstractmethod + async def parameter_descriptions( + self, worker_type: str, model_name: str + ) -> List[ParameterDescription]: + """Get parameter descriptions of model""" diff --git a/pilot/model/controller/registry.py b/pilot/model/cluster/registry.py similarity index 100% rename from pilot/model/controller/registry.py rename to pilot/model/cluster/registry.py diff --git a/pilot/model/worker/ray_worker.py b/pilot/model/cluster/worker/__init__.py similarity index 100% rename from pilot/model/worker/ray_worker.py rename to pilot/model/cluster/worker/__init__.py diff --git a/pilot/model/worker/default_worker.py b/pilot/model/cluster/worker/default_worker.py similarity index 98% rename from pilot/model/worker/default_worker.py rename to pilot/model/cluster/worker/default_worker.py index 157745ffe..cb7686566 100644 --- a/pilot/model/worker/default_worker.py +++ b/pilot/model/cluster/worker/default_worker.py @@ -7,7 +7,7 @@ from pilot.model.base import ModelOutput from pilot.model.loader import ModelLoader, _get_model_real_path from pilot.model.parameter import ModelParameters -from pilot.model.worker.base import ModelWorker +from pilot.model.cluster.worker_base import ModelWorker from pilot.server.chat_adapter import get_llm_chat_adapter, BaseChatAdpter from pilot.utils.model_utils import _clear_torch_cache from pilot.utils.parameter_utils import EnvArgumentParser diff --git a/pilot/model/worker/embedding_worker.py b/pilot/model/cluster/worker/embedding_worker.py similarity index 98% rename from pilot/model/worker/embedding_worker.py rename to pilot/model/cluster/worker/embedding_worker.py index 0fd3c4593..a8824f228 100644 --- a/pilot/model/worker/embedding_worker.py +++ b/pilot/model/cluster/worker/embedding_worker.py @@ -7,7 +7,7 @@ EmbeddingModelParameters, WorkerType, ) -from pilot.model.worker.base import ModelWorker +from pilot.model.cluster.worker_base import ModelWorker from pilot.utils.model_utils import _clear_torch_cache from pilot.utils.parameter_utils import EnvArgumentParser diff --git a/pilot/model/worker/manager.py b/pilot/model/cluster/worker/manager.py similarity index 70% rename from pilot/model/worker/manager.py rename to pilot/model/cluster/worker/manager.py index 76367cc79..7e4852527 100644 --- a/pilot/model/worker/manager.py +++ b/pilot/model/cluster/worker/manager.py @@ -4,13 +4,11 @@ import os import random import time -from abc import ABC, abstractmethod -from concurrent.futures import Future, ThreadPoolExecutor -from dataclasses import asdict, dataclass -from datetime import datetime -from typing import Awaitable, Callable, Dict, Iterator, List, Optional +from concurrent.futures import ThreadPoolExecutor +from dataclasses import asdict +from typing import Awaitable, Callable, Dict, Iterator, List -from fastapi import APIRouter, FastAPI, Request +from fastapi import APIRouter, FastAPI from fastapi.responses import StreamingResponse from pilot.configs.model_config import LOGDIR from pilot.model.base import ( @@ -18,103 +16,34 @@ ModelOutput, WorkerApplyOutput, WorkerApplyType, + WorkerSupportedModel, ) -from pilot.model.controller.registry import ModelRegistry +from pilot.model.cluster.registry import ModelRegistry +from pilot.model.llm_utils import list_supported_models from pilot.model.parameter import ModelParameters, ModelWorkerParameters, WorkerType -from pilot.model.worker.base import ModelWorker -from pilot.scene.base_message import ModelMessage +from pilot.model.cluster.worker_base import ModelWorker +from pilot.model.cluster.manager_base import WorkerManager, WorkerRunData +from pilot.model.cluster.base import * from pilot.utils import build_logger -from pilot.utils.parameter_utils import EnvArgumentParser, ParameterDescription -from pydantic import BaseModel +from pilot.utils.parameter_utils import ( + EnvArgumentParser, + ParameterDescription, + _dict_to_command_args, +) logger = build_logger("model_worker", LOGDIR + "/model_worker.log") -class PromptRequest(BaseModel): - messages: List[ModelMessage] - model: str - prompt: str = None - temperature: float = None - max_new_tokens: int = None - stop: str = None - echo: bool = True - - -class EmbeddingsRequest(BaseModel): - model: str - input: List[str] - - -class WorkerApplyRequest(BaseModel): - model: str - apply_type: WorkerApplyType - worker_type: WorkerType = WorkerType.LLM - params: Dict = None - apply_user: str = None - - -class WorkerParameterRequest(BaseModel): - model: str - worker_type: WorkerType = WorkerType.LLM - - -@dataclass -class WorkerRunData: - worker_key: str - worker: ModelWorker - worker_params: ModelWorkerParameters - model_params: ModelParameters - stop_event: asyncio.Event - semaphore: asyncio.Semaphore = None - command_args: List[str] = None - _heartbeat_future: Optional[Future] = None - _last_heartbeat: Optional[datetime] = None - - RegisterFunc = Callable[[WorkerRunData], Awaitable[None]] DeregisterFunc = Callable[[WorkerRunData], Awaitable[None]] SendHeartbeatFunc = Callable[[WorkerRunData], Awaitable[None]] ApplyFunction = Callable[[WorkerRunData], Awaitable[None]] -class WorkerManager(ABC): - @abstractmethod - async def get_model_instances( - self, worker_type: str, model_name: str, healthy_only: bool = True - ) -> List[WorkerRunData]: - """Get model instances by worker type and model name""" - - @abstractmethod - async def select_one_instanes( - self, worker_type: str, model_name: str, healthy_only: bool = True - ) -> WorkerRunData: - """Select one instances""" - - @abstractmethod - async def generate_stream(self, params: Dict, **kwargs) -> Iterator[ModelOutput]: - """Generate stream result, chat scene""" - - @abstractmethod - async def generate(self, params: Dict) -> ModelOutput: - """Generate non stream result""" - - @abstractmethod - async def embeddings(self, params: Dict) -> List[List[float]]: - """Embed input""" - - @abstractmethod - async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput: - """Worker apply""" - - @abstractmethod - async def parameter_descriptions( - self, worker_type: str, model_name: str - ) -> List[ParameterDescription]: - """Get parameter descriptions of model""" - - async def _async_heartbeat_sender( - worker_run_data: WorkerRunData, send_heartbeat_func: SendHeartbeatFunc + worker_run_data: WorkerRunData, + heartbeat_interval, + send_heartbeat_func: SendHeartbeatFunc, ): while not worker_run_data.stop_event.is_set(): try: @@ -122,7 +51,7 @@ async def _async_heartbeat_sender( except Exception as e: logger.warn(f"Send heartbeat func error: {str(e)}") finally: - await asyncio.sleep(worker_run_data.worker_params.heartbeat_interval) + await asyncio.sleep(heartbeat_interval) class LocalWorkerManager(WorkerManager): @@ -132,6 +61,8 @@ def __init__( deregister_func: DeregisterFunc = None, send_heartbeat_func: SendHeartbeatFunc = None, model_registry: ModelRegistry = None, + host: str = None, + port: int = None, ) -> None: self.workers: Dict[str, List[WorkerRunData]] = dict() self.executor = ThreadPoolExecutor(max_workers=os.cpu_count() * 5) @@ -139,19 +70,58 @@ def __init__( self.deregister_func = deregister_func self.send_heartbeat_func = send_heartbeat_func self.model_registry = model_registry + self.host = host + self.port = port + + self.run_data = WorkerRunData( + host=self.host, + port=self.port, + worker_key=self._worker_key( + WORKER_MANAGER_SERVICE_TYPE, WORKER_MANAGER_SERVICE_NAME + ), + worker=None, + worker_params=None, + model_params=None, + stop_event=asyncio.Event(), + semaphore=None, + command_args=None, + ) def _worker_key(self, worker_type: str, model_name: str) -> str: if isinstance(worker_type, WorkerType): worker_type = worker_type.value return f"{model_name}@{worker_type}" + async def run_blocking_func(self, func, *args): + loop = asyncio.get_event_loop() + return await loop.run_in_executor(self.executor, func, *args) + + async def start(self): + if len(self.workers) > 0: + await self._start_all_worker(apply_req=None) + if self.register_func: + await self.register_func(self.run_data) + if self.send_heartbeat_func: + asyncio.create_task( + _async_heartbeat_sender(self.run_data, 20, self.send_heartbeat_func) + ) + + async def stop(self): + if not self.run_data.stop_event.is_set(): + logger.info("Stop all workers") + self.run_data.stop_event.clear() + stop_tasks = [] + stop_tasks.append(self._stop_all_worker(apply_req=None)) + if self.deregister_func: + stop_tasks.append(self.deregister_func(self.run_data)) + await asyncio.gather(*stop_tasks) + def add_worker( self, worker: ModelWorker, worker_params: ModelWorkerParameters, - embedded_mod: bool = True, command_args: List[str] = None, - ): + ) -> bool: if not command_args: import sys @@ -179,6 +149,8 @@ def add_worker( model_params = worker.parse_parameters(command_args=command_args) worker_run_data = WorkerRunData( + host=self.host, + port=self.port, worker_key=worker_key, worker=worker, worker_params=worker_params, @@ -187,14 +159,66 @@ def add_worker( semaphore=asyncio.Semaphore(worker_params.limit_model_concurrency), command_args=command_args, ) - if not embedded_mod: - exist_instances = [ - (w, p) for w, p in instances if p.host == host and p.port == port - ] - if not exist_instances: - instances.append(worker_run_data) - else: + exist_instances = [ + ins for ins in instances if ins.host == host and ins.port == port + ] + if not exist_instances: instances.append(worker_run_data) + return True + else: + # TODO Update worker + return False + + async def model_startup(self, startup_req: WorkerStartupRequest) -> bool: + """Start model""" + model_name = startup_req.model + worker_type = startup_req.worker_type + params = startup_req.params + logger.debug( + f"start model, model name {model_name}, worker type {worker_type}, params: {params}" + ) + worker_params: ModelWorkerParameters = ModelWorkerParameters.from_dict( + params, ignore_extra_fields=True + ) + if not worker_params.model_name: + worker_params.model_name = model_name + assert model_name == worker_params.model_name + worker = _build_worker(worker_params) + command_args = _dict_to_command_args(params) + success = await self.run_blocking_func( + self.add_worker, worker, worker_params, command_args + ) + if not success: + logger.warn( + f"Add worker failed, worker instances is exist, worker_params: {worker_params}" + ) + return False + supported_types = WorkerType.values() + if worker_type not in supported_types: + raise ValueError( + f"Unsupported worker type: {worker_type}, now supported worker type: {supported_types}" + ) + start_apply_req = WorkerApplyRequest( + model=model_name, apply_type=WorkerApplyType.START, worker_type=worker_type + ) + await self.worker_apply(start_apply_req) + return True + + async def model_shutdown(self, shutdown_req: WorkerStartupRequest) -> bool: + logger.info(f"Begin shutdown model, shutdown_req: {shutdown_req}") + apply_req = WorkerApplyRequest( + model=shutdown_req.model, + apply_type=WorkerApplyType.STOP, + worker_type=shutdown_req.worker_type, + ) + out = await self._stop_all_worker(apply_req) + if out.success: + return True + raise Exception(out.message) + + async def supported_models(self) -> List[WorkerSupportedModel]: + models = await self.run_blocking_func(list_supported_models) + return [WorkerSupportedModel(host=self.host, port=self.port, models=models)] async def get_model_instances( self, worker_type: str, model_name: str, healthy_only: bool = True @@ -202,7 +226,7 @@ async def get_model_instances( worker_key = self._worker_key(worker_type, model_name) return self.workers.get(worker_key) - async def select_one_instanes( + async def select_one_instance( self, worker_type: str, model_name: str, healthy_only: bool = True ) -> WorkerRunData: worker_instances = await self.get_model_instances( @@ -219,7 +243,7 @@ async def _get_model(self, params: Dict, worker_type: str = "llm") -> WorkerRunD model = params.get("model") if not model: raise Exception("Model name count not be empty") - return await self.select_one_instanes(worker_type, model, healthy_only=True) + return await self.select_one_instance(worker_type, model, healthy_only=True) async def generate_stream( self, params: Dict, async_wrapper=None, **kwargs @@ -262,9 +286,8 @@ async def generate(self, params: Dict) -> ModelOutput: if worker_run_data.worker.support_async(): return await worker_run_data.worker.async_generate(params) else: - loop = asyncio.get_event_loop() - return await loop.run_in_executor( - self.executor, worker_run_data.worker.generate, params + return await self.run_blocking_func( + worker_run_data.worker.generate, params ) async def embeddings(self, params: Dict) -> List[List[float]]: @@ -277,9 +300,8 @@ async def embeddings(self, params: Dict) -> List[List[float]]: if worker_run_data.worker.support_async(): return await worker_run_data.worker.async_embeddings(params) else: - loop = asyncio.get_event_loop() - return await loop.run_in_executor( - self.executor, worker_run_data.worker.embeddings, params + return await self.run_blocking_func( + worker_run_data.worker.embeddings, params ) async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput: @@ -342,8 +364,10 @@ async def _start_all_worker( logger.info(f"Begin start all worker, apply_req: {apply_req}") async def _start_worker(worker_run_data: WorkerRunData): - worker_run_data.worker.start( - worker_run_data.model_params, worker_run_data.command_args + await self.run_blocking_func( + worker_run_data.worker.start, + worker_run_data.model_params, + worker_run_data.command_args, ) worker_run_data.stop_event.clear() if worker_run_data.worker_params.register and self.register_func: @@ -355,7 +379,9 @@ async def _start_worker(worker_run_data: WorkerRunData): ): asyncio.create_task( _async_heartbeat_sender( - worker_run_data, self.send_heartbeat_func + worker_run_data, + worker_run_data.worker_params.heartbeat_interval, + self.send_heartbeat_func, ) ) @@ -371,7 +397,7 @@ async def _stop_all_worker( start_time = time.time() async def _stop_worker(worker_run_data: WorkerRunData): - worker_run_data.worker.stop() + await self.run_blocking_func(worker_run_data.worker.stop) # Set stop event worker_run_data.stop_event.set() if worker_run_data._heartbeat_future: @@ -422,62 +448,24 @@ async def update_params(worker_run_data: WorkerRunData): return WorkerApplyOutput(message=message, timecost=timecost) -class RemoteWorkerManager(LocalWorkerManager): - def __init__(self, model_registry: ModelRegistry = None) -> None: - super().__init__(model_registry=model_registry) - - async def get_model_instances( - self, worker_type: str, model_name: str, healthy_only: bool = True - ) -> List[WorkerRunData]: - from pilot.model.worker.remote_worker import RemoteModelWorker +class WorkerManagerAdapter(WorkerManager): + def __init__(self, worker_manager: WorkerManager = None) -> None: + self.worker_manager = worker_manager - worker_key = self._worker_key(worker_type, model_name) - instances: List[ModelInstance] = await self.model_registry.get_all_instances( - worker_key, healthy_only - ) - worker_instances = [] - for ins in instances: - worker = RemoteModelWorker() - worker.load_worker(model_name, model_name, host=ins.host, port=ins.port) - wr = WorkerRunData( - worker_key=ins.model_name, - worker=worker, - worker_params=None, - model_params=None, - stop_event=asyncio.Event(), - semaphore=asyncio.Semaphore(100), # Not limit in client - ) - worker_instances.append(wr) - return worker_instances + async def start(self): + return await self.worker_manager.start() - async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput: - import httpx - - async def _remote_apply_func(worker_run_data: WorkerRunData): - worker_addr = worker_run_data.worker.worker_addr - async with httpx.AsyncClient() as client: - response = await client.post( - worker_addr + "/apply", - headers=worker_run_data.worker.headers, - json=apply_req.dict(), - timeout=worker_run_data.worker.timeout, - ) - if response.status_code == 200: - output = WorkerApplyOutput(**response.json()) - logger.info(f"worker_apply success: {output}") - else: - output = WorkerApplyOutput(message=response.text) - logger.warn(f"worker_apply failed: {output}") - return output + async def stop(self): + return await self.worker_manager.stop() - results = await self._apply_worker(apply_req, _remote_apply_func) - if results: - return results[0] + async def supported_models(self) -> List[WorkerSupportedModel]: + return await self.worker_manager.supported_models() + async def model_startup(self, startup_req: WorkerStartupRequest) -> bool: + return await self.worker_manager.model_startup(startup_req) -class WorkerManagerAdapter(WorkerManager): - def __init__(self, worker_manager: WorkerManager = None) -> None: - self.worker_manager = worker_manager + async def model_shutdown(self, shutdown_req: WorkerStartupRequest) -> bool: + return await self.worker_manager.model_shutdown(shutdown_req) async def get_model_instances( self, worker_type: str, model_name: str, healthy_only: bool = True @@ -486,10 +474,10 @@ async def get_model_instances( worker_type, model_name, healthy_only ) - async def select_one_instanes( + async def select_one_instance( self, worker_type: str, model_name: str, healthy_only: bool = True ) -> WorkerRunData: - return await self.worker_manager.select_one_instanes( + return await self.worker_manager.select_one_instance( worker_type, model_name, healthy_only ) @@ -535,37 +523,58 @@ async def api_generate_stream(request: PromptRequest): @router.post("/worker/generate") async def api_generate(request: PromptRequest): params = request.dict(exclude_none=True) - output = await worker_manager.generate(params) - return output + return await worker_manager.generate(params) @router.post("/worker/embeddings") async def api_embeddings(request: EmbeddingsRequest): params = request.dict(exclude_none=True) - output = await worker_manager.embeddings(params) - return output + return await worker_manager.embeddings(params) @router.post("/worker/apply") async def api_worker_apply(request: WorkerApplyRequest): - output = await worker_manager.worker_apply(request) - return output + return await worker_manager.worker_apply(request) @router.get("/worker/parameter/descriptions") async def api_worker_parameter_descs( model: str, worker_type: str = WorkerType.LLM.value ): - output = await worker_manager.parameter_descriptions(worker_type, model) - return output + return await worker_manager.parameter_descriptions(worker_type, model) + + +@router.get("/worker/models/supports") +async def api_supported_models(): + """Get all supported models. + + This method reads all models from the configuration file and tries to perform some basic checks on the model (like if the path exists). + + If it's a RemoteWorkerManager, this method returns the list of models supported by the entire cluster. + """ + return await worker_manager.supported_models() + + +@router.post("/worker/models/startup") +async def api_model_startup(request: WorkerStartupRequest): + """Start up a specific model.""" + return await worker_manager.model_startup(request) + + +@router.post("/worker/models/shutdown") +async def api_model_shutdown(request: WorkerStartupRequest): + """Shut down a specific model.""" + return await worker_manager.model_shutdown(request) def _setup_fastapi(worker_params: ModelWorkerParameters, app=None): if not app: app = FastAPI() if worker_params.standalone: - from pilot.model.controller.controller import router as controller_router - from pilot.model.controller.controller import initialize_controller + from pilot.model.cluster.controller.controller import initialize_controller + from pilot.model.cluster.controller.controller import ( + router as controller_router, + ) if not worker_params.controller_addr: worker_params.controller_addr = f"http://127.0.0.1:{worker_params.port}" @@ -577,9 +586,11 @@ def _setup_fastapi(worker_params: ModelWorkerParameters, app=None): @app.on_event("startup") async def startup_event(): - asyncio.create_task( - worker_manager.worker_manager._start_all_worker(apply_req=None) - ) + asyncio.create_task(worker_manager.worker_manager.start()) + + @app.on_event("shutdown") + async def startup_event(): + await worker_manager.worker_manager.stop() return app @@ -609,22 +620,23 @@ def _parse_worker_params( def _create_local_model_manager( worker_params: ModelWorkerParameters, ) -> LocalWorkerManager: + from pilot.utils.net_utils import _get_ip_address + + host = ( + worker_params.worker_register_host + if worker_params.worker_register_host + else _get_ip_address() + ) + port = worker_params.port if not worker_params.register or not worker_params.controller_addr: logger.info( f"Not register current to controller, register: {worker_params.register}, controller_addr: {worker_params.controller_addr}" ) - return LocalWorkerManager() + return LocalWorkerManager(host=host, port=port) else: - from pilot.model.controller.controller import ModelRegistryClient - from pilot.utils.net_utils import _get_ip_address + from pilot.model.cluster.controller.controller import ModelRegistryClient client = ModelRegistryClient(worker_params.controller_addr) - host = ( - worker_params.worker_register_host - if worker_params.worker_register_host - else _get_ip_address() - ) - port = worker_params.port async def register_func(worker_run_data: WorkerRunData): instance = ModelInstance( @@ -648,31 +660,33 @@ async def send_heartbeat_func(worker_run_data: WorkerRunData): register_func=register_func, deregister_func=deregister_func, send_heartbeat_func=send_heartbeat_func, + host=host, + port=port, ) -def _start_local_worker( - worker_manager: WorkerManagerAdapter, - worker_params: ModelWorkerParameters, - embedded_mod=True, -): - from pilot.utils.module_utils import import_from_checked_string - +def _build_worker(worker_params: ModelWorkerParameters): if worker_params.worker_class: + from pilot.utils.module_utils import import_from_checked_string + worker_cls = import_from_checked_string(worker_params.worker_class, ModelWorker) logger.info( f"Import worker class from {worker_params.worker_class} successfully" ) worker: ModelWorker = worker_cls() else: - from pilot.model.worker.default_worker import DefaultModelWorker + from pilot.model.cluster.worker.default_worker import DefaultModelWorker worker = DefaultModelWorker() + return worker + +def _start_local_worker( + worker_manager: WorkerManagerAdapter, worker_params: ModelWorkerParameters +): + worker = _build_worker(worker_params) worker_manager.worker_manager = _create_local_model_manager(worker_params) - worker_manager.worker_manager.add_worker( - worker, worker_params, embedded_mod=embedded_mod - ) + worker_manager.worker_manager.add_worker(worker, worker_params) def initialize_worker_manager_in_client( @@ -713,16 +727,13 @@ def initialize_worker_manager_in_client( worker_params.port = local_port logger.info(f"Worker params: {worker_params}") _setup_fastapi(worker_params, app) - _start_local_worker(worker_manager, worker_params, True) - # loop = asyncio.get_event_loop() - # loop.run_until_complete( - # worker_manager.worker_manager._start_all_worker(apply_req=None) - # ) + _start_local_worker(worker_manager, worker_params) else: - from pilot.model.controller.controller import ( - initialize_controller, + from pilot.model.cluster.controller.controller import ( ModelRegistryClient, + initialize_controller, ) + from pilot.model.cluster.worker.remote_manager import RemoteWorkerManager if not worker_params.controller_addr: raise ValueError("Controller can`t be None") @@ -758,13 +769,11 @@ def run_worker_manager( # Run worker manager independently embedded_mod = False app = _setup_fastapi(worker_params) - _start_local_worker(worker_manager, worker_params, embedded_mod=False) + _start_local_worker(worker_manager, worker_params) else: - _start_local_worker(worker_manager, worker_params, embedded_mod=False) + _start_local_worker(worker_manager, worker_params) loop = asyncio.get_event_loop() - loop.run_until_complete( - worker_manager.worker_manager._start_all_worker(apply_req=None) - ) + loop.run_until_complete(worker_manager.worker_manager.start()) if include_router: app.include_router(router, prefix="/api") diff --git a/pilot/model/cluster/worker/ray_worker.py b/pilot/model/cluster/worker/ray_worker.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/model/cluster/worker/remote_manager.py b/pilot/model/cluster/worker/remote_manager.py new file mode 100644 index 000000000..afdf10774 --- /dev/null +++ b/pilot/model/cluster/worker/remote_manager.py @@ -0,0 +1,169 @@ +from typing import Callable, Any +import httpx +import asyncio +from pilot.model.cluster.registry import ModelRegistry +from pilot.model.cluster.worker.manager import LocalWorkerManager, WorkerRunData, logger +from pilot.model.cluster.base import * +from pilot.model.base import ( + ModelInstance, + WorkerApplyOutput, + WorkerSupportedModel, +) + + +class RemoteWorkerManager(LocalWorkerManager): + def __init__(self, model_registry: ModelRegistry = None) -> None: + super().__init__(model_registry=model_registry) + + async def start(self): + pass + + async def stop(self): + pass + + async def _fetch_from_worker( + self, + worker_run_data: WorkerRunData, + endpoint: str, + method: str = "GET", + json: dict = None, + params: dict = None, + additional_headers: dict = None, + success_handler: Callable = None, + error_handler: Callable = None, + ) -> Any: + url = worker_run_data.worker.worker_addr + endpoint + headers = {**worker_run_data.worker.headers, **(additional_headers or {})} + timeout = worker_run_data.worker.timeout + + async with httpx.AsyncClient() as client: + request = client.build_request( + method, + url, + json=json, # using json for data to ensure it sends as application/json + params=params, + headers=headers, + timeout=timeout, + ) + + response = await client.send(request) + if response.status_code != 200: + if error_handler: + return error_handler(response) + else: + error_msg = f"Request to {url} failed, error: {response.text}" + raise Exception(error_msg) + if success_handler: + return success_handler(response) + return response.json() + + async def _apply_to_worker_manager_instances(self): + pass + + async def supported_models(self) -> List[WorkerSupportedModel]: + worker_instances = await self.get_model_instances( + WORKER_MANAGER_SERVICE_TYPE, WORKER_MANAGER_SERVICE_NAME + ) + + async def get_supported_models(worker_run_data) -> List[WorkerSupportedModel]: + def handler(response): + return list(WorkerSupportedModel.from_dict(m) for m in response.json()) + + return await self._fetch_from_worker( + worker_run_data, "/models/supports", success_handler=handler + ) + + models = [] + results = await asyncio.gather( + *(get_supported_models(worker) for worker in worker_instances) + ) + for res in results: + models += res + return models + + async def _get_worker_service_instance( + self, host: str = None, port: int = None + ) -> List[WorkerRunData]: + worker_instances = await self.get_model_instances( + WORKER_MANAGER_SERVICE_TYPE, WORKER_MANAGER_SERVICE_NAME + ) + error_msg = f"Cound not found worker instances" + if host and port: + worker_instances = [ + ins for ins in worker_instances if ins.host == host and ins.port == port + ] + error_msg = f"Cound not found worker instances for host {host} port {port}" + if not worker_instances: + raise Exception(error_msg) + return worker_instances + + async def model_startup(self, startup_req: WorkerStartupRequest) -> bool: + worker_instances = await self._get_worker_service_instance( + startup_req.host, startup_req.port + ) + worker_run_data = worker_instances[0] + logger.info(f"Start model remote, startup_req: {startup_req}") + return await self._fetch_from_worker( + worker_run_data, + "/models/startup", + method="POST", + json=startup_req.dict(), + success_handler=lambda x: True, + ) + + async def model_shutdown(self, shutdown_req: WorkerStartupRequest) -> bool: + worker_instances = await self._get_worker_service_instance( + shutdown_req.host, shutdown_req.port + ) + worker_run_data = worker_instances[0] + logger.info(f"Shutdown model remote, shutdown_req: {shutdown_req}") + return await self._fetch_from_worker( + worker_run_data, + "/models/shutdown", + method="POST", + json=shutdown_req.dict(), + success_handler=lambda x: True, + ) + + async def get_model_instances( + self, worker_type: str, model_name: str, healthy_only: bool = True + ) -> List[WorkerRunData]: + from pilot.model.cluster.worker.remote_worker import RemoteModelWorker + + worker_key = self._worker_key(worker_type, model_name) + instances: List[ModelInstance] = await self.model_registry.get_all_instances( + worker_key, healthy_only + ) + worker_instances = [] + for ins in instances: + worker = RemoteModelWorker() + worker.load_worker(model_name, model_name, host=ins.host, port=ins.port) + wr = WorkerRunData( + host=ins.host, + port=ins.port, + worker_key=ins.model_name, + worker=worker, + worker_params=None, + model_params=None, + stop_event=asyncio.Event(), + semaphore=asyncio.Semaphore(100), # Not limit in client + ) + worker_instances.append(wr) + return worker_instances + + async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput: + async def _remote_apply_func(worker_run_data: WorkerRunData): + return await self._fetch_from_worker( + worker_run_data, + "/apply", + method="POST", + json=apply_req.dict(), + success_handler=lambda res: WorkerApplyOutput(**res.json()), + error_handler=lambda res: WorkerApplyOutput( + message=res.text, success=False + ), + ) + + results = await self._apply_worker(apply_req, _remote_apply_func) + if results: + return results[0] diff --git a/pilot/model/worker/remote_worker.py b/pilot/model/cluster/worker/remote_worker.py similarity index 98% rename from pilot/model/worker/remote_worker.py rename to pilot/model/cluster/worker/remote_worker.py index c7538f6fc..b123f1aa7 100644 --- a/pilot/model/worker/remote_worker.py +++ b/pilot/model/cluster/worker/remote_worker.py @@ -3,7 +3,7 @@ import logging from pilot.model.base import ModelOutput from pilot.model.parameter import ModelParameters -from pilot.model.worker.base import ModelWorker +from pilot.model.cluster.worker_base import ModelWorker class RemoteModelWorker(ModelWorker): diff --git a/pilot/model/worker/base.py b/pilot/model/cluster/worker_base.py similarity index 100% rename from pilot/model/worker/base.py rename to pilot/model/cluster/worker_base.py diff --git a/pilot/model/llm_utils.py b/pilot/model/llm_utils.py index ec50a7d34..9131490f5 100644 --- a/pilot/model/llm_utils.py +++ b/pilot/model/llm_utils.py @@ -2,14 +2,18 @@ # -*- coding:utf-8 -*- import traceback +from pathlib import Path from queue import Queue from threading import Thread import transformers -from typing import List, Optional +from typing import List, Optional, Dict +import cachetools from pilot.configs.config import Config -from pilot.model.base import Message +from pilot.configs.model_config import LLM_MODEL_CONFIG, EMBEDDING_MODEL_CONFIG +from pilot.model.base import Message, SupportedModel +from pilot.utils.parameter_utils import _get_parameter_descriptions def create_chat_completion( @@ -128,3 +132,49 @@ def is_partial_stop(output: str, stop_str: str): if stop_str.startswith(output[-i:]): return True return False + + +@cachetools.cached(cachetools.TTLCache(maxsize=100, ttl=60 * 5)) +def list_supported_models(): + from pilot.model.parameter import WorkerType + + models = _list_supported_models(WorkerType.LLM.value, LLM_MODEL_CONFIG) + models += _list_supported_models(WorkerType.TEXT2VEC.value, EMBEDDING_MODEL_CONFIG) + return models + + +def _list_supported_models( + worker_type: str, model_config: Dict[str, str] +) -> List[SupportedModel]: + from pilot.model.adapter import get_llm_model_adapter + from pilot.model.parameter import ModelParameters + from pilot.model.loader import _get_model_real_path + + ret = [] + for model_name, model_path in model_config.items(): + model_path = _get_model_real_path(model_name, model_path) + model = SupportedModel( + model=model_name, + path=model_path, + worker_type=worker_type, + path_exist=False, + proxy=False, + enabled=False, + params=None, + ) + if "proxyllm" in model_name: + model.proxy = True + else: + path = Path(model_path) + model.path_exist = path.exists() + param_cls = None + try: + llm_adapter = get_llm_model_adapter(model_name, model_path) + param_cls = llm_adapter.model_param_class() + model.enabled = True + params = _get_parameter_descriptions(param_cls) + model.params = params + except Exception: + pass + ret.append(model) + return ret diff --git a/pilot/model/loader.py b/pilot/model/loader.py index 025e93091..a6019c129 100644 --- a/pilot/model/loader.py +++ b/pilot/model/loader.py @@ -353,5 +353,6 @@ def llamacpp_loader(llm_adapter: BaseLLMAdaper, model_params: LlamaCppModelParam def proxyllm_loader(llm_adapter: BaseLLMAdaper, model_params: ProxyModelParameters): from pilot.model.proxy.llms.proxy_model import ProxyModel + logger.info("Load proxyllm") model = ProxyModel(model_params) return model, model diff --git a/pilot/model/parameter.py b/pilot/model/parameter.py index a4067dab0..f8bc91a7c 100644 --- a/pilot/model/parameter.py +++ b/pilot/model/parameter.py @@ -33,9 +33,13 @@ class ModelControllerParameters(BaseParameters): @dataclass -class ModelWorkerParameters(BaseParameters): +class BaseModelParameters(BaseParameters): model_name: str = field(metadata={"help": "Model name", "tags": "fixed"}) model_path: str = field(metadata={"help": "Model path", "tags": "fixed"}) + + +@dataclass +class ModelWorkerParameters(BaseModelParameters): worker_type: Optional[str] = field( default=None, metadata={"valid_values": WorkerType.values(), "help": "Worker type"}, @@ -84,9 +88,7 @@ class ModelWorkerParameters(BaseParameters): @dataclass -class EmbeddingModelParameters(BaseParameters): - model_name: str = field(metadata={"help": "Model name", "tags": "fixed"}) - model_path: str = field(metadata={"help": "Model path", "tags": "fixed"}) +class EmbeddingModelParameters(BaseModelParameters): device: Optional[str] = field( default=None, metadata={ @@ -114,12 +116,6 @@ def build_kwargs(self, **kwargs) -> Dict: return kwargs -@dataclass -class BaseModelParameters(BaseParameters): - model_name: str = field(metadata={"help": "Model name", "tags": "fixed"}) - model_path: str = field(metadata={"help": "Model path", "tags": "fixed"}) - - @dataclass class ModelParameters(BaseModelParameters): device: Optional[str] = field( diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index a35c9963f..9880edcb2 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -139,7 +139,7 @@ async def stream_call(self): logger.info(f"Requert: \n{payload}") ai_response_text = "" try: - from pilot.model.worker.manager import worker_manager + from pilot.model.cluster import worker_manager async for output in worker_manager.generate_stream(payload): yield output @@ -157,7 +157,7 @@ async def nostream_call(self): logger.info(f"Request: \n{payload}") ai_response_text = "" try: - from pilot.model.worker.manager import worker_manager + from pilot.model.cluster import worker_manager model_output = await worker_manager.generate(payload) diff --git a/pilot/scene/chat_knowledge/v1/chat.py b/pilot/scene/chat_knowledge/v1/chat.py index dedb72f6c..65402b4e3 100644 --- a/pilot/scene/chat_knowledge/v1/chat.py +++ b/pilot/scene/chat_knowledge/v1/chat.py @@ -6,7 +6,7 @@ from pilot.configs.model_config import ( KNOWLEDGE_UPLOAD_ROOT_PATH, - LLM_MODEL_CONFIG, + EMBEDDING_MODEL_CONFIG, ) from pilot.scene.chat_knowledge.v1.prompt import prompt @@ -48,7 +48,7 @@ def __init__(self, chat_session_id, user_input, select_param: str = None): "chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH, } self.knowledge_embedding_client = EmbeddingEngine( - model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL], + model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL], vector_store_config=vector_store_config, ) diff --git a/pilot/server/dbgpt_server.py b/pilot/server/dbgpt_server.py index 809743518..011b1146a 100644 --- a/pilot/server/dbgpt_server.py +++ b/pilot/server/dbgpt_server.py @@ -23,7 +23,7 @@ from pilot.openapi.base import validation_exception_handler from pilot.openapi.api_v1.editor.api_editor_v1 import router as api_editor_route_v1 from pilot.commands.disply_type.show_chart_gen import static_message_img_path -from pilot.model.worker.manager import initialize_worker_manager_in_client +from pilot.model.cluster import initialize_worker_manager_in_client from pilot.utils.utils import setup_logging, logging_str_to_uvicorn_level static_file_path = os.path.join(os.getcwd(), "server/static") diff --git a/pilot/server/knowledge/api.py b/pilot/server/knowledge/api.py index c2113ab3b..5d9c7522b 100644 --- a/pilot/server/knowledge/api.py +++ b/pilot/server/knowledge/api.py @@ -7,7 +7,10 @@ from langchain.embeddings import HuggingFaceEmbeddings from pilot.configs.config import Config -from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH +from pilot.configs.model_config import ( + EMBEDDING_MODEL_CONFIG, + KNOWLEDGE_UPLOAD_ROOT_PATH, +) from pilot.openapi.api_view_model import Result from pilot.embedding_engine.embedding_engine import EmbeddingEngine @@ -29,7 +32,9 @@ router = APIRouter() -embeddings = HuggingFaceEmbeddings(model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL]) +embeddings = HuggingFaceEmbeddings( + model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL] +) knowledge_space_service = KnowledgeService() diff --git a/pilot/server/knowledge/service.py b/pilot/server/knowledge/service.py index 9faad31ed..dc95cf9c9 100644 --- a/pilot/server/knowledge/service.py +++ b/pilot/server/knowledge/service.py @@ -5,7 +5,10 @@ from pilot.vector_store.connector import VectorStoreConnector from pilot.configs.config import Config -from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH +from pilot.configs.model_config import ( + EMBEDDING_MODEL_CONFIG, + KNOWLEDGE_UPLOAD_ROOT_PATH, +) from pilot.logs import logger from pilot.server.knowledge.chunk_db import ( DocumentChunkEntity, @@ -204,7 +207,7 @@ def sync_knowledge_document(self, space_name, doc_ids): client = EmbeddingEngine( knowledge_source=doc.content, knowledge_type=doc.doc_type.upper(), - model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL], + model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL], vector_store_config={ "vector_store_name": space_name, "vector_store_type": CFG.VECTOR_STORE_TYPE, @@ -341,7 +344,7 @@ def _build_default_context(self): "topk": CFG.KNOWLEDGE_SEARCH_TOP_SIZE, "recall_score": 0.0, "recall_type": "TopK", - "model": LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL].rsplit("/", 1)[-1], + "model": EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL].rsplit("/", 1)[-1], "chunk_size": CFG.KNOWLEDGE_CHUNK_SIZE, "chunk_overlap": CFG.KNOWLEDGE_CHUNK_OVERLAP, }, diff --git a/pilot/server/llmserver.py b/pilot/server/llmserver.py index 3fdabc002..2e67b370f 100644 --- a/pilot/server/llmserver.py +++ b/pilot/server/llmserver.py @@ -9,7 +9,7 @@ from pilot.configs.config import Config from pilot.configs.model_config import LLM_MODEL_CONFIG -from pilot.model.worker.manager import run_worker_manager +from pilot.model.cluster import run_worker_manager CFG = Config() diff --git a/pilot/summary/db_summary_client.py b/pilot/summary/db_summary_client.py index 964580591..1b6cc251f 100644 --- a/pilot/summary/db_summary_client.py +++ b/pilot/summary/db_summary_client.py @@ -5,7 +5,7 @@ from pilot.configs.config import Config from pilot.configs.model_config import ( KNOWLEDGE_UPLOAD_ROOT_PATH, - LLM_MODEL_CONFIG, + EMBEDDING_MODEL_CONFIG, LOGDIR, ) from pilot.scene.base import ChatScene @@ -36,7 +36,7 @@ def db_summary_embedding(self, dbname, db_type): db_summary_client = RdbmsSummary(dbname, db_type) embeddings = HuggingFaceEmbeddings( - model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL] + model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL] ) vector_store_config = { "vector_store_name": dbname + "_summary", @@ -90,7 +90,7 @@ def get_db_summary(self, dbname, query, topk): "chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH, } knowledge_embedding_client = EmbeddingEngine( - model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL], + model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL], vector_store_config=vector_store_config, ) table_docs = knowledge_embedding_client.similar_search(query, topk) @@ -108,7 +108,7 @@ def get_similar_tables(self, dbname, query, topk): "chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH, } knowledge_embedding_client = EmbeddingEngine( - model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL], + model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL], vector_store_config=vector_store_config, ) if CFG.SUMMARY_CONFIG == "FAST": @@ -134,7 +134,7 @@ def get_similar_tables(self, dbname, query, topk): "chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH, } knowledge_embedding_client = EmbeddingEngine( - model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL], + model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL], vector_store_config=vector_store_config, ) table_summery = knowledge_embedding_client.similar_search(query, 1) diff --git a/pilot/utils/command_utils.py b/pilot/utils/command_utils.py index 815d13d5b..aeaad3f9e 100644 --- a/pilot/utils/command_utils.py +++ b/pilot/utils/command_utils.py @@ -4,6 +4,7 @@ from typing import List, Dict import psutil import platform +from functools import lru_cache def _get_abspath_of_current_command(command_path: str): @@ -137,6 +138,7 @@ def _get_ports_by_cmdline_part(service_keys: List[str]) -> List[int]: return ports +@lru_cache() def _detect_controller_address() -> str: controller_addr = os.getenv("CONTROLLER_ADDRESS") if controller_addr: diff --git a/pilot/utils/module_utils.py b/pilot/utils/module_utils.py index cbc1db149..c2d857440 100644 --- a/pilot/utils/module_utils.py +++ b/pilot/utils/module_utils.py @@ -2,7 +2,7 @@ from importlib import import_module -def import_from_string(module_path: str): +def import_from_string(module_path: str, ignore_import_error: bool = False): try: module_path, class_name = module_path.rsplit(".", 1) except ValueError: @@ -12,6 +12,8 @@ def import_from_string(module_path: str): try: return getattr(module, class_name) except AttributeError: + if ignore_import_error: + return None raise ImportError( f'Module "{module_path}" does not define a "{class_name}" attribute/class' ) diff --git a/pilot/utils/parameter_utils.py b/pilot/utils/parameter_utils.py index 247cfb43d..b67282442 100644 --- a/pilot/utils/parameter_utils.py +++ b/pilot/utils/parameter_utils.py @@ -1,21 +1,50 @@ import argparse import os -from dataclasses import dataclass, fields, MISSING +from dataclasses import dataclass, fields, MISSING, asdict, field from typing import Any, List, Optional, Type, Union, Callable, Dict from collections import OrderedDict @dataclass class ParameterDescription: + param_class: str param_name: str param_type: str - description: str default_value: Optional[Any] + description: str valid_values: Optional[List[Any]] + ext_metadata: Dict @dataclass class BaseParameters: + @classmethod + def from_dict( + cls, data: dict, ignore_extra_fields: bool = False + ) -> "BaseParameters": + """Create an instance of the dataclass from a dictionary. + + Args: + data: A dictionary containing values for the dataclass fields. + ignore_extra_fields: If True, any extra fields in the data dictionary that are + not part of the dataclass will be ignored. + If False, extra fields will raise an error. Defaults to False. + Returns: + An instance of the dataclass with values populated from the given dictionary. + + Raises: + TypeError: If `ignore_extra_fields` is False and there are fields in the + dictionary that aren't present in the dataclass. + """ + all_field_names = {f.name for f in fields(cls)} + if ignore_extra_fields: + data = {key: value for key, value in data.items() if key in all_field_names} + else: + extra_fields = set(data.keys()) - all_field_names + if extra_fields: + raise TypeError(f"Unexpected fields: {', '.join(extra_fields)}") + return cls(**data) + def update_from(self, source: Union["BaseParameters", dict]) -> bool: """ Update the attributes of this object using the values from another object (of the same or parent type) or a dictionary. @@ -68,6 +97,35 @@ def __str__(self) -> str: ) return "\n".join(parameters) + def to_command_args(self, args_prefix: str = "--") -> List[str]: + """Convert the fields of the dataclass to a list of command line arguments. + + Args: + args_prefix: args prefix + Returns: + A list of strings where each field is represented by two items: + one for the field name prefixed by args_prefix, and one for its value. + """ + return _dict_to_command_args(asdict(self), args_prefix=args_prefix) + + +def _dict_to_command_args(obj: Dict, args_prefix: str = "--") -> List[str]: + """Convert dict to a list of command line arguments + + Args: + obj: dict + Returns: + A list of strings where each field is represented by two items: + one for the field name prefixed by args_prefix, and one for its value. + """ + args = [] + for key, value in obj.items(): + if value is None: + continue + args.append(f"{args_prefix}{key}") + args.append(str(value)) + return args + def _get_simple_privacy_field_value(obj, field_info): """Retrieve the value of a field from a dataclass instance, applying privacy rules if necessary. @@ -81,9 +139,9 @@ def _get_simple_privacy_field_value(obj, field_info): - str: if length > 5, masks the middle part and returns first and last char; otherwise, returns "******" - Parameters: - - obj: The dataclass instance. - - field_info: A Field object that contains information about the dataclass field. + Args: + obj: The dataclass instance. + field_info: A Field object that contains information about the dataclass field. Returns: The original or modified value of the field based on the privacy rules. @@ -202,11 +260,42 @@ def _create_arg_parser(dataclass_type: Type) -> argparse.ArgumentParser: parser.add_argument(f"--{field.name}", **argument_kwargs) return parser + @staticmethod + def _create_click_option_from_field(field_name: str, field: Type, is_func=True): + import click + + help_text = field.metadata.get("help", "") + valid_values = field.metadata.get("valid_values", None) + cli_params = { + "default": None if field.default is MISSING else field.default, + "help": help_text, + "show_default": True, + "required": field.default is MISSING, + } + if valid_values: + cli_params["type"] = click.Choice(valid_values) + real_type = EnvArgumentParser._get_argparse_type(field.type) + if real_type is int: + cli_params["type"] = click.INT + elif real_type is float: + cli_params["type"] = click.FLOAT + elif real_type is str: + cli_params["type"] = click.STRING + elif real_type is bool: + cli_params["is_flag"] = True + name = f"--{field_name}" + if is_func: + return click.option( + name, + **cli_params, + ) + else: + return click.Option([name], **cli_params) + @staticmethod def create_click_option( *dataclass_types: Type, _dynamic_factory: Callable[[None], List[Type]] = None ): - import click import functools from collections import OrderedDict @@ -222,30 +311,8 @@ def create_click_option( def decorator(func): for field_name, field in reversed(combined_fields.items()): - help_text = field.metadata.get("help", "") - valid_values = field.metadata.get("valid_values", None) - cli_params = { - "default": None if field.default is MISSING else field.default, - "help": help_text, - "show_default": True, - "required": field.default is MISSING, - } - if valid_values: - cli_params["type"] = click.Choice(valid_values) - real_type = EnvArgumentParser._get_argparse_type(field.type) - if real_type is int: - cli_params["type"] = click.INT - elif real_type is float: - cli_params["type"] = click.FLOAT - elif real_type is str: - cli_params["type"] = click.STRING - elif real_type is bool: - cli_params["is_flag"] = True - - option_decorator = click.option( - # f"--{field_name.replace('_', '-')}", **cli_params - f"--{field_name}", - **cli_params, + option_decorator = EnvArgumentParser._create_click_option_from_field( + field_name, field ) func = option_decorator(func) @@ -257,6 +324,23 @@ def wrapper(*args, **kwargs): return decorator + @staticmethod + def _create_raw_click_option( + *dataclass_types: Type, _dynamic_factory: Callable[[None], List[Type]] = None + ): + combined_fields = _merge_dataclass_types( + *dataclass_types, _dynamic_factory=_dynamic_factory + ) + options = [] + + for field_name, field in reversed(combined_fields.items()): + options.append( + EnvArgumentParser._create_click_option_from_field( + field_name, field, is_func=False + ) + ) + return options + @staticmethod def create_argparse_option( *dataclass_types: Type, _dynamic_factory: Callable[[None], List[Type]] = None @@ -366,21 +450,70 @@ def _merge_dataclass_types( return combined_fields +def _type_str_to_python_type(type_str: str) -> Type: + type_mapping: Dict[str, Type] = { + "int": int, + "float": float, + "bool": bool, + "str": str, + } + return type_mapping.get(type_str, str) + + def _get_parameter_descriptions(dataclass_type: Type) -> List[ParameterDescription]: descriptions = [] for field in fields(dataclass_type): + ext_metadata = { + k: v for k, v in field.metadata.items() if k not in ["help", "valid_values"] + } + descriptions.append( ParameterDescription( + param_class=f"{dataclass_type.__module__}.{dataclass_type.__name__}", param_name=field.name, param_type=EnvArgumentParser._get_argparse_type_str(field.type), description=field.metadata.get("help", None), - default_value=field.default, # TODO handle dataclasses._MISSING_TYPE + default_value=field.default if field.default != MISSING else None, valid_values=field.metadata.get("valid_values", None), + ext_metadata=ext_metadata, ) ) return descriptions +def _build_parameter_class(desc: List[ParameterDescription]) -> Type: + from pilot.utils.module_utils import import_from_string + + if not desc: + raise ValueError("Parameter descriptions cant be empty") + param_class_str = desc[0].param_class + param_class = import_from_string(param_class_str, ignore_import_error=True) + if param_class: + return param_class + module_name, _, class_name = param_class_str.rpartition(".") + + fields_dict = {} # This will store field names and their default values or field() + annotations = {} # This will store the type annotations for the fields + + for d in desc: + metadata = d.ext_metadata if d.ext_metadata else {} + metadata["help"] = d.description + metadata["valid_values"] = d.valid_values + + annotations[d.param_name] = _type_str_to_python_type( + d.param_type + ) # Set type annotation + fields_dict[d.param_name] = field(default=d.default_value, metadata=metadata) + + # Create the new class. Note the setting of __annotations__ for type hints + new_class = type( + class_name, (object,), {**fields_dict, "__annotations__": annotations} + ) + result_class = dataclass(new_class) # Make it a dataclass + + return result_class + + class _SimpleArgParser: def __init__(self, *args): self.params = {arg.replace("_", "-"): None for arg in args} @@ -422,3 +555,24 @@ def __str__(self): return "\n".join( [f'{key.replace("-", "_")}: {value}' for key, value in self.params.items()] ) + + +def build_lazy_click_command(*dataclass_types: Type, _dynamic_factory=None): + import click + + class LazyCommand(click.Command): + def __init__(self, *args, **kwargs): + super(LazyCommand, self).__init__(*args, **kwargs) + self.dynamic_params_added = False + + def get_params(self, ctx): + if ctx and not self.dynamic_params_added: + dynamic_params = EnvArgumentParser._create_raw_click_option( + *dataclass_types, _dynamic_factory=_dynamic_factory + ) + for param in reversed(dynamic_params): + self.params.append(param) + self.dynamic_params_added = True + return super(LazyCommand, self).get_params(ctx) + + return LazyCommand diff --git a/pilot/vector_store/extract_tovec.py b/pilot/vector_store/extract_tovec.py index a79960477..c3cf2577f 100644 --- a/pilot/vector_store/extract_tovec.py +++ b/pilot/vector_store/extract_tovec.py @@ -29,10 +29,10 @@ def knownledge_tovec_st(filename): """Use sentence transformers to embedding the document. https://github.com/UKPLab/sentence-transformers """ - from pilot.configs.model_config import LLM_MODEL_CONFIG + from pilot.configs.model_config import EMBEDDING_MODEL_CONFIG embeddings = HuggingFaceEmbeddings( - model_name=LLM_MODEL_CONFIG["sentence-transforms"] + model_name=EMBEDDING_MODEL_CONFIG["sentence-transforms"] ) with open(filename, "r") as f: @@ -57,10 +57,10 @@ def load_knownledge_from_doc(): "Not Exists Local DataSets, We will answers the Question use model default." ) - from pilot.configs.model_config import LLM_MODEL_CONFIG + from pilot.configs.model_config import EMBEDDING_MODEL_CONFIG embeddings = HuggingFaceEmbeddings( - model_name=LLM_MODEL_CONFIG["sentence-transforms"] + model_name=EMBEDDING_MODEL_CONFIG["sentence-transforms"] ) files = os.listdir(DATASETS_DIR) diff --git a/pilot/vector_store/file_loader.py b/pilot/vector_store/file_loader.py index cca027324..76411004d 100644 --- a/pilot/vector_store/file_loader.py +++ b/pilot/vector_store/file_loader.py @@ -16,7 +16,7 @@ from pilot.configs.model_config import ( DATASETS_DIR, - LLM_MODEL_CONFIG, + EMBEDDING_MODEL_CONFIG, VECTORE_PATH, ) @@ -39,7 +39,7 @@ class KnownLedge2Vector: """ embeddings: object = None - model_name = LLM_MODEL_CONFIG["sentence-transforms"] + model_name = EMBEDDING_MODEL_CONFIG["sentence-transforms"] def __init__(self, model_name=None) -> None: if not model_name: diff --git a/requirements.txt b/requirements.txt index daec0dc85..130d7ccd2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -79,4 +79,5 @@ duckdb duckdb-engine # cli -prettytable \ No newline at end of file +prettytable +cachetools \ No newline at end of file