diff --git a/tests/streaming/agent/test_base_agent.py b/tests/streaming/agent/test_base_agent.py index c6a0adb95..972040579 100644 --- a/tests/streaming/agent/test_base_agent.py +++ b/tests/streaming/agent/test_base_agent.py @@ -177,3 +177,18 @@ async def test_action_response_agent_input(mocker: MockerFixture): # TODO: assert that the canned response is optionally sent if the action is not quiet # and that it goes through the normal flow when the action is not quiet pass + + +@pytest.fixture +def agent_config(): + return ChatGPTAgentConfig( + openai_api_key="test_key", + model_name="llama3-8b-8192", + base_url_override="https://api.groq.com/openai/v1/", + prompt_preamble="Test prompt", + ) + + +def test_chat_gpt_agent_base_url(agent_config): + agent = ChatGPTAgent(agent_config) + assert str(agent.openai_client.base_url) == "https://api.groq.com/openai/v1/" diff --git a/vocode/streaming/agent/chat_gpt_agent.py b/vocode/streaming/agent/chat_gpt_agent.py index 047aa6c90..11e7c6586 100644 --- a/vocode/streaming/agent/chat_gpt_agent.py +++ b/vocode/streaming/agent/chat_gpt_agent.py @@ -39,9 +39,11 @@ def instantiate_openai_client(agent_config: ChatGPTAgentConfig, model_fallback: else: if agent_config.openai_api_key is not None: logger.info("Using OpenAI API key override") + if agent_config.base_url_override is not None: + logger.info(f"Using OpenAI base URL override: {agent_config.base_url_override}") return AsyncOpenAI( api_key=agent_config.openai_api_key or os.environ["OPENAI_API_KEY"], - base_url="https://api.openai.com/v1", + base_url=agent_config.base_url_override or "https://api.openai.com/v1", max_retries=0 if model_fallback else OPENAI_DEFAULT_MAX_RETRIES, ) diff --git a/vocode/streaming/agent/token_utils.py b/vocode/streaming/agent/token_utils.py index 7cd80f8cc..bf014b2b7 100644 --- a/vocode/streaming/agent/token_utils.py +++ b/vocode/streaming/agent/token_utils.py @@ -90,7 +90,7 @@ def get_tokenizer_info(model: str) -> Optional[TokenizerInfo]: try: encoding = tiktoken.encoding_for_model(model) except KeyError: - logger.warning("Warning: model not found. Using cl100k_base encoding.") + logger.warning(f"Warning: model not found. Using cl100k_base encoding for {model}.") encoding = tiktoken.get_encoding("cl100k_base") if model in { "gpt-3.5-turbo-0613", @@ -117,6 +117,14 @@ def get_tokenizer_info(model: str) -> Optional[TokenizerInfo]: ) tokens_per_message = 3 tokens_per_name = 1 + elif "llama" in model.lower(): + logger.warning( + f"Warning: you are using a llama model with an OpenAI compatible endpoint. \ + Llama models are not supported natively support for token counting in tiktoken. \ + Using cl100k_base encoding for {model} as an APPROXIMATION of token usage." + ) + tokens_per_message = 3 + tokens_per_name = 1 else: return None diff --git a/vocode/streaming/models/agent.py b/vocode/streaming/models/agent.py index b023e1e23..aa9b63c8f 100644 --- a/vocode/streaming/models/agent.py +++ b/vocode/streaming/models/agent.py @@ -122,6 +122,7 @@ class ChatGPTAgentConfig(AgentConfig, type=AgentType.CHAT_GPT.value): # type: i openai_api_key: Optional[str] = None prompt_preamble: str model_name: str = CHAT_GPT_AGENT_DEFAULT_MODEL_NAME + base_url_override: Optional[str] = None temperature: float = LLM_AGENT_DEFAULT_TEMPERATURE max_tokens: int = LLM_AGENT_DEFAULT_MAX_TOKENS azure_params: Optional[AzureOpenAIConfig] = None