Skip to content

Commit

Permalink
feat(model): Support claude proxy models (#2155)
Browse files Browse the repository at this point in the history
  • Loading branch information
fangyinc authored Nov 26, 2024
1 parent 9d8673a commit 61509dc
Show file tree
Hide file tree
Showing 20 changed files with 508 additions and 157 deletions.
9 changes: 9 additions & 0 deletions dbgpt/_private/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,15 @@ def __init__(self) -> None:
os.environ["deepseek_proxyllm_api_base"] = os.getenv(
"DEEPSEEK_API_BASE", "https://api.deepseek.com/v1"
)
self.claude_proxy_api_key = os.getenv("ANTHROPIC_API_KEY")
if self.claude_proxy_api_key:
os.environ["claude_proxyllm_proxy_api_key"] = self.claude_proxy_api_key
os.environ["claude_proxyllm_proxyllm_backend"] = os.getenv(
"ANTHROPIC_MODEL_VERSION", "claude-3-5-sonnet-20241022"
)
os.environ["claude_proxyllm_api_base"] = os.getenv(
"ANTHROPIC_BASE_URL", "https://api.anthropic.com"
)

self.proxy_server_url = os.getenv("PROXY_SERVER_URL")

Expand Down
76 changes: 71 additions & 5 deletions dbgpt/core/interface/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import time
from abc import ABC, abstractmethod
from dataclasses import asdict, dataclass, field
from typing import Any, AsyncIterator, Dict, List, Optional, Union
from typing import Any, AsyncIterator, Coroutine, Dict, List, Optional, Tuple, Union

from cachetools import TTLCache

Expand Down Expand Up @@ -394,6 +394,29 @@ def messages_to_string(self) -> str:
"""
return ModelMessage.messages_to_string(self.get_messages())

def split_messages(self) -> Tuple[List[Dict[str, Any]], List[str]]:
"""Split the messages.
Returns:
Tuple[List[Dict[str, Any]], List[str]]: The common messages and system
messages.
"""
messages = self.get_messages()
common_messages = []
system_messages = []
for message in messages:
if message.role == ModelMessageRoleType.HUMAN:
common_messages.append({"role": "user", "content": message.content})
elif message.role == ModelMessageRoleType.SYSTEM:
system_messages.append(message.content)
elif message.role == ModelMessageRoleType.AI:
common_messages.append(
{"role": "assistant", "content": message.content}
)
else:
pass
return common_messages, system_messages


@dataclass
class ModelExtraMedata(BaseParameters):
Expand Down Expand Up @@ -861,30 +884,73 @@ async def get_model_metadata(self, model: str) -> ModelMetadata:
raise ValueError(f"Model {model} not found")
return model_metadata

def __call__(self, *args, **kwargs) -> ModelOutput:
def __call__(
self, *args, **kwargs
) -> Coroutine[Any, Any, ModelOutput] | ModelOutput:
"""Return the model output.
Call the LLM client to generate the response for the given message.
Please do not use this method in the production environment, it is only used
for debugging.
"""
import asyncio

from dbgpt.util import get_or_create_event_loop

try:
# Check if we are in an event loop
loop = asyncio.get_running_loop()
# If we are in an event loop, use async call
if loop.is_running():
# Because we are in an async environment, but this is a sync method,
# we need to return a coroutine object for the caller to use await
return self.async_call(*args, **kwargs)
else:
loop = get_or_create_event_loop()
return loop.run_until_complete(self.async_call(*args, **kwargs))
except RuntimeError:
# If we are not in an event loop, use sync call
loop = get_or_create_event_loop()
return loop.run_until_complete(self.async_call(*args, **kwargs))

async def async_call(self, *args, **kwargs) -> ModelOutput:
"""Return the model output asynchronously.
Please do not use this method in the production environment, it is only used
for debugging.
"""
req = self._build_call_request(*args, **kwargs)
return await self.generate(req)

async def async_call_stream(self, *args, **kwargs) -> AsyncIterator[ModelOutput]:
"""Return the model output stream asynchronously.
Please do not use this method in the production environment, it is only used
for debugging.
"""
req = self._build_call_request(*args, **kwargs)
async for output in self.generate_stream(req): # type: ignore
yield output

def _build_call_request(self, *args, **kwargs) -> ModelRequest:
"""Build the model request for the call method."""
messages = kwargs.get("messages")
model = kwargs.get("model")

if messages:
del kwargs["messages"]
model_messages = ModelMessage.from_openai_messages(messages)
else:
model_messages = [ModelMessage.build_human_message(args[0])]

if not model:
if hasattr(self, "default_model"):
model = getattr(self, "default_model")
else:
raise ValueError("The default model is not set")

if "model" in kwargs:
del kwargs["model"]
req = ModelRequest.build_request(model, model_messages, **kwargs)
loop = get_or_create_event_loop()
return loop.run_until_complete(self.generate(req))

return ModelRequest.build_request(model, model_messages, **kwargs)
21 changes: 21 additions & 0 deletions dbgpt/model/adapter/proxy_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,26 @@ def get_async_generate_stream_function(self, model, model_path: str):
return chatgpt_generate_stream


class ClaudeProxyLLMModelAdapter(ProxyLLMModelAdapter):
def support_async(self) -> bool:
return True

def do_match(self, lower_model_name_or_path: Optional[str] = None):
return lower_model_name_or_path == "claude_proxyllm"

def get_llm_client_class(
self, params: ProxyModelParameters
) -> Type[ProxyLLMClient]:
from dbgpt.model.proxy.llms.claude import ClaudeLLMClient

return ClaudeLLMClient

def get_async_generate_stream_function(self, model, model_path: str):
from dbgpt.model.proxy.llms.claude import claude_generate_stream

return claude_generate_stream


class TongyiProxyLLMModelAdapter(ProxyLLMModelAdapter):
def do_match(self, lower_model_name_or_path: Optional[str] = None):
return lower_model_name_or_path == "tongyi_proxyllm"
Expand Down Expand Up @@ -320,6 +340,7 @@ def get_async_generate_stream_function(self, model, model_path: str):


register_model_adapter(OpenAIProxyLLMModelAdapter)
register_model_adapter(ClaudeProxyLLMModelAdapter)
register_model_adapter(TongyiProxyLLMModelAdapter)
register_model_adapter(OllamaLLMModelAdapter)
register_model_adapter(ZhipuProxyLLMModelAdapter)
Expand Down
6 changes: 4 additions & 2 deletions dbgpt/model/cluster/manager_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from concurrent.futures import Future
from dataclasses import dataclass
from datetime import datetime
from typing import Callable, Dict, Iterator, List, Optional
from typing import AsyncIterator, Callable, Dict, Iterator, List, Optional

from dbgpt.component import BaseComponent, ComponentType, SystemApp
from dbgpt.core import ModelMetadata, ModelOutput
Expand Down Expand Up @@ -113,7 +113,9 @@ async def model_shutdown(self, shutdown_req: WorkerStartupRequest):
"""Shutdown model instance"""

@abstractmethod
async def generate_stream(self, params: Dict, **kwargs) -> Iterator[ModelOutput]:
async def generate_stream(
self, params: Dict, **kwargs
) -> AsyncIterator[ModelOutput]:
"""Generate stream result, chat scene"""

@abstractmethod
Expand Down
8 changes: 5 additions & 3 deletions dbgpt/model/cluster/worker/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import traceback
from concurrent.futures import ThreadPoolExecutor
from dataclasses import asdict
from typing import Awaitable, Callable, Iterator
from typing import AsyncIterator, Awaitable, Callable, Iterator

from fastapi import APIRouter
from fastapi.responses import StreamingResponse
Expand Down Expand Up @@ -327,7 +327,7 @@ def _sync_get_model(self, params: Dict, worker_type: str = "llm") -> WorkerRunDa

async def generate_stream(
self, params: Dict, async_wrapper=None, **kwargs
) -> Iterator[ModelOutput]:
) -> AsyncIterator[ModelOutput]:
"""Generate stream result, chat scene"""
with root_tracer.start_span(
"WorkerManager.generate_stream", params.get("span_id")
Expand Down Expand Up @@ -693,7 +693,9 @@ def sync_select_one_instance(
worker_type, model_name, healthy_only
)

async def generate_stream(self, params: Dict, **kwargs) -> Iterator[ModelOutput]:
async def generate_stream(
self, params: Dict, **kwargs
) -> AsyncIterator[ModelOutput]:
async for output in self.worker_manager.generate_stream(params, **kwargs):
yield output

Expand Down
17 changes: 17 additions & 0 deletions dbgpt/model/proxy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,25 @@
"""Proxy models."""

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from dbgpt.model.proxy.llms.chatgpt import OpenAILLMClient
from dbgpt.model.proxy.llms.claude import ClaudeLLMClient
from dbgpt.model.proxy.llms.deepseek import DeepseekLLMClient
from dbgpt.model.proxy.llms.gemini import GeminiLLMClient
from dbgpt.model.proxy.llms.moonshot import MoonshotLLMClient
from dbgpt.model.proxy.llms.ollama import OllamaLLMClient
from dbgpt.model.proxy.llms.spark import SparkLLMClient
from dbgpt.model.proxy.llms.tongyi import TongyiLLMClient
from dbgpt.model.proxy.llms.wenxin import WenxinLLMClient
from dbgpt.model.proxy.llms.yi import YiLLMClient
from dbgpt.model.proxy.llms.zhipu import ZhipuLLMClient


def __lazy_import(name):
module_path = {
"OpenAILLMClient": "dbgpt.model.proxy.llms.chatgpt",
"ClaudeLLMClient": "dbgpt.model.proxy.llms.claude",
"GeminiLLMClient": "dbgpt.model.proxy.llms.gemini",
"SparkLLMClient": "dbgpt.model.proxy.llms.spark",
"TongyiLLMClient": "dbgpt.model.proxy.llms.tongyi",
Expand All @@ -28,6 +44,7 @@ def __getattr__(name):

__all__ = [
"OpenAILLMClient",
"ClaudeLLMClient",
"GeminiLLMClient",
"TongyiLLMClient",
"ZhipuLLMClient",
Expand Down
35 changes: 34 additions & 1 deletion dbgpt/model/proxy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,25 @@ def count_token(self, model_name: str, prompts: List[str]) -> List[int]:
List[int]: token count, -1 if failed
"""

def support_async(self) -> bool:
"""Check if the tokenizer supports asynchronous counting token.
Returns:
bool: True if supports, False otherwise
"""
return False

async def count_token_async(self, model_name: str, prompts: List[str]) -> List[int]:
"""Count token of given prompts asynchronously.
Args:
model_name (str): model name
prompts (List[str]): prompts to count token
Returns:
List[int]: token count, -1 if failed
"""
raise NotImplementedError()


class TiktokenProxyTokenizer(ProxyTokenizer):
def __init__(self):
Expand Down Expand Up @@ -92,7 +111,7 @@ def __init__(
self.model_names = model_names
self.context_length = context_length
self.executor = executor or ThreadPoolExecutor()
self.proxy_tokenizer = proxy_tokenizer or TiktokenProxyTokenizer()
self._proxy_tokenizer = proxy_tokenizer

def __getstate__(self):
"""Customize the serialization of the object"""
Expand All @@ -105,6 +124,17 @@ def __setstate__(self, state):
self.__dict__.update(state)
self.executor = ThreadPoolExecutor()

@property
def proxy_tokenizer(self) -> ProxyTokenizer:
"""Get proxy tokenizer
Returns:
ProxyTokenizer: proxy tokenizer
"""
if not self._proxy_tokenizer:
self._proxy_tokenizer = TiktokenProxyTokenizer()
return self._proxy_tokenizer

@classmethod
@abstractmethod
def new_client(
Expand Down Expand Up @@ -257,6 +287,9 @@ async def count_token(self, model: str, prompt: str) -> int:
Returns:
int: token count, -1 if failed
"""
if self.proxy_tokenizer.support_async():
cnts = await self.proxy_tokenizer.count_token_async(model, [prompt])
return cnts[0]
counts = await blocking_func_to_async(
self.executor, self.proxy_tokenizer.count_token, model, [prompt]
)
Expand Down
22 changes: 5 additions & 17 deletions dbgpt/model/proxy/llms/chatgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,11 @@
from concurrent.futures import Executor
from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, List, Optional, Union

from dbgpt.core import (
MessageConverter,
ModelMetadata,
ModelOutput,
ModelRequest,
ModelRequestContext,
)
from dbgpt.core import MessageConverter, ModelMetadata, ModelOutput, ModelRequest
from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource
from dbgpt.model.parameter import ProxyModelParameters
from dbgpt.model.proxy.base import ProxyLLMClient
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
from dbgpt.model.proxy.llms.proxy_model import ProxyModel, parse_model_request
from dbgpt.model.utils.chatgpt_utils import OpenAIParameters
from dbgpt.util.i18n_utils import _

Expand All @@ -32,15 +26,7 @@ async def chatgpt_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):
client: OpenAILLMClient = model.proxy_llm_client
context = ModelRequestContext(stream=True, user_name=params.get("user_name"))
request = ModelRequest.build_request(
client.default_model,
messages=params["messages"],
temperature=params.get("temperature"),
context=context,
max_new_tokens=params.get("max_new_tokens"),
stop=params.get("stop"),
)
request = parse_model_request(params, client.default_model, stream=True)
async for r in client.generate_stream(request):
yield r

Expand Down Expand Up @@ -191,6 +177,8 @@ def _build_request(
payload["max_tokens"] = request.max_new_tokens
if request.stop:
payload["stop"] = request.stop
if request.top_p:
payload["top_p"] = request.top_p
return payload

async def generate(
Expand Down
Loading

0 comments on commit 61509dc

Please sign in to comment.