diff --git a/assets/wechat.jpg b/assets/wechat.jpg index 40a9d988e..a50496546 100644 Binary files a/assets/wechat.jpg and b/assets/wechat.jpg differ diff --git a/pilot/componet.py b/pilot/componet.py index 705eb1193..2c3980cfc 100644 --- a/pilot/componet.py +++ b/pilot/componet.py @@ -1,13 +1,17 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Type, Dict, TypeVar, Optional, TYPE_CHECKING +from typing import Type, Dict, TypeVar, Optional, Union, TYPE_CHECKING +from enum import Enum +import logging import asyncio # Checking for type hints during runtime if TYPE_CHECKING: from fastapi import FastAPI +logger = logging.getLogger(__name__) + class LifeCycle: """This class defines hooks for lifecycle events of a component.""" @@ -37,6 +41,11 @@ async def async_before_stop(self): pass +class ComponetType(str, Enum): + WORKER_MANAGER = "dbgpt_worker_manager" + MODEL_CONTROLLER = "dbgpt_model_controller" + + class BaseComponet(LifeCycle, ABC): """Abstract Base Component class. All custom components should extend this.""" @@ -80,11 +89,21 @@ def register(self, componet: Type[BaseComponet], *args, **kwargs): def register_instance(self, instance: T): """Register an already initialized component.""" - self.componets[instance.name] = instance + name = instance.name + if isinstance(name, ComponetType): + name = name.value + if name in self.componets: + raise RuntimeError( + f"Componse name {name} already exists: {self.componets[name]}" + ) + logger.info(f"Register componet with name {name} and instance: {instance}") + self.componets[name] = instance instance.init_app(self) - def get_componet(self, name: str, componet_type: Type[T]) -> T: + def get_componet(self, name: Union[str, ComponetType], componet_type: Type[T]) -> T: """Retrieve a registered component by its name and type.""" + if isinstance(name, ComponetType): + name = name.value component = self.componets.get(name) if not component: raise ValueError(f"No component found with name {name}") diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index b4632003e..e80f9c8e8 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -69,6 +69,9 @@ def get_device() -> str: # (Llama2 based) We only support WizardLM-13B-V1.2 for now, which is trained from Llama-2 13b, see https://huggingface.co/WizardLM/WizardLM-13B-V1.2 "wizardlm-13b": os.path.join(MODEL_PATH, "WizardLM-13B-V1.2"), "llama-cpp": os.path.join(MODEL_PATH, "ggml-model-q4_0.bin"), + # https://huggingface.co/internlm/internlm-chat-7b-v1_1, 7b vs 7b-v1.1: https://github.com/InternLM/InternLM/issues/288 + "internlm-7b": os.path.join(MODEL_PATH, "internlm-chat-7b-v1_1"), + "internlm-7b-8k": os.path.join(MODEL_PATH, "internlm-chat-7b-8k"), } EMBEDDING_MODEL_CONFIG = { diff --git a/pilot/model/adapter.py b/pilot/model/adapter.py index 0b93faa40..8fe4a9057 100644 --- a/pilot/model/adapter.py +++ b/pilot/model/adapter.py @@ -411,6 +411,29 @@ def loader(self, model_path: str, from_pretrained_kwargs: dict): return model, tokenizer +class InternLMAdapter(BaseLLMAdaper): + """The model adapter for internlm/internlm-chat-7b""" + + def match(self, model_path: str): + return "internlm" in model_path.lower() + + def loader(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + model = AutoModelForCausalLM.from_pretrained( + model_path, + low_cpu_mem_usage=True, + trust_remote_code=True, + **from_pretrained_kwargs, + ) + model = model.eval() + if "8k" in model_path.lower(): + model.config.max_sequence_length = 8192 + tokenizer = AutoTokenizer.from_pretrained( + model_path, use_fast=False, trust_remote_code=True, revision=revision + ) + return model, tokenizer + + register_llm_model_adapters(VicunaLLMAdapater) register_llm_model_adapters(ChatGLMAdapater) register_llm_model_adapters(GuanacoAdapter) @@ -421,6 +444,7 @@ def loader(self, model_path: str, from_pretrained_kwargs: dict): register_llm_model_adapters(BaichuanAdapter) register_llm_model_adapters(WizardLMAdapter) register_llm_model_adapters(LlamaCppAdapater) +register_llm_model_adapters(InternLMAdapter) # TODO Default support vicuna, other model need to tests and Evaluate # just for test_py, remove this later diff --git a/pilot/model/cluster/controller/controller.py b/pilot/model/cluster/controller/controller.py index e1d55eab7..54360e477 100644 --- a/pilot/model/cluster/controller/controller.py +++ b/pilot/model/cluster/controller/controller.py @@ -4,6 +4,7 @@ from typing import List from fastapi import APIRouter, FastAPI +from pilot.componet import BaseComponet, ComponetType, SystemApp from pilot.model.base import ModelInstance from pilot.model.parameter import ModelControllerParameters from pilot.model.cluster.registry import EmbeddedModelRegistry, ModelRegistry @@ -14,7 +15,12 @@ ) -class BaseModelController(ABC): +class BaseModelController(BaseComponet, ABC): + name = ComponetType.MODEL_CONTROLLER + + def init_app(self, system_app: SystemApp): + pass + @abstractmethod async def register_instance(self, instance: ModelInstance) -> bool: """Register a given model instance""" @@ -25,7 +31,7 @@ async def deregister_instance(self, instance: ModelInstance) -> bool: @abstractmethod async def get_all_instances( - self, model_name: str, healthy_only: bool = False + self, model_name: str = None, healthy_only: bool = False ) -> List[ModelInstance]: """Fetch all instances of a given model. Optionally, fetch only the healthy instances.""" @@ -51,7 +57,7 @@ async def deregister_instance(self, instance: ModelInstance) -> bool: return await self.registry.deregister_instance(instance) async def get_all_instances( - self, model_name: str, healthy_only: bool = False + self, model_name: str = None, healthy_only: bool = False ) -> List[ModelInstance]: logging.info( f"Get all instances with {model_name}, healthy_only: {healthy_only}" @@ -94,7 +100,7 @@ async def get_all_model_instances(self) -> List[ModelInstance]: @sync_api_remote(path="/api/controller/models") def sync_get_all_instances( - self, model_name: str, healthy_only: bool = False + self, model_name: str = None, healthy_only: bool = False ) -> List[ModelInstance]: pass @@ -110,7 +116,7 @@ async def deregister_instance(self, instance: ModelInstance) -> bool: return await self.backend.deregister_instance(instance) async def get_all_instances( - self, model_name: str, healthy_only: bool = False + self, model_name: str = None, healthy_only: bool = False ) -> List[ModelInstance]: return await self.backend.get_all_instances(model_name, healthy_only) diff --git a/pilot/model/conversation.py b/pilot/model/conversation.py index f3b85842f..c8cb2a74b 100644 --- a/pilot/model/conversation.py +++ b/pilot/model/conversation.py @@ -2,6 +2,8 @@ Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py Conversation prompt templates. + +TODO Using fastchat core package """ import dataclasses @@ -366,4 +368,21 @@ def get_conv_template(name: str) -> Conversation: ) ) +# Internlm-chat template +register_conv_template( + Conversation( + name="internlm-chat", + system="A chat between a curious <|User|> and an <|Bot|>. The <|Bot|> gives helpful, detailed, and polite answers to the <|User|>'s questions.\n\n", + roles=("<|User|>", "<|Bot|>"), + messages=(), + offset=0, + sep_style=SeparatorStyle.CHATINTERN, + sep="", + sep2="", + stop_token_ids=[1, 103028], + stop_str="", + ) +) + + # TODO Support other model conversation template diff --git a/pilot/openapi/api_v1/api_v1.py b/pilot/openapi/api_v1/api_v1.py index 1507a9a1f..3af082026 100644 --- a/pilot/openapi/api_v1/api_v1.py +++ b/pilot/openapi/api_v1/api_v1.py @@ -18,6 +18,7 @@ from typing import List import tempfile +from pilot.componet import ComponetType from pilot.openapi.api_view_model import ( Result, ConversationVo, @@ -352,20 +353,17 @@ async def chat_completions(dialogue: ConversationVo = Body()): async def model_types(request: Request): print(f"/controller/model/types") try: - import httpx - - async with httpx.AsyncClient() as client: - base_url = request.base_url - response = await client.get( - f"{base_url}api/controller/models?healthy_only=true", - ) types = set() - if response.status_code == 200: - models = json.loads(response.text) - for model in models: - worker_type = model["model_name"].split("@")[1] - if worker_type == "llm": - types.add(model["model_name"].split("@")[0]) + from pilot.model.cluster.controller.controller import BaseModelController + + controller = CFG.SYSTEM_APP.get_componet( + ComponetType.MODEL_CONTROLLER, BaseModelController + ) + models = await controller.get_all_instances(healthy_only=True) + for model in models: + worker_name, worker_type = model.model_name.split("@") + if worker_type == "llm": + types.add(worker_name) return Result.succ(list(types)) except Exception as e: diff --git a/pilot/server/chat_adapter.py b/pilot/server/chat_adapter.py index 466a96c0f..80f22effe 100644 --- a/pilot/server/chat_adapter.py +++ b/pilot/server/chat_adapter.py @@ -247,6 +247,16 @@ def get_generate_stream_func(self, model_path: str): return generate_stream +class InternLMChatAdapter(BaseChatAdpter): + """The model adapter for internlm/internlm-chat-7b""" + + def match(self, model_path: str): + return "internlm" in model_path.lower() + + def get_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("internlm-chat") + + register_llm_model_chat_adapter(VicunaChatAdapter) register_llm_model_chat_adapter(ChatGLMChatAdapter) register_llm_model_chat_adapter(GuanacoChatAdapter) @@ -257,6 +267,7 @@ def get_generate_stream_func(self, model_path: str): register_llm_model_chat_adapter(BaichuanChatAdapter) register_llm_model_chat_adapter(WizardLMChatAdapter) register_llm_model_chat_adapter(LlamaCppChatAdapter) +register_llm_model_chat_adapter(InternLMChatAdapter) # Proxy model for test and develop, it's cheap for us now. register_llm_model_chat_adapter(ProxyllmChatAdapter) diff --git a/pilot/server/componet_configs.py b/pilot/server/componet_configs.py index 745937068..b68b052fa 100644 --- a/pilot/server/componet_configs.py +++ b/pilot/server/componet_configs.py @@ -9,10 +9,12 @@ def initialize_componets(system_app: SystemApp, embedding_model_name: str): from pilot.model.cluster import worker_manager + from pilot.model.cluster.controller.controller import controller system_app.register( RemoteEmbeddingFactory, worker_manager, model_name=embedding_model_name ) + system_app.register_instance(controller) class RemoteEmbeddingFactory(EmbeddingFactory):