Skip to content

Commit

Permalink
feat: Support openchat-3.5-1210
Browse files Browse the repository at this point in the history
  • Loading branch information
fangyinc committed Dec 21, 2023
1 parent 2f14fbf commit 09f3c7f
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 4 deletions.
4 changes: 3 additions & 1 deletion dbgpt/configs/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,9 @@ def get_device() -> str:
# https://huggingface.co/microsoft/Orca-2-13b
"orca-2-13b": os.path.join(MODEL_PATH, "Orca-2-13b"),
# https://huggingface.co/openchat/openchat_3.5
"openchat_3.5": os.path.join(MODEL_PATH, "openchat_3.5"),
"openchat-3.5": os.path.join(MODEL_PATH, "openchat_3.5"),
# https://huggingface.co/openchat/openchat-3.5-1210
"openchat-3.5-1210": os.path.join(MODEL_PATH, "openchat-3.5-1210"),
# https://huggingface.co/hfl/chinese-alpaca-2-7b
"chinese-alpaca-2-7b": os.path.join(MODEL_PATH, "chinese-alpaca-2-7b"),
# https://huggingface.co/hfl/chinese-alpaca-2-13b
Expand Down
4 changes: 3 additions & 1 deletion dbgpt/model/adapter/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ def transform_model_messages(
Returns:
List[Dict[str, str]]: The transformed model messages
"""
logger.info(f"support_system_message: {self.support_system_message}")
if not self.support_system_message:
return self._transform_to_no_system_messages(messages)
else:
Expand Down Expand Up @@ -432,4 +433,5 @@ def get_model_adapter(
new_adapter.model_path = model_path
if conv_factory:
new_adapter.conv_factory = conv_factory
return adapter
return new_adapter
return None
10 changes: 8 additions & 2 deletions dbgpt/model/adapter/hf_adapter.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from abc import ABC, abstractmethod
from typing import Dict, Optional, List, Any
import logging

from dbgpt.core import ModelMessage
from dbgpt.model.base import ModelType
from dbgpt.model.adapter.base import LLMModelAdapter, register_model_adapter

logger = logging.getLogger(__name__)

class NewHFChatModelAdapter(LLMModelAdapter):

class NewHFChatModelAdapter(LLMModelAdapter, ABC):
"""Model adapter for new huggingface chat models
See https://huggingface.co/docs/transformers/main/en/chat_templating
Expand All @@ -31,6 +35,7 @@ def match(
model_path = model_path.lower() if model_path else None
return self.do_match(model_name) or self.do_match(model_path)

@abstractmethod
def do_match(self, lower_model_name_or_path: Optional[str] = None):
raise NotImplementedError()

Expand Down Expand Up @@ -89,7 +94,8 @@ def get_str_prompt(
raise ValueError("tokenizer is is None")
tokenizer: AutoTokenizer = tokenizer

messages = ModelMessage.to_openai_messages(messages)
messages = self.transform_model_messages(messages)
logger.debug(f"The messages after transform: \n{messages}")
str_prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
Expand Down
1 change: 1 addition & 0 deletions dbgpt/model/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ def _try_load_default_quantization_model(
return _handle_model_and_tokenizer(
model, tokenizer, device, num_gpus, model_params
)
return None, None
except Exception as e:
logger.warning(
f"Load default quantization model {model_params.model_name} failed, error: {str(e)}"
Expand Down

0 comments on commit 09f3c7f

Please sign in to comment.