Skip to content

Commit

Permalink
feat(model): Proxy model support count token (#996)
Browse files Browse the repository at this point in the history
  • Loading branch information
fangyinc authored Dec 29, 2023
1 parent ba0599e commit 0cdc77a
Show file tree
Hide file tree
Showing 16 changed files with 365 additions and 247 deletions.
10 changes: 8 additions & 2 deletions dbgpt/model/cluster/worker/default_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def generate(self, params: Dict) -> ModelOutput:
return output

def count_token(self, prompt: str) -> int:
return _try_to_count_token(prompt, self.tokenizer)
return _try_to_count_token(prompt, self.tokenizer, self.model)

async def async_count_token(self, prompt: str) -> int:
# TODO if we deploy the model by vllm, it can't work, we should run transformer _try_to_count_token to async
Expand Down Expand Up @@ -454,19 +454,25 @@ def _new_metrics_from_model_output(
return metrics


def _try_to_count_token(prompt: str, tokenizer) -> int:
def _try_to_count_token(prompt: str, tokenizer, model) -> int:
"""Try to count token of prompt
Args:
prompt (str): prompt
tokenizer ([type]): tokenizer
model ([type]): model
Returns:
int: token count, if error return -1
TODO: More implementation
"""
try:
from dbgpt.model.proxy.llms.proxy_model import ProxyModel

if isinstance(model, ProxyModel):
return model.count_token(prompt)
# Only support huggingface model now
return len(tokenizer(prompt).input_ids[0])
except Exception as e:
logger.warning(f"Count token error, detail: {e}, return -1")
Expand Down
4 changes: 2 additions & 2 deletions dbgpt/model/cluster/worker/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def add_worker(
return True
else:
# TODO Update worker
logger.warn(f"Instance {worker_key} exist")
logger.warning(f"Instance {worker_key} exist")
return False

def _remove_worker(self, worker_params: ModelWorkerParameters) -> None:
Expand Down Expand Up @@ -229,7 +229,7 @@ async def model_startup(self, startup_req: WorkerStartupRequest):
)
if not success:
msg = f"Add worker {model_name}@{worker_type}, worker instances is exist"
logger.warn(f"{msg}, worker_params: {worker_params}")
logger.warning(f"{msg}, worker_params: {worker_params}")
self._remove_worker(worker_params)
raise Exception(msg)
supported_types = WorkerType.values()
Expand Down
2 changes: 1 addition & 1 deletion dbgpt/model/proxy/llms/chatgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,11 @@ def _initialize_openai_v1(params: ProxyModelParameters):


def __convert_2_gpt_messages(messages: List[ModelMessage]):
chat_round = 0
gpt_messages = []
last_usr_message = ""
system_messages = []

# TODO: We can't change message order in low level
for message in messages:
if message.role == ModelMessageRoleType.HUMAN or message.role == "user":
last_usr_message = message.content
Expand Down
27 changes: 27 additions & 0 deletions dbgpt/model/proxy/llms/proxy_model.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,36 @@
from __future__ import annotations

from typing import Union, List, Optional, TYPE_CHECKING
import logging
from dbgpt.model.parameter import ProxyModelParameters
from dbgpt.model.utils.token_utils import ProxyTokenizerWrapper

if TYPE_CHECKING:
from dbgpt.core.interface.message import ModelMessage, BaseMessage

logger = logging.getLogger(__name__)


class ProxyModel:
def __init__(self, model_params: ProxyModelParameters) -> None:
self._model_params = model_params
self._tokenizer = ProxyTokenizerWrapper()

def get_params(self) -> ProxyModelParameters:
return self._model_params

def count_token(
self,
messages: Union[str, BaseMessage, ModelMessage, List[ModelMessage]],
model_name: Optional[int] = None,
) -> int:
"""Count token of given messages
Args:
messages (Union[str, BaseMessage, ModelMessage, List[ModelMessage]]): messages to count token
model_name (Optional[int], optional): model name. Defaults to None.
Returns:
int: token count, -1 if failed
"""
return self._tokenizer.count_token(messages, model_name)
9 changes: 6 additions & 3 deletions dbgpt/model/utils/chatgpt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from dbgpt.model.cluster.client import DefaultLLMClient
from dbgpt.model.cluster import WorkerManagerFactory
from dbgpt._private.pydantic import model_to_json
from dbgpt.model.utils.token_utils import ProxyTokenizerWrapper

if TYPE_CHECKING:
import httpx
Expand Down Expand Up @@ -152,6 +153,7 @@ def __init__(
self._context_length = context_length
self._client = openai_client
self._openai_kwargs = openai_kwargs or {}
self._tokenizer = ProxyTokenizerWrapper()

@property
def client(self) -> ClientType:
Expand Down Expand Up @@ -238,10 +240,11 @@ async def get_context_length(self) -> int:
async def count_token(self, model: str, prompt: str) -> int:
"""Count the number of tokens in a given prompt.
TODO: Get the real number of tokens from the openai api or tiktoken package
Args:
model (str): The model name.
prompt (str): The prompt.
"""

raise NotImplementedError()
return self._tokenizer.count_token(prompt, model)


class OpenAIStreamingOperator(TransformStreamAbsOperator[ModelOutput, str]):
Expand Down
80 changes: 80 additions & 0 deletions dbgpt/model/utils/token_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from __future__ import annotations

from typing import Union, List, Optional, TYPE_CHECKING
import logging

if TYPE_CHECKING:
from dbgpt.core.interface.message import ModelMessage, BaseMessage

logger = logging.getLogger(__name__)


class ProxyTokenizerWrapper:
def __init__(self) -> None:
self._support_encoding = True
self._encoding_model = None

def count_token(
self,
messages: Union[str, BaseMessage, ModelMessage, List[ModelMessage]],
model_name: Optional[str] = None,
) -> int:
"""Count token of given messages
Args:
messages (Union[str, BaseMessage, ModelMessage, List[ModelMessage]]): messages to count token
model_name (Optional[str], optional): model name. Defaults to None.
Returns:
int: token count, -1 if failed
"""
if not self._support_encoding:
logger.warning(
"model does not support encoding model, can't count token, returning -1"
)
return -1
encoding = self._get_or_create_encoding_model(model_name)
cnt = 0
if isinstance(messages, str):
cnt = len(encoding.encode(messages, disallowed_special=()))
elif isinstance(messages, BaseMessage):
cnt = len(encoding.encode(messages.content, disallowed_special=()))
elif isinstance(messages, ModelMessage):
cnt = len(encoding.encode(messages.content, disallowed_special=()))
elif isinstance(messages, list):
for message in messages:
cnt += len(encoding.encode(message.content, disallowed_special=()))
else:
logger.warning(
"unsupported type of messages, can't count token, returning -1"
)
return -1
return cnt

def _get_or_create_encoding_model(self, model_name: Optional[str] = None):
"""Get or create encoding model for given model name
More detail see: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
"""
if self._encoding_model:
return self._encoding_model
try:
import tiktoken

logger.info(
"tiktoken installed, using it to count tokens, tiktoken will download tokenizer from network, "
"also you can download it and put it in the directory of environment variable TIKTOKEN_CACHE_DIR"
)
except ImportError:
self._support_encoding = False
logger.warn("tiktoken not installed, cannot count tokens, returning -1")
return -1
try:
if not model_name:
model_name = "gpt-3.5-turbo"
self._encoding_model = tiktoken.model.encoding_for_model(model_name)
except KeyError:
logger.warning(
f"{model_name}'s tokenizer not found, using cl100k_base encoding."
)
self._encoding_model = tiktoken.get_encoding("cl100k_base")
return self._encoding_model
17 changes: 4 additions & 13 deletions dbgpt/serve/conversation/tests/test_models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import List

import pytest

from dbgpt.storage.metadata import db
Expand Down Expand Up @@ -39,11 +37,9 @@ def test_table_exist():


def test_entity_create(default_entity_dict):
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
# TODO: implement your test case
with db.session() as session:
db_entity: ServeEntity = session.query(ServeEntity).get(entity.id)
assert db_entity.id == entity.id
entity = ServeEntity(**default_entity_dict)
session.add(entity)


def test_entity_unique_key(default_entity_dict):
Expand All @@ -52,10 +48,8 @@ def test_entity_unique_key(default_entity_dict):


def test_entity_get(default_entity_dict):
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
db_entity: ServeEntity = ServeEntity.get(entity.id)
assert db_entity.id == entity.id
# TODO: implement your test case
pass


def test_entity_update(default_entity_dict):
Expand All @@ -65,10 +59,7 @@ def test_entity_update(default_entity_dict):

def test_entity_delete(default_entity_dict):
# TODO: implement your test case
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
entity.delete()
db_entity: ServeEntity = ServeEntity.get(entity.id)
assert db_entity is None
pass


def test_entity_all():
Expand Down
Loading

0 comments on commit 0cdc77a

Please sign in to comment.