Skip to content

Commit

Permalink
feat(model): Support internlm2.5 models (#1735)
Browse files Browse the repository at this point in the history
  • Loading branch information
fangyinc authored Jul 18, 2024
1 parent 083becd commit d389fdd
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 0 deletions.
2 changes: 2 additions & 0 deletions dbgpt/configs/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ def get_device() -> str:
"internlm-7b": os.path.join(MODEL_PATH, "internlm-chat-7b"),
"internlm-7b-8k": os.path.join(MODEL_PATH, "internlm-chat-7b-8k"),
"internlm-20b": os.path.join(MODEL_PATH, "internlm-chat-20b"),
"internlm2_5-7b-chat": os.path.join(MODEL_PATH, "internlm2_5-7b-chat"),
"internlm2_5-7b-chat-1m": os.path.join(MODEL_PATH, "internlm2_5-7b-chat-1m"),
"codellama-7b": os.path.join(MODEL_PATH, "CodeLlama-7b-Instruct-hf"),
"codellama-7b-sql-sft": os.path.join(MODEL_PATH, "codellama-7b-sql-sft"),
"codellama-13b": os.path.join(MODEL_PATH, "CodeLlama-13b-Instruct-hf"),
Expand Down
21 changes: 21 additions & 0 deletions dbgpt/model/adapter/hf_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,26 @@ def load(self, model_path: str, from_pretrained_kwargs: dict):
return super().load(model_path, from_pretrained_kwargs)


class Internlm2Adapter(NewHFChatModelAdapter):
"""
https://huggingface.co/internlm/internlm2_5-7b-chat
"""

def do_match(self, lower_model_name_or_path: Optional[str] = None):
return (
lower_model_name_or_path
and "internlm2" in lower_model_name_or_path
and "chat" in lower_model_name_or_path
)

def load(self, model_path: str, from_pretrained_kwargs: dict):
if not from_pretrained_kwargs:
from_pretrained_kwargs = {}
if "trust_remote_code" not in from_pretrained_kwargs:
from_pretrained_kwargs["trust_remote_code"] = True
return super().load(model_path, from_pretrained_kwargs)


# The following code is used to register the model adapter
# The last registered model adapter is matched first
register_model_adapter(YiAdapter)
Expand All @@ -602,3 +622,4 @@ def load(self, model_path: str, from_pretrained_kwargs: dict):
register_model_adapter(GLM4Adapter)
register_model_adapter(Codegeex4Adapter)
register_model_adapter(Qwen2Adapter)
register_model_adapter(Internlm2Adapter)

0 comments on commit d389fdd

Please sign in to comment.