Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add ability to configure OpenAI base URL in ChatGPTAgentConfig #577

Merged
merged 8 commits into from
Jul 3, 2024
13 changes: 13 additions & 0 deletions tests/streaming/agent/test_base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,16 @@ async def test_action_response_agent_input(mocker: MockerFixture):
# TODO: assert that the canned response is optionally sent if the action is not quiet
# and that it goes through the normal flow when the action is not quiet
pass

@pytest.fixture
def agent_config():
return ChatGPTAgentConfig(
openai_api_key="test_key",
model_name="llama3-8b-8192",
base_url_override="https://api.groq.com/openai/v1",
prompt_preamble="Test prompt",
)

def test_chat_gpt_agent_base_url(agent_config):
agent = ChatGPTAgent(agent_config)
assert agent.openai_client.base_url == "https://api.groq.com/openai/v1"
ajar98 marked this conversation as resolved.
Show resolved Hide resolved
4 changes: 3 additions & 1 deletion vocode/streaming/agent/chat_gpt_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,11 @@ def instantiate_openai_client(agent_config: ChatGPTAgentConfig, model_fallback:
else:
if agent_config.openai_api_key is not None:
logger.info("Using OpenAI API key override")
if agent_config.base_url_override is not None:
logger.info(f"Using OpenAI base URL override: {agent_config.base_url_override}")
return AsyncOpenAI(
api_key=agent_config.openai_api_key or os.environ["OPENAI_API_KEY"],
base_url="https://api.openai.com/v1",
base_url=agent_config.base_url_override or "https://api.openai.com/v1",
max_retries=0 if model_fallback else OPENAI_DEFAULT_MAX_RETRIES,
)

Expand Down
10 changes: 9 additions & 1 deletion vocode/streaming/agent/token_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def get_tokenizer_info(model: str) -> Optional[TokenizerInfo]:
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
logger.warning("Warning: model not found. Using cl100k_base encoding.")
logger.warning(f"Warning: model not found. Using cl100k_base encoding for {model}.")
encoding = tiktoken.get_encoding("cl100k_base")
if model in {
"gpt-3.5-turbo-0613",
Expand All @@ -117,6 +117,14 @@ def get_tokenizer_info(model: str) -> Optional[TokenizerInfo]:
)
tokens_per_message = 3
tokens_per_name = 1
elif "llama" in model.lower():
logger.warning(
f"Warning: you are using a llama model with an OpenAI compatible endpoint. \
Llama models are not supported natively support for token counting in tiktoken. \
Using cl100k_base encoding for {model} as an APPROXIMATION of token usage."
)
tokens_per_message = 3
tokens_per_name = 1
else:
return None

Expand Down
1 change: 1 addition & 0 deletions vocode/streaming/models/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ class ChatGPTAgentConfig(AgentConfig, type=AgentType.CHAT_GPT.value): # type: i
openai_api_key: Optional[str] = None
prompt_preamble: str
model_name: str = CHAT_GPT_AGENT_DEFAULT_MODEL_NAME
base_url_override: Optional[str] = None
temperature: float = LLM_AGENT_DEFAULT_TEMPERATURE
max_tokens: int = LLM_AGENT_DEFAULT_MAX_TOKENS
azure_params: Optional[AzureOpenAIConfig] = None
Expand Down
Loading