Skip to content

Commit

Permalink
feat(model): Support moonshot proxy LLM
Browse files Browse the repository at this point in the history
  • Loading branch information
fangyinc committed Apr 10, 2024
1 parent d55c51a commit e033bda
Show file tree
Hide file tree
Showing 9 changed files with 186 additions and 0 deletions.
4 changes: 4 additions & 0 deletions .env.template
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,10 @@ TONGYI_PROXY_API_KEY={your-tongyi-sk}
#YI_API_BASE=https://api.lingyiwanwu.com/v1
#YI_API_KEY={your-yi-api-key}

## Moonshot Proxyllm, https://platform.moonshot.cn/docs/
# MOONSHOT_MODEL_VERSION=moonshot-v1-8k
# MOONSHOT_API_BASE=https://api.moonshot.cn/v1
# MOONSHOT_API_KEY={your-moonshot-api-key}


#*******************************************************************#
Expand Down
3 changes: 3 additions & 0 deletions .mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,6 @@ ignore_missing_imports = True

[mypy-unstructured.*]
ignore_missing_imports = True

[mypy-rich.*]
ignore_missing_imports = True
10 changes: 10 additions & 0 deletions dbgpt/_private/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,16 @@ def __init__(self) -> None:
os.environ["yi_proxyllm_proxy_api_base"] = os.getenv(
"YI_API_BASE", "https://api.lingyiwanwu.com/v1"
)
# Moonshot proxy
self.moonshot_proxy_api_key = os.getenv("MOONSHOT_API_KEY")
if self.moonshot_proxy_api_key:
os.environ["moonshot_proxyllm_proxy_api_key"] = self.moonshot_proxy_api_key
os.environ["moonshot_proxyllm_proxyllm_backend"] = os.getenv(
"MOONSHOT_MODEL_VERSION", "moonshot-v1-8k"
)
os.environ["moonshot_proxyllm_api_base"] = os.getenv(
"MOONSHOT_API_BASE", "https://api.moonshot.cn/v1"
)

self.proxy_server_url = os.getenv("PROXY_SERVER_URL")

Expand Down
2 changes: 2 additions & 0 deletions dbgpt/configs/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ def get_device() -> str:
"spark_proxyllm": "spark_proxyllm",
# https://platform.lingyiwanwu.com/docs/
"yi_proxyllm": "yi_proxyllm",
# https://platform.moonshot.cn/docs/
"moonshot_proxyllm": "moonshot_proxyllm",
"llama-2-7b": os.path.join(MODEL_PATH, "Llama-2-7b-chat-hf"),
"llama-2-13b": os.path.join(MODEL_PATH, "Llama-2-13b-chat-hf"),
"llama-2-70b": os.path.join(MODEL_PATH, "Llama-2-70b-chat-hf"),
Expand Down
28 changes: 28 additions & 0 deletions dbgpt/core/interface/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,3 +857,31 @@ async def get_model_metadata(self, model: str) -> ModelMetadata:
if not model_metadata:
raise ValueError(f"Model {model} not found")
return model_metadata

def __call__(self, *args, **kwargs) -> ModelOutput:
"""Return the model output.
Call the LLM client to generate the response for the given message.
Please do not use this method in the production environment, it is only used
for debugging.
"""
from dbgpt.util import get_or_create_event_loop

messages = kwargs.get("messages")
model = kwargs.get("model")
if messages:
del kwargs["messages"]
model_messages = ModelMessage.from_openai_messages(messages)
else:
model_messages = [ModelMessage.build_human_message(args[0])]
if not model:
if hasattr(self, "default_model"):
model = getattr(self, "default_model")
else:
raise ValueError("The default model is not set")
if "model" in kwargs:
del kwargs["model"]
req = ModelRequest.build_request(model, model_messages, **kwargs)
loop = get_or_create_event_loop()
return loop.run_until_complete(self.generate(req))
26 changes: 26 additions & 0 deletions dbgpt/model/adapter/proxy_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,31 @@ def get_async_generate_stream_function(self, model, model_path: str):
return yi_generate_stream


class MoonshotProxyLLMModelAdapter(ProxyLLMModelAdapter):
"""Moonshot proxy LLM model adapter.
See Also: `Moonshot Documentation <https://platform.moonshot.cn/docs/>`_
"""

def support_async(self) -> bool:
return True

def do_match(self, lower_model_name_or_path: Optional[str] = None):
return lower_model_name_or_path in ["moonshot_proxyllm"]

def get_llm_client_class(
self, params: ProxyModelParameters
) -> Type[ProxyLLMClient]:
from dbgpt.model.proxy.llms.moonshot import MoonshotLLMClient

return MoonshotLLMClient

def get_async_generate_stream_function(self, model, model_path: str):
from dbgpt.model.proxy.llms.moonshot import moonshot_generate_stream

return moonshot_generate_stream


