From e033bdaa150b92bf5235e0528169f76b28fccc2c Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Wed, 10 Apr 2024 22:50:02 +0800 Subject: [PATCH] feat(model): Support moonshot proxy LLM --- .env.template | 4 ++ .mypy.ini | 3 + dbgpt/_private/config.py | 10 +++ dbgpt/configs/model_config.py | 2 + dbgpt/core/interface/llm.py | 28 ++++++++ dbgpt/model/adapter/proxy_adapter.py | 26 +++++++ dbgpt/model/proxy/__init__.py | 2 + dbgpt/model/proxy/llms/chatgpt.py | 10 +++ dbgpt/model/proxy/llms/moonshot.py | 101 +++++++++++++++++++++++++++ 9 files changed, 186 insertions(+) create mode 100644 dbgpt/model/proxy/llms/moonshot.py diff --git a/.env.template b/.env.template index f57e0c372..27e6c6b7c 100644 --- a/.env.template +++ b/.env.template @@ -214,6 +214,10 @@ TONGYI_PROXY_API_KEY={your-tongyi-sk} #YI_API_BASE=https://api.lingyiwanwu.com/v1 #YI_API_KEY={your-yi-api-key} +## Moonshot Proxyllm, https://platform.moonshot.cn/docs/ +# MOONSHOT_MODEL_VERSION=moonshot-v1-8k +# MOONSHOT_API_BASE=https://api.moonshot.cn/v1 +# MOONSHOT_API_KEY={your-moonshot-api-key} #*******************************************************************# diff --git a/.mypy.ini b/.mypy.ini index 5bfab1fce..d9dfd8233 100644 --- a/.mypy.ini +++ b/.mypy.ini @@ -94,3 +94,6 @@ ignore_missing_imports = True [mypy-unstructured.*] ignore_missing_imports = True + +[mypy-rich.*] +ignore_missing_imports = True diff --git a/dbgpt/_private/config.py b/dbgpt/_private/config.py index 280cf873b..6e41f565c 100644 --- a/dbgpt/_private/config.py +++ b/dbgpt/_private/config.py @@ -118,6 +118,16 @@ def __init__(self) -> None: os.environ["yi_proxyllm_proxy_api_base"] = os.getenv( "YI_API_BASE", "https://api.lingyiwanwu.com/v1" ) + # Moonshot proxy + self.moonshot_proxy_api_key = os.getenv("MOONSHOT_API_KEY") + if self.moonshot_proxy_api_key: + os.environ["moonshot_proxyllm_proxy_api_key"] = self.moonshot_proxy_api_key + os.environ["moonshot_proxyllm_proxyllm_backend"] = os.getenv( + "MOONSHOT_MODEL_VERSION", "moonshot-v1-8k" + ) + os.environ["moonshot_proxyllm_api_base"] = os.getenv( + "MOONSHOT_API_BASE", "https://api.moonshot.cn/v1" + ) self.proxy_server_url = os.getenv("PROXY_SERVER_URL") diff --git a/dbgpt/configs/model_config.py b/dbgpt/configs/model_config.py index ea8af4e63..3ff9a01a5 100644 --- a/dbgpt/configs/model_config.py +++ b/dbgpt/configs/model_config.py @@ -67,6 +67,8 @@ def get_device() -> str: "spark_proxyllm": "spark_proxyllm", # https://platform.lingyiwanwu.com/docs/ "yi_proxyllm": "yi_proxyllm", + # https://platform.moonshot.cn/docs/ + "moonshot_proxyllm": "moonshot_proxyllm", "llama-2-7b": os.path.join(MODEL_PATH, "Llama-2-7b-chat-hf"), "llama-2-13b": os.path.join(MODEL_PATH, "Llama-2-13b-chat-hf"), "llama-2-70b": os.path.join(MODEL_PATH, "Llama-2-70b-chat-hf"), diff --git a/dbgpt/core/interface/llm.py b/dbgpt/core/interface/llm.py index d87482af2..4d26ddb84 100644 --- a/dbgpt/core/interface/llm.py +++ b/dbgpt/core/interface/llm.py @@ -857,3 +857,31 @@ async def get_model_metadata(self, model: str) -> ModelMetadata: if not model_metadata: raise ValueError(f"Model {model} not found") return model_metadata + + def __call__(self, *args, **kwargs) -> ModelOutput: + """Return the model output. + + Call the LLM client to generate the response for the given message. + + Please do not use this method in the production environment, it is only used + for debugging. + """ + from dbgpt.util import get_or_create_event_loop + + messages = kwargs.get("messages") + model = kwargs.get("model") + if messages: + del kwargs["messages"] + model_messages = ModelMessage.from_openai_messages(messages) + else: + model_messages = [ModelMessage.build_human_message(args[0])] + if not model: + if hasattr(self, "default_model"): + model = getattr(self, "default_model") + else: + raise ValueError("The default model is not set") + if "model" in kwargs: + del kwargs["model"] + req = ModelRequest.build_request(model, model_messages, **kwargs) + loop = get_or_create_event_loop() + return loop.run_until_complete(self.generate(req)) diff --git a/dbgpt/model/adapter/proxy_adapter.py b/dbgpt/model/adapter/proxy_adapter.py index 3da91eb0d..42c4480f6 100644 --- a/dbgpt/model/adapter/proxy_adapter.py +++ b/dbgpt/model/adapter/proxy_adapter.py @@ -252,6 +252,31 @@ def get_async_generate_stream_function(self, model, model_path: str): return yi_generate_stream +class MoonshotProxyLLMModelAdapter(ProxyLLMModelAdapter): + """Moonshot proxy LLM model adapter. + + See Also: `Moonshot Documentation `_ + """ + + def support_async(self) -> bool: + return True + + def do_match(self, lower_model_name_or_path: Optional[str] = None): + return lower_model_name_or_path in ["moonshot_proxyllm"] + + def get_llm_client_class( + self, params: ProxyModelParameters + ) -> Type[ProxyLLMClient]: + from dbgpt.model.proxy.llms.moonshot import MoonshotLLMClient + + return MoonshotLLMClient + + def get_async_generate_stream_function(self, model, model_path: str): + from dbgpt.model.proxy.llms.moonshot import moonshot_generate_stream + + return moonshot_generate_stream + + register_model_adapter(OpenAIProxyLLMModelAdapter) register_model_adapter(TongyiProxyLLMModelAdapter) register_model_adapter(ZhipuProxyLLMModelAdapter) @@ -261,3 +286,4 @@ def get_async_generate_stream_function(self, model, model_path: str): register_model_adapter(BardProxyLLMModelAdapter) register_model_adapter(BaichuanProxyLLMModelAdapter) register_model_adapter(YiProxyLLMModelAdapter) +register_model_adapter(MoonshotProxyLLMModelAdapter) diff --git a/dbgpt/model/proxy/__init__.py b/dbgpt/model/proxy/__init__.py index 2412658a1..831456fbd 100644 --- a/dbgpt/model/proxy/__init__.py +++ b/dbgpt/model/proxy/__init__.py @@ -10,6 +10,7 @@ def __lazy_import(name): "WenxinLLMClient": "dbgpt.model.proxy.llms.wenxin", "ZhipuLLMClient": "dbgpt.model.proxy.llms.zhipu", "YiLLMClient": "dbgpt.model.proxy.llms.yi", + "MoonshotLLMClient": "dbgpt.model.proxy.llms.moonshot", } if name in module_path: @@ -31,4 +32,5 @@ def __getattr__(name): "WenxinLLMClient", "SparkLLMClient", "YiLLMClient", + "MoonshotLLMClient", ] diff --git a/dbgpt/model/proxy/llms/chatgpt.py b/dbgpt/model/proxy/llms/chatgpt.py index 515b928d6..604ee83c7 100755 --- a/dbgpt/model/proxy/llms/chatgpt.py +++ b/dbgpt/model/proxy/llms/chatgpt.py @@ -99,6 +99,8 @@ def __init__( ) from exc self._openai_version = metadata.version("openai") self._openai_less_then_v1 = not self._openai_version >= "1.0.0" + self.check_sdk_version(self._openai_version) + self._init_params = OpenAIParameters( api_type=api_type, api_base=api_base, @@ -141,6 +143,14 @@ def new_client( full_url=model_params.proxy_server_url, ) + def check_sdk_version(self, version: str) -> None: + """Check the sdk version of the client. + + Raises: + ValueError: If check failed. + """ + pass + @property def client(self) -> ClientType: if self._openai_less_then_v1: diff --git a/dbgpt/model/proxy/llms/moonshot.py b/dbgpt/model/proxy/llms/moonshot.py new file mode 100644 index 000000000..e4eac390a --- /dev/null +++ b/dbgpt/model/proxy/llms/moonshot.py @@ -0,0 +1,101 @@ +import os +from typing import TYPE_CHECKING, Any, Dict, Optional, Union, cast + +from dbgpt.core import ModelRequest, ModelRequestContext +from dbgpt.model.proxy.llms.proxy_model import ProxyModel + +from .chatgpt import OpenAILLMClient + +if TYPE_CHECKING: + from httpx._types import ProxiesTypes + from openai import AsyncAzureOpenAI, AsyncOpenAI + + ClientType = Union[AsyncAzureOpenAI, AsyncOpenAI] + +_MOONSHOT_DEFAULT_MODEL = "moonshot-v1-8k" + + +async def moonshot_generate_stream( + model: ProxyModel, tokenizer, params, device, context_len=2048 +): + client: MoonshotLLMClient = cast(MoonshotLLMClient, model.proxy_llm_client) + context = ModelRequestContext(stream=True, user_name=params.get("user_name")) + request = ModelRequest.build_request( + client.default_model, + messages=params["messages"], + temperature=params.get("temperature"), + context=context, + max_new_tokens=params.get("max_new_tokens"), + ) + async for r in client.generate_stream(request): + yield r + + +class MoonshotLLMClient(OpenAILLMClient): + """Moonshot LLM Client. + + Moonshot's API is compatible with OpenAI's API, so we inherit from OpenAILLMClient. + """ + + def __init__( + self, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + api_type: Optional[str] = None, + api_version: Optional[str] = None, + model: Optional[str] = _MOONSHOT_DEFAULT_MODEL, + proxies: Optional["ProxiesTypes"] = None, + timeout: Optional[int] = 240, + model_alias: Optional[str] = "moonshot_proxyllm", + context_length: Optional[int] = None, + openai_client: Optional["ClientType"] = None, + openai_kwargs: Optional[Dict[str, Any]] = None, + **kwargs, + ): + api_base = ( + api_base or os.getenv("MOONSHOT_API_BASE") or "https://api.moonshot.cn/v1" + ) + api_key = api_key or os.getenv("MOONSHOT_API_KEY") + model = model or _MOONSHOT_DEFAULT_MODEL + if not context_length: + if "128k" in model: + context_length = 1024 * 128 + elif "32k" in model: + context_length = 1024 * 32 + else: + # 8k + context_length = 1024 * 8 + + if not api_key: + raise ValueError( + "Moonshot API key is required, please set 'MOONSHOT_API_KEY' in " + "environment variable or pass it to the client." + ) + super().__init__( + api_key=api_key, + api_base=api_base, + api_type=api_type, + api_version=api_version, + model=model, + proxies=proxies, + timeout=timeout, + model_alias=model_alias, + context_length=context_length, + openai_client=openai_client, + openai_kwargs=openai_kwargs, + **kwargs, + ) + + def check_sdk_version(self, version: str) -> None: + if not version >= "1.0": + raise ValueError( + "Moonshot API requires openai>=1.0, please upgrade it by " + "`pip install --upgrade 'openai>=1.0'`" + ) + + @property + def default_model(self) -> str: + model = self._model + if not model: + model = _MOONSHOT_DEFAULT_MODEL + return model