Skip to content

Commit

Permalink
fix: compile OpenAI 1.x.x version (#805)
Browse files Browse the repository at this point in the history
  • Loading branch information
Aries-ckt authored Nov 17, 2023
2 parents 1ad09c8 + ddabd01 commit 50b3f35
Showing 1 changed file with 20 additions and 17 deletions.
37 changes: 20 additions & 17 deletions pilot/model/proxy/llms/chatgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pilot.model.proxy.llms.proxy_model import ProxyModel
from pilot.model.parameter import ProxyModelParameters
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
import httpx

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -82,15 +83,12 @@ def _initialize_openai_v1(params: ProxyModelParameters):
# Adapt previous proxy_server_url configuration
base_url = params.proxy_server_url.split("/chat/completions")[0]

if params.http_proxy:
openai.proxies = params.http_proxy
proxies = params.http_proxy
openai_params = {
"api_key": api_key,
"base_url": base_url,
"proxies": params.http_proxy,
}

return openai_params, api_type, api_version
return openai_params, api_type, api_version, proxies


def _build_request(model: ProxyModel, params):
Expand Down Expand Up @@ -130,7 +128,9 @@ def _build_request(model: ProxyModel, params):
proxyllm_backend = model_params.proxyllm_backend

if metadata.version("openai") >= "1.0.0":
openai_params, api_type, api_version = _initialize_openai_v1(model_params)
openai_params, api_type, api_version, proxies = _initialize_openai_v1(
model_params
)
proxyllm_backend = proxyllm_backend or "gpt-3.5-turbo"
payloads["model"] = proxyllm_backend
else:
Expand All @@ -154,22 +154,23 @@ def chatgpt_generate_stream(
):
if metadata.version("openai") >= "1.0.0":
model_params = model.get_params()
openai_params, api_type, api_version = _initialize_openai_v1(model_params)
openai_params, api_type, api_version, proxies = _initialize_openai_v1(
model_params
)
history, payloads = _build_request(model, params)
if api_type == "azure":
from openai import AzureOpenAI

client = AzureOpenAI(
api_key=openai_params["api_key"],
api_version=api_version,
azure_endpoint=openai_params[
"base_url"
], # Your Azure OpenAI resource's endpoint value.
azure_endpoint=openai_params["base_url"],
http_client=httpx.Client(proxies=proxies),
)
else:
from openai import OpenAI

client = OpenAI(**openai_params)
client = OpenAI(**openai_params, http_client=httpx.Client(proxies=proxies))
res = client.chat.completions.create(messages=history, **payloads)
text = ""
for r in res:
Expand All @@ -186,7 +187,6 @@ def chatgpt_generate_stream(
res = openai.ChatCompletion.create(messages=history, **payloads)

text = ""
print("res", res)
for r in res:
if r["choices"][0]["delta"].get("content") is not None:
content = r["choices"][0]["delta"]["content"]
Expand All @@ -199,22 +199,25 @@ async def async_chatgpt_generate_stream(
):
if metadata.version("openai") >= "1.0.0":
model_params = model.get_params()
openai_params, api_type, api_version = _initialize_openai_v1(model_params)
openai_params, api_type, api_version, proxies = _initialize_openai_v1(
model_params
)
history, payloads = _build_request(model, params)
if api_type == "azure":
from openai import AsyncAzureOpenAI

client = AsyncAzureOpenAI(
api_key=openai_params["api_key"],
api_version=api_version,
azure_endpoint=openai_params[
"base_url"
], # Your Azure OpenAI resource's endpoint value.
azure_endpoint=openai_params["base_url"],
http_client=httpx.AsyncClient(proxies=proxies),
)
else:
from openai import AsyncOpenAI

client = AsyncOpenAI(**openai_params)
client = AsyncOpenAI(
**openai_params, http_client=httpx.AsyncClient(proxies=proxies)
)

res = await client.chat.completions.create(messages=history, **payloads)
text = ""
Expand Down

0 comments on commit 50b3f35

Please sign in to comment.