Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add ability to configure OpenAI base URL in ChatGPTAgentConfig #577

Merged
merged 8 commits into from
Jul 3, 2024
13 changes: 13 additions & 0 deletions tests/streaming/agent/test_base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,16 @@ async def test_action_response_agent_input(mocker: MockerFixture):
# TODO: assert that the canned response is optionally sent if the action is not quiet
# and that it goes through the normal flow when the action is not quiet
pass

@pytest.fixture
def agent_config():
return ChatGPTAgentConfig(
openai_api_key="test_key",
model_name="llama3-8b-8192",
base_url="https://api.groq.com/openai/v1",
prompt_preamble="Test prompt",
)

def test_chat_gpt_agent_base_url(agent_config):
agent = ChatGPTAgent(agent_config)
assert agent.openai_client.base_url == "https://api.groq.com/openai/v1"
ajar98 marked this conversation as resolved.
Show resolved Hide resolved
4 changes: 3 additions & 1 deletion vocode/streaming/agent/chat_gpt_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,11 @@ def instantiate_openai_client(agent_config: ChatGPTAgentConfig, model_fallback:
else:
if agent_config.openai_api_key is not None:
logger.info("Using OpenAI API key override")
if agent_config.base_url is not None:
logger.info(f"Using OpenAI base URL override: {agent_config.base_url}")
ajar98 marked this conversation as resolved.
Show resolved Hide resolved
return AsyncOpenAI(
api_key=agent_config.openai_api_key or os.environ["OPENAI_API_KEY"],
base_url="https://api.openai.com/v1",
base_url=agent_config.base_url or "https://api.openai.com/v1",
max_retries=0 if model_fallback else OPENAI_DEFAULT_MAX_RETRIES,
)

Expand Down
11 changes: 10 additions & 1 deletion vocode/streaming/agent/token_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def get_tokenizer_info(model: str) -> Optional[TokenizerInfo]:
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
logger.warning("Warning: model not found. Using cl100k_base encoding.")
logger.warning(f"Warning: model not found. Using cl100k_base encoding for {model}.")
encoding = tiktoken.get_encoding("cl100k_base")
if model in {
"gpt-3.5-turbo-0613",
Expand All @@ -117,6 +117,14 @@ def get_tokenizer_info(model: str) -> Optional[TokenizerInfo]:
)
tokens_per_message = 3
tokens_per_name = 1
elif "llama" in model.lower():
logger.warning(
f"Warning: you are using a llama model with an OpenAI compatible endpoint. \
Llama models are not supported natively support for token counting in tiktoken. \
Using cl100k_base encoding for {model} as an APPROXIMATION of token usage."
)
tokens_per_message = 3
tokens_per_name = 1
else:
return None

Expand Down Expand Up @@ -250,3 +258,4 @@ def format_default(schema):
result += "_: " + formatted
result += ") => any;\n\n"
return result

1 change: 1 addition & 0 deletions vocode/streaming/models/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ class ChatGPTAgentConfig(AgentConfig, type=AgentType.CHAT_GPT.value): # type: i
openai_api_key: Optional[str] = None
prompt_preamble: str
model_name: str = CHAT_GPT_AGENT_DEFAULT_MODEL_NAME
base_url: Optional[str] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's either name this base_url_override or default it to "https://api.openai.com/v1" — i'd prefer the former since it changes the code less

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @ajar98 nice to meet you! Thanks for looking over this.

I just pushed some changes to go to the override option. Let me know if anything else needs to get changed!

temperature: float = LLM_AGENT_DEFAULT_TEMPERATURE
max_tokens: int = LLM_AGENT_DEFAULT_MAX_TOKENS
azure_params: Optional[AzureOpenAIConfig] = None
Expand Down