Skip to content

Commit

Permalink
feat: Add ability to configure OpenAI base URL in ChatGPTAgentConfig (v…
Browse files Browse the repository at this point in the history
…ocodedev#577)

* feat: Add ability to configure OpenAI base URL in ChatGPTAgentConfig

- Added `base_url` parameter to `ChatGPTAgentConfig` to allow customization of the OpenAI API base URL.
- Updated `instantiate_openai_client` function to use the `base_url` parameter from the configuration.
- Modified `ChatGPTAgent` to utilize the updated `instantiate_openai_client` function.
- Added tests to verify the new `base_url` functionality in `tests/streaming/agent/test_base_agent.py`.

This enhancement allows users to specify a custom OpenAI API base URL, providing greater flexibility in agent configuration.

* adding capability to use the openai compatible endpoint with token estimation for llama

* lint fix

* changing openai base_url parameter for overall less code changes

* missed logging update

* Update vocode/streaming/agent/chat_gpt_agent.py

* Update tests/streaming/agent/test_base_agent.py

* fix test

---------

Co-authored-by: Ajay Raj <[email protected]>
  • Loading branch information
celmore25 and ajar98 authored Jul 3, 2024
1 parent af2849c commit 918412c
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 2 deletions.
15 changes: 15 additions & 0 deletions tests/streaming/agent/test_base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,18 @@ async def test_action_response_agent_input(mocker: MockerFixture):
# TODO: assert that the canned response is optionally sent if the action is not quiet
# and that it goes through the normal flow when the action is not quiet
pass


@pytest.fixture
def agent_config():
return ChatGPTAgentConfig(
openai_api_key="test_key",
model_name="llama3-8b-8192",
base_url_override="https://api.groq.com/openai/v1/",
prompt_preamble="Test prompt",
)


def test_chat_gpt_agent_base_url(agent_config):
agent = ChatGPTAgent(agent_config)
assert str(agent.openai_client.base_url) == "https://api.groq.com/openai/v1/"
4 changes: 3 additions & 1 deletion vocode/streaming/agent/chat_gpt_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,11 @@ def instantiate_openai_client(agent_config: ChatGPTAgentConfig, model_fallback:
else:
if agent_config.openai_api_key is not None:
logger.info("Using OpenAI API key override")
if agent_config.base_url_override is not None:
logger.info(f"Using OpenAI base URL override: {agent_config.base_url_override}")
return AsyncOpenAI(
api_key=agent_config.openai_api_key or os.environ["OPENAI_API_KEY"],
base_url="https://api.openai.com/v1",
base_url=agent_config.base_url_override or "https://api.openai.com/v1",
max_retries=0 if model_fallback else OPENAI_DEFAULT_MAX_RETRIES,
)

Expand Down
10 changes: 9 additions & 1 deletion vocode/streaming/agent/token_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def get_tokenizer_info(model: str) -> Optional[TokenizerInfo]:
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
logger.warning("Warning: model not found. Using cl100k_base encoding.")
logger.warning(f"Warning: model not found. Using cl100k_base encoding for {model}.")
encoding = tiktoken.get_encoding("cl100k_base")
if model in {
"gpt-3.5-turbo-0613",
Expand All @@ -117,6 +117,14 @@ def get_tokenizer_info(model: str) -> Optional[TokenizerInfo]:
)
tokens_per_message = 3
tokens_per_name = 1
elif "llama" in model.lower():
logger.warning(
f"Warning: you are using a llama model with an OpenAI compatible endpoint. \
Llama models are not supported natively support for token counting in tiktoken. \
Using cl100k_base encoding for {model} as an APPROXIMATION of token usage."
)
tokens_per_message = 3
tokens_per_name = 1
else:
return None

Expand Down
1 change: 1 addition & 0 deletions vocode/streaming/models/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ class ChatGPTAgentConfig(AgentConfig, type=AgentType.CHAT_GPT.value): # type: i
openai_api_key: Optional[str] = None
prompt_preamble: str
model_name: str = CHAT_GPT_AGENT_DEFAULT_MODEL_NAME
base_url_override: Optional[str] = None
temperature: float = LLM_AGENT_DEFAULT_TEMPERATURE
max_tokens: int = LLM_AGENT_DEFAULT_MAX_TOKENS
azure_params: Optional[AzureOpenAIConfig] = None
Expand Down

0 comments on commit 918412c

Please sign in to comment.