Skip to content

Commit

Permalink
feat(model): support InternLM (#583)
Browse files Browse the repository at this point in the history
- Close #575
- Fix the bug that the webserver cannot return model instances
  • Loading branch information
Aries-ckt authored Sep 14, 2023
2 parents 6555b67 + f304f97 commit 94c4f4a
Show file tree
Hide file tree
Showing 9 changed files with 103 additions and 21 deletions.
Binary file modified assets/wechat.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
25 changes: 22 additions & 3 deletions pilot/componet.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Type, Dict, TypeVar, Optional, TYPE_CHECKING
from typing import Type, Dict, TypeVar, Optional, Union, TYPE_CHECKING
from enum import Enum
import logging
import asyncio

# Checking for type hints during runtime
if TYPE_CHECKING:
from fastapi import FastAPI

logger = logging.getLogger(__name__)


class LifeCycle:
"""This class defines hooks for lifecycle events of a component."""
Expand Down Expand Up @@ -37,6 +41,11 @@ async def async_before_stop(self):
pass


class ComponetType(str, Enum):
WORKER_MANAGER = "dbgpt_worker_manager"
MODEL_CONTROLLER = "dbgpt_model_controller"


class BaseComponet(LifeCycle, ABC):
"""Abstract Base Component class. All custom components should extend this."""

Expand Down Expand Up @@ -80,11 +89,21 @@ def register(self, componet: Type[BaseComponet], *args, **kwargs):

def register_instance(self, instance: T):
"""Register an already initialized component."""
self.componets[instance.name] = instance
name = instance.name
if isinstance(name, ComponetType):
name = name.value
if name in self.componets:
raise RuntimeError(
f"Componse name {name} already exists: {self.componets[name]}"
)
logger.info(f"Register componet with name {name} and instance: {instance}")
self.componets[name] = instance
instance.init_app(self)

def get_componet(self, name: str, componet_type: Type[T]) -> T:
def get_componet(self, name: Union[str, ComponetType], componet_type: Type[T]) -> T:
"""Retrieve a registered component by its name and type."""
if isinstance(name, ComponetType):
name = name.value
component = self.componets.get(name)
if not component:
raise ValueError(f"No component found with name {name}")
Expand Down
3 changes: 3 additions & 0 deletions pilot/configs/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ def get_device() -> str:
# (Llama2 based) We only support WizardLM-13B-V1.2 for now, which is trained from Llama-2 13b, see https://huggingface.co/WizardLM/WizardLM-13B-V1.2
"wizardlm-13b": os.path.join(MODEL_PATH, "WizardLM-13B-V1.2"),
"llama-cpp": os.path.join(MODEL_PATH, "ggml-model-q4_0.bin"),
# https://huggingface.co/internlm/internlm-chat-7b-v1_1, 7b vs 7b-v1.1: https://github.com/InternLM/InternLM/issues/288
"internlm-7b": os.path.join(MODEL_PATH, "internlm-chat-7b-v1_1"),
"internlm-7b-8k": os.path.join(MODEL_PATH, "internlm-chat-7b-8k"),
}

EMBEDDING_MODEL_CONFIG = {
Expand Down
24 changes: 24 additions & 0 deletions pilot/model/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,29 @@ def loader(self, model_path: str, from_pretrained_kwargs: dict):
return model, tokenizer


class InternLMAdapter(BaseLLMAdaper):
"""The model adapter for internlm/internlm-chat-7b"""

def match(self, model_path: str):
return "internlm" in model_path.lower()

def loader(self, model_path: str, from_pretrained_kwargs: dict):
revision = from_pretrained_kwargs.get("revision", "main")
model = AutoModelForCausalLM.from_pretrained(
model_path,
low_cpu_mem_usage=True,
trust_remote_code=True,
**from_pretrained_kwargs,
)
model = model.eval()
if "8k" in model_path.lower():
model.config.max_sequence_length = 8192
tokenizer = AutoTokenizer.from_pretrained(
model_path, use_fast=False, trust_remote_code=True, revision=revision
)
return model, tokenizer


register_llm_model_adapters(VicunaLLMAdapater)
register_llm_model_adapters(ChatGLMAdapater)
register_llm_model_adapters(GuanacoAdapter)
Expand All @@ -421,6 +444,7 @@ def loader(self, model_path: str, from_pretrained_kwargs: dict):
register_llm_model_adapters(BaichuanAdapter)
register_llm_model_adapters(WizardLMAdapter)
register_llm_model_adapters(LlamaCppAdapater)
register_llm_model_adapters(InternLMAdapter)
# TODO Default support vicuna, other model need to tests and Evaluate

# just for test_py, remove this later
Expand Down
16 changes: 11 additions & 5 deletions pilot/model/cluster/controller/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import List

from fastapi import APIRouter, FastAPI
from pilot.componet import BaseComponet, ComponetType, SystemApp
from pilot.model.base import ModelInstance
from pilot.model.parameter import ModelControllerParameters
from pilot.model.cluster.registry import EmbeddedModelRegistry, ModelRegistry
Expand All @@ -14,7 +15,12 @@
)


class BaseModelController(ABC):
class BaseModelController(BaseComponet, ABC):
name = ComponetType.MODEL_CONTROLLER

def init_app(self, system_app: SystemApp):
pass

