Skip to content

Commit

Permalink
Update chatgpt.py
Browse files Browse the repository at this point in the history
  • Loading branch information
zhanghy-sketchzh authored Nov 16, 2023
1 parent e198cd3 commit 1b615ae
Showing 1 changed file with 37 additions and 30 deletions.
67 changes: 37 additions & 30 deletions pilot/model/proxy/llms/chatgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,16 @@ def _initialize_openai(params: ProxyModelParameters):

return openai_params


def _initialize_openai_v1(params: ProxyModelParameters):
try:
from openai import OpenAI
except ImportError as exc:
raise ValueError(
"Could not import python package: openai "
"Please install openai by command `pip install openai"
)
raise ValueError(
"Could not import python package: openai "
"Please install openai by command `pip install openai"
)

api_type = params.proxy_api_type or os.getenv("OPENAI_API_TYPE", "open_ai")

base_url = params.proxy_api_base or os.getenv(
Expand All @@ -80,7 +81,7 @@ def _initialize_openai_v1(params: ProxyModelParameters):
if not base_url and params.proxy_server_url:
# Adapt previous proxy_server_url configuration
base_url = params.proxy_server_url.split("/chat/completions")[0]

if params.http_proxy:
openai.proxies = params.http_proxy
openai_params = {
Expand All @@ -89,19 +90,14 @@ def _initialize_openai_v1(params: ProxyModelParameters):
"proxies": params.http_proxy,
}

return openai_params,api_type,api_version




return openai_params, api_type, api_version


def _build_request(model: ProxyModel, params):
history = []

model_params = model.get_params()
logger.info(f"Model: {model}, model_params: {model_params}")


messages: List[ModelMessage] = params["messages"]
# Add history conversation
Expand Down Expand Up @@ -134,10 +130,10 @@ def _build_request(model: ProxyModel, params):
proxyllm_backend = model_params.proxyllm_backend

if metadata.version("openai") >= "1.0.0":
openai_params,api_type,api_version = _initialize_openai_v1(model_params)
openai_params, api_type, api_version = _initialize_openai_v1(model_params)
proxyllm_backend = proxyllm_backend or "gpt-3.5-turbo"
payloads["model"] = proxyllm_backend
else:
else:
openai_params = _initialize_openai(model_params)
if openai_params["api_type"] == "azure":
# engine = "deployment_name".
Expand All @@ -158,20 +154,24 @@ def chatgpt_generate_stream(
):
if metadata.version("openai") >= "1.0.0":
model_params = model.get_params()
openai_params,api_type,api_version = _initialize_openai_v1(model_params)
openai_params, api_type, api_version = _initialize_openai_v1(model_params)
history, payloads = _build_request(model, params)
if api_type == "azure":
from openai import AzureOpenAI

client = AzureOpenAI(
api_key = openai_params["api_key"],
api_version = api_version,
azure_endpoint = openai_params["base_url"] # Your Azure OpenAI resource's endpoint value.
)
api_key=openai_params["api_key"],
api_version=api_version,
azure_endpoint=openai_params[
"base_url"
], # Your Azure OpenAI resource's endpoint value.
)
else:
from openai import OpenAI

client = OpenAI(**openai_params)
print("openai_params",openai_params)
print("payloads",payloads)
print("openai_params", openai_params)
print("payloads", payloads)
res = client.chat.completions.create(messages=history, **payloads)
print(res)
text = ""
Expand All @@ -182,14 +182,14 @@ def chatgpt_generate_stream(
yield text

else:

import openai

history, payloads = _build_request(model, params)

res = openai.ChatCompletion.create(messages=history, **payloads)

text = ""
print("res",res)
print("res", res)
for r in res:
if r["choices"][0]["delta"].get("content") is not None:
content = r["choices"][0]["delta"]["content"]
Expand All @@ -202,22 +202,29 @@ async def async_chatgpt_generate_stream(
):
if metadata.version("openai") >= "1.0.0":
model_params = model.get_params()
openai_params,api_type,api_version = _initialize_openai_v1(model_params)
openai_params, api_type, api_version = _initialize_openai_v1(model_params)
history, payloads = _build_request(model, params)
if api_type == "azure":
from openai import AsyncAzureOpenAI

client = AsyncAzureOpenAI(
api_key = openai_params["api_key"],
end_point = openai_params["base_url"],
api_version = api_version,
azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT") # Your Azure OpenAI resource's endpoint value.
)
api_key=openai_params["api_key"],
end_point=openai_params["base_url"],
api_version=api_version,
azure_endpoint=os.getenv(
"AZURE_OPENAI_ENDPOINT"
), # Your Azure OpenAI resource's endpoint value.
)
else:
from openai import AsyncOpenAI

client = AsyncOpenAI(**openai_params)
res = await client.chat.completions.create(messages=history, **payloads).model_dump()
res = await client.chat.completions.create(
messages=history, **payloads
).model_dump()
else:
import openai

history, payloads = _build_request(model, params)

res = await openai.ChatCompletion.acreate(messages=history, **payloads)
Expand Down

0 comments on commit 1b615ae

Please sign in to comment.