Skip to content

Commit

Permalink
Added support for Azure openAI and added API_AZURE_DEPLOYMENT environ…
Browse files Browse the repository at this point in the history
…ment variables to adapt to Azure openAI
  • Loading branch information
likenamehaojie authored Jan 30, 2024
1 parent 208d91d commit c6868aa
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion dbgpt/model/utils/chatgpt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class OpenAIParameters:
api_base: Optional[str] = None
api_key: Optional[str] = None
api_version: Optional[str] = None
api_azure_deployment: Optional[str] = None
full_url: Optional[str] = None
proxies: Optional["ProxiesTypes"] = None

Expand Down Expand Up @@ -70,7 +71,7 @@ def _initialize_openai_v1(init_params: OpenAIParameters):
os.getenv("AZURE_OPENAI_KEY") if api_type == "azure" else None,
)
api_version = api_version or os.getenv("OPENAI_API_VERSION")

api_azure_deployment = init_params.api_azure_deployment or os.getenv("API_AZURE_DEPLOYMENT")
if not base_url and full_url:
base_url = full_url.split("/chat/completions")[0]

Expand All @@ -84,6 +85,7 @@ def _initialize_openai_v1(init_params: OpenAIParameters):
openai_params = {
"api_key": api_key,
"base_url": base_url,
"api_azure_deployment": api_azure_deployment
}
return openai_params, api_type, api_version

Expand All @@ -109,6 +111,8 @@ def _initialize_openai(params: OpenAIParameters):
)
api_version = params.api_version or os.getenv("OPENAI_API_VERSION")

api_azure_deployment = params.api_azure_deployment or os.getenv("API_AZURE_DEPLOYMENT")

if not api_base and params.full_url:
# Adapt previous proxy_server_url configuration
api_base = params.full_url.split("/chat/completions")[0]
Expand All @@ -122,6 +126,8 @@ def _initialize_openai(params: OpenAIParameters):
openai.api_version = api_version
if params.proxies:
openai.proxy = params.proxies
if params.api_azure_deployment:
openai.api_azure_deployment =api_azure_deployment


def _build_openai_client(init_params: OpenAIParameters) -> Tuple[str, ClientType]:
Expand All @@ -134,6 +140,9 @@ def _build_openai_client(init_params: OpenAIParameters) -> Tuple[str, ClientType
return api_type, AsyncAzureOpenAI(
api_key=openai_params["api_key"],
api_version=api_version,
#azure_deployment="siasmodel",
azure_deployment=openai_params["api_azure_deployment"],
# model_name="gpt-35-turbo",
azure_endpoint=openai_params["base_url"],
http_client=httpx.AsyncClient(proxies=init_params.proxies),
)
Expand Down

0 comments on commit c6868aa

Please sign in to comment.