@abstractmethod
async def register_instance(self, instance: ModelInstance) -> bool:
"""Register a given model instance"""
Expand All @@ -25,7 +31,7 @@ async def deregister_instance(self, instance: ModelInstance) -> bool:

@abstractmethod
async def get_all_instances(
self, model_name: str, healthy_only: bool = False
self, model_name: str = None, healthy_only: bool = False
) -> List[ModelInstance]:
"""Fetch all instances of a given model. Optionally, fetch only the healthy instances."""

Expand All @@ -51,7 +57,7 @@ async def deregister_instance(self, instance: ModelInstance) -> bool:
return await self.registry.deregister_instance(instance)

async def get_all_instances(
self, model_name: str, healthy_only: bool = False
self, model_name: str = None, healthy_only: bool = False
) -> List[ModelInstance]:
logging.info(
f"Get all instances with {model_name}, healthy_only: {healthy_only}"
Expand Down Expand Up @@ -94,7 +100,7 @@ async def get_all_model_instances(self) -> List[ModelInstance]:

@sync_api_remote(path="/api/controller/models")
def sync_get_all_instances(
self, model_name: str, healthy_only: bool = False
self, model_name: str = None, healthy_only: bool = False
) -> List[ModelInstance]:
pass

Expand All @@ -110,7 +116,7 @@ async def deregister_instance(self, instance: ModelInstance) -> bool:
return await self.backend.deregister_instance(instance)

async def get_all_instances(
self, model_name: str, healthy_only: bool = False
self, model_name: str = None, healthy_only: bool = False
) -> List[ModelInstance]:
return await self.backend.get_all_instances(model_name, healthy_only)

Expand Down
19 changes: 19 additions & 0 deletions pilot/model/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
Conversation prompt templates.
TODO Using fastchat core package
"""

import dataclasses
Expand Down Expand Up @@ -366,4 +368,21 @@ def get_conv_template(name: str) -> Conversation:
)
)

# Internlm-chat template
register_conv_template(
Conversation(
name="internlm-chat",
system="A chat between a curious <|User|> and an <|Bot|>. The <|Bot|> gives helpful, detailed, and polite answers to the <|User|>'s questions.\n\n",
roles=("<|User|>", "<|Bot|>"),
messages=(),
offset=0,
sep_style=SeparatorStyle.CHATINTERN,
sep="<eoh>",
sep2="<eoa>",
stop_token_ids=[1, 103028],
stop_str="<eoa>",
)
)


# TODO Support other model conversation template
24 changes: 11 additions & 13 deletions pilot/openapi/api_v1/api_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import List
import tempfile

from pilot.componet import ComponetType
from pilot.openapi.api_view_model import (
Result,
ConversationVo,
Expand Down Expand Up @@ -352,20 +353,17 @@ async def chat_completions(dialogue: ConversationVo = Body()):
async def model_types(request: Request):
print(f"/controller/model/types")
try:
import httpx

async with httpx.AsyncClient() as client:
base_url = request.base_url
response = await client.get(
f"{base_url}api/controller/models?healthy_only=true",
)
types = set()
if response.status_code == 200:
models = json.loads(response.text)
for model in models:
worker_type = model["model_name"].split("@")[1]
if worker_type == "llm":
types.add(model["model_name"].split("@")[0])
from pilot.model.cluster.controller.controller import BaseModelController

controller = CFG.SYSTEM_APP.get_componet(
ComponetType.MODEL_CONTROLLER, BaseModelController
)
models = await controller.get_all_instances(healthy_only=True)
for model in models:
worker_name, worker_type = model.model_name.split("@")
if worker_type == "llm":
types.add(worker_name)
return Result.succ(list(types))

except Exception as e:
Expand Down
11 changes: 11 additions & 0 deletions pilot/server/chat_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,16 @@ def get_generate_stream_func(self, model_path: str):
return generate_stream


class InternLMChatAdapter(BaseChatAdpter):
"""The model adapter for internlm/internlm-chat-7b"""

def match(self, model_path: str):
return "internlm" in model_path.lower()

def get_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("internlm-chat")


register_llm_model_chat_adapter(VicunaChatAdapter)
register_llm_model_chat_adapter(ChatGLMChatAdapter)
register_llm_model_chat_adapter(GuanacoChatAdapter)
Expand All @@ -257,6 +267,7 @@ def get_generate_stream_func(self, model_path: str):
register_llm_model_chat_adapter(BaichuanChatAdapter)
register_llm_model_chat_adapter(WizardLMChatAdapter)
register_llm_model_chat_adapter(LlamaCppChatAdapter)
register_llm_model_chat_adapter(InternLMChatAdapter)

# Proxy model for test and develop, it's cheap for us now.
register_llm_model_chat_adapter(ProxyllmChatAdapter)
Expand Down
2 changes: 2 additions & 0 deletions pilot/server/componet_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@

def initialize_componets(system_app: SystemApp, embedding_model_name: str):
from pilot.model.cluster import worker_manager
from pilot.model.cluster.controller.controller import controller

system_app.register(
RemoteEmbeddingFactory, worker_manager, model_name=embedding_model_name
)
system_app.register_instance(controller)


class RemoteEmbeddingFactory(EmbeddingFactory):
Expand Down

0 comments on commit 94c4f4a

Please sign in to comment.