diff --git a/dbgpt/model/utils/chatgpt_utils.py b/dbgpt/model/utils/chatgpt_utils.py index b6f0e1643..427d672e4 100644 --- a/dbgpt/model/utils/chatgpt_utils.py +++ b/dbgpt/model/utils/chatgpt_utils.py @@ -37,6 +37,7 @@ class OpenAIParameters: api_base: Optional[str] = None api_key: Optional[str] = None api_version: Optional[str] = None + api_azure_deployment: Optional[str] = None full_url: Optional[str] = None proxies: Optional["ProxiesTypes"] = None @@ -70,7 +71,7 @@ def _initialize_openai_v1(init_params: OpenAIParameters): os.getenv("AZURE_OPENAI_KEY") if api_type == "azure" else None, ) api_version = api_version or os.getenv("OPENAI_API_VERSION") - + api_azure_deployment = init_params.api_azure_deployment or os.getenv("API_AZURE_DEPLOYMENT") if not base_url and full_url: base_url = full_url.split("/chat/completions")[0] @@ -84,6 +85,7 @@ def _initialize_openai_v1(init_params: OpenAIParameters): openai_params = { "api_key": api_key, "base_url": base_url, + "api_azure_deployment": api_azure_deployment } return openai_params, api_type, api_version @@ -109,6 +111,8 @@ def _initialize_openai(params: OpenAIParameters): ) api_version = params.api_version or os.getenv("OPENAI_API_VERSION") + api_azure_deployment = params.api_azure_deployment or os.getenv("API_AZURE_DEPLOYMENT") + if not api_base and params.full_url: # Adapt previous proxy_server_url configuration api_base = params.full_url.split("/chat/completions")[0] @@ -122,6 +126,8 @@ def _initialize_openai(params: OpenAIParameters): openai.api_version = api_version if params.proxies: openai.proxy = params.proxies + if params.api_azure_deployment: + openai.api_azure_deployment =api_azure_deployment def _build_openai_client(init_params: OpenAIParameters) -> Tuple[str, ClientType]: @@ -134,6 +140,9 @@ def _build_openai_client(init_params: OpenAIParameters) -> Tuple[str, ClientType return api_type, AsyncAzureOpenAI( api_key=openai_params["api_key"], api_version=api_version, + #azure_deployment="siasmodel", + azure_deployment=openai_params["api_azure_deployment"], + # model_name="gpt-35-turbo", azure_endpoint=openai_params["base_url"], http_client=httpx.AsyncClient(proxies=init_params.proxies), )