Skip to content

Commit

Permalink
fix(model): Fix openai adapt previous proxy_server_url configuration …
Browse files Browse the repository at this point in the history
…and support azure openai model (#668)
  • Loading branch information
csunny authored Oct 12, 2023
2 parents 83cd90e + 7ffa45c commit a9241e1
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 15 deletions.
21 changes: 21 additions & 0 deletions pilot/model/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,9 +282,30 @@ class ProxyModelParameters(BaseModelParameters):
"help": "Proxy server url, such as: https://api.openai.com/v1/chat/completions"
},
)

proxy_api_key: str = field(
metadata={"tags": "privacy", "help": "The api key of current proxy LLM"},
)

proxy_api_base: str = field(
default=None,
metadata={
"help": "The base api address, such as: https://api.openai.com/v1. If None, we will use proxy_api_base first"
},
)

proxy_api_type: Optional[str] = field(
default=None,
metadata={
"help": "The api type of current proxy the current proxy model, if you use Azure, it can be: azure"
},
)

proxy_api_version: Optional[str] = field(
default=None,
metadata={"help": "The api version of current proxy the current model"},
)

http_proxy: Optional[str] = field(
default=os.environ.get("http_proxy") or os.environ.get("https_proxy"),
metadata={"help": "The http or https proxy to use openai"},
Expand Down
95 changes: 80 additions & 15 deletions pilot/model/proxy/llms/chatgpt.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,63 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import json
import os
from typing import List
import logging

import openai

from pilot.model.proxy.llms.proxy_model import ProxyModel
from pilot.model.parameter import ProxyModelParameters
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType

logger = logging.getLogger(__name__)

def chatgpt_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):

def _initialize_openai(params: ProxyModelParameters):
api_type = params.proxy_api_type or os.getenv("OPENAI_API_TYPE", "open_ai")

api_base = params.proxy_api_base or os.getenv(
"OPENAI_API_TYPE",
os.getenv("AZURE_OPENAI_ENDPOINT") if api_type == "azure" else None,
)
api_key = params.proxy_api_key or os.getenv(
"OPENAI_API_KEY",
os.getenv("AZURE_OPENAI_KEY") if api_type == "azure" else None,
)
api_version = params.proxy_api_version or os.getenv("OPENAI_API_VERSION")

if not api_base and params.proxy_server_url:
# Adapt previous proxy_server_url configuration
api_base = params.proxy_server_url.split("/chat/completions")[0]
if api_type:
openai.api_type = api_type
if api_base:
openai.api_base = api_base
if api_key:
openai.api_key = api_key
if api_version:
openai.api_version = api_version
if params.http_proxy:
openai.proxy = params.http_proxy

openai_params = {
"api_type": api_type,
"api_base": api_base,
"api_version": api_version,
"proxy": params.http_proxy,
}

return openai_params


def _build_request(model: ProxyModel, params):
history = []

model_params = model.get_params()
print(f"Model: {model}, model_params: {model_params}")
logger.info(f"Model: {model}, model_params: {model_params}")

proxy_api_key = model_params.proxy_api_key
if model_params.http_proxy:
openai.proxy = model_params.http_proxy
openai.api_key = os.getenv("OPENAI_API_KEY") or proxy_api_key
proxyllm_backend = model_params.proxyllm_backend
if not proxyllm_backend:
proxyllm_backend = "gpt-3.5-turbo"
openai_params = _initialize_openai(model_params)

messages: List[ModelMessage] = params["messages"]
# Add history conversation
Expand All @@ -51,18 +83,51 @@ def chatgpt_generate_stream(
history.append(last_user_input)

payloads = {
"model": proxyllm_backend, # just for test, remove this later
"temperature": params.get("temperature"),
"max_tokens": params.get("max_new_tokens"),
"stream": True,
}
res = openai.ChatCompletion.create(messages=history, **payloads)
proxyllm_backend = model_params.proxyllm_backend

if openai_params["api_type"] == "azure":
# engine = "deployment_name".
proxyllm_backend = proxyllm_backend or "gpt-35-turbo"
payloads["engine"] = proxyllm_backend
else:
proxyllm_backend = proxyllm_backend or "gpt-3.5-turbo"
payloads["model"] = proxyllm_backend

print(f"Send request to real model {proxyllm_backend}")
logger.info(
f"Send request to real model {proxyllm_backend}, openai_params: {openai_params}"
)
return history, payloads


def chatgpt_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):
history, payloads = _build_request(model, params)

res = openai.ChatCompletion.create(messages=history, **payloads)

text = ""
for r in res:
if r["choices"][0]["delta"].get("content") is not None:
content = r["choices"][0]["delta"]["content"]
text += content
yield text


async def async_chatgpt_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):
history, payloads = _build_request(model, params)

res = await openai.ChatCompletion.acreate(messages=history, **payloads)

text = ""
async for r in res:
if r["choices"][0]["delta"].get("content") is not None:
content = r["choices"][0]["delta"]["content"]
text += content
yield text

0 comments on commit a9241e1

Please sign in to comment.