diff --git a/agentverse/llms/utils/token_counter.py b/agentverse/llms/utils/token_counter.py index b594011b7..94125ad73 100644 --- a/agentverse/llms/utils/token_counter.py +++ b/agentverse/llms/utils/token_counter.py @@ -8,7 +8,12 @@ def count_string_tokens(prompt: str = "", model: str = "gpt-3.5-turbo") -> int: - return len(tiktoken.encoding_for_model(model).encode(prompt)) + if model.startswith("gpt-3.5-turbo") or model.startswith("gpt-4"): + return len(tiktoken.encoding_for_model(model).encode(prompt)) + elif model.lower() in LOCAL_LLMS or model in LOCAL_LLMS: + from transformers import AutoTokenizer + encoding = AutoTokenizer.from_pretrained(LOCAL_LLMS_MAPPING[model.lower()]) + return len(encoding.encode(prompt)) def count_message_tokens(