register_model_adapter(OpenAIProxyLLMModelAdapter)
register_model_adapter(TongyiProxyLLMModelAdapter)
register_model_adapter(ZhipuProxyLLMModelAdapter)
Expand All @@ -261,3 +286,4 @@ def get_async_generate_stream_function(self, model, model_path: str):
register_model_adapter(BardProxyLLMModelAdapter)
register_model_adapter(BaichuanProxyLLMModelAdapter)
register_model_adapter(YiProxyLLMModelAdapter)
register_model_adapter(MoonshotProxyLLMModelAdapter)
2 changes: 2 additions & 0 deletions dbgpt/model/proxy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ def __lazy_import(name):
"WenxinLLMClient": "dbgpt.model.proxy.llms.wenxin",
"ZhipuLLMClient": "dbgpt.model.proxy.llms.zhipu",
"YiLLMClient": "dbgpt.model.proxy.llms.yi",
"MoonshotLLMClient": "dbgpt.model.proxy.llms.moonshot",
}

if name in module_path:
Expand All @@ -31,4 +32,5 @@ def __getattr__(name):
"WenxinLLMClient",
"SparkLLMClient",
"YiLLMClient",
"MoonshotLLMClient",
]
10 changes: 10 additions & 0 deletions dbgpt/model/proxy/llms/chatgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ def __init__(
) from exc
self._openai_version = metadata.version("openai")
self._openai_less_then_v1 = not self._openai_version >= "1.0.0"
self.check_sdk_version(self._openai_version)

self._init_params = OpenAIParameters(
api_type=api_type,
api_base=api_base,
Expand Down Expand Up @@ -141,6 +143,14 @@ def new_client(
full_url=model_params.proxy_server_url,
)

def check_sdk_version(self, version: str) -> None:
"""Check the sdk version of the client.
Raises:
ValueError: If check failed.
"""
pass

@property
def client(self) -> ClientType:
if self._openai_less_then_v1:
Expand Down
101 changes: 101 additions & 0 deletions dbgpt/model/proxy/llms/moonshot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import os
from typing import TYPE_CHECKING, Any, Dict, Optional, Union, cast

from dbgpt.core import ModelRequest, ModelRequestContext
from dbgpt.model.proxy.llms.proxy_model import ProxyModel

from .chatgpt import OpenAILLMClient

if TYPE_CHECKING:
from httpx._types import ProxiesTypes
from openai import AsyncAzureOpenAI, AsyncOpenAI

ClientType = Union[AsyncAzureOpenAI, AsyncOpenAI]

_MOONSHOT_DEFAULT_MODEL = "moonshot-v1-8k"


async def moonshot_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):
client: MoonshotLLMClient = cast(MoonshotLLMClient, model.proxy_llm_client)
context = ModelRequestContext(stream=True, user_name=params.get("user_name"))
request = ModelRequest.build_request(
client.default_model,
messages=params["messages"],
temperature=params.get("temperature"),
context=context,
max_new_tokens=params.get("max_new_tokens"),
)
async for r in client.generate_stream(request):
yield r


class MoonshotLLMClient(OpenAILLMClient):
"""Moonshot LLM Client.
Moonshot's API is compatible with OpenAI's API, so we inherit from OpenAILLMClient.
"""

def __init__(
self,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
api_type: Optional[str] = None,
api_version: Optional[str] = None,
model: Optional[str] = _MOONSHOT_DEFAULT_MODEL,
proxies: Optional["ProxiesTypes"] = None,
timeout: Optional[int] = 240,
model_alias: Optional[str] = "moonshot_proxyllm",
context_length: Optional[int] = None,
openai_client: Optional["ClientType"] = None,
openai_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
):
api_base = (
api_base or os.getenv("MOONSHOT_API_BASE") or "https://api.moonshot.cn/v1"
)
api_key = api_key or os.getenv("MOONSHOT_API_KEY")
model = model or _MOONSHOT_DEFAULT_MODEL
if not context_length:
if "128k" in model:
context_length = 1024 * 128
elif "32k" in model:
context_length = 1024 * 32
else:
# 8k
context_length = 1024 * 8

if not api_key:
raise ValueError(
"Moonshot API key is required, please set 'MOONSHOT_API_KEY' in "
"environment variable or pass it to the client."
)
super().__init__(
api_key=api_key,
api_base=api_base,
api_type=api_type,
api_version=api_version,
model=model,
proxies=proxies,
timeout=timeout,
model_alias=model_alias,
context_length=context_length,
openai_client=openai_client,
openai_kwargs=openai_kwargs,
**kwargs,
)

def check_sdk_version(self, version: str) -> None:
if not version >= "1.0":
raise ValueError(
"Moonshot API requires openai>=1.0, please upgrade it by "
"`pip install --upgrade 'openai>=1.0'`"
)

@property
def default_model(self) -> str:
model = self._model
if not model:
model = _MOONSHOT_DEFAULT_MODEL
return model

0 comments on commit e033bda

Please sign in to comment.