Skip to content

Commit

Permalink
p2l stuff (#3660)
Browse files Browse the repository at this point in the history
  • Loading branch information
efrick2002 authored Jan 11, 2025
1 parent 8664268 commit 6f4258a
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 5 deletions.
2 changes: 1 addition & 1 deletion fastchat/model/model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2489,7 +2489,7 @@ def get_default_conv_template(self, model_path: str) -> Conversation:

class NoSystemAdapter(BaseModelAdapter):
def match(self, model_path: str):
keyword_list = ["athene-70b"]
keyword_list = ["athene-70b", "p2l"]

for keyword in keyword_list:
if keyword == model_path.lower():
Expand Down
76 changes: 76 additions & 0 deletions fastchat/serve/api_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,17 @@ def get_api_provider_stream_iter(
api_key=model_api_dict["api_key"],
conversation_id=state.conv_id,
)
elif model_api_dict["api_type"] == "p2l":
prompt = conv.to_openai_api_messages()
stream_iter = p2l_api_stream_iter(
model_api_dict["model_name"],
prompt,
temperature,
top_p,
max_new_tokens,
api_base=model_api_dict["api_base"],
api_key=model_api_dict["api_key"],
)
else:
raise NotImplementedError()

Expand Down Expand Up @@ -412,6 +423,71 @@ def column_api_stream_iter(
}


def p2l_api_stream_iter(
model_name,
messages,
temperature,
top_p,
max_new_tokens,
api_base=None,
api_key=None,
):
import openai

client = openai.OpenAI(
base_url=api_base,
api_key=api_key or "-",
timeout=180,
)

# Make requests for logging
text_messages = []
for message in messages:
if type(message["content"]) == str: # text-only model
text_messages.append(message)
else: # vision model
filtered_content_list = [
content for content in message["content"] if content["type"] == "text"
]
text_messages.append(
{"role": message["role"], "content": filtered_content_list}
)

gen_params = {
"model": model_name,
"prompt": text_messages,
"temperature": None,
"top_p": None,
"max_new_tokens": max_new_tokens,
}
logger.info(f"==== request ====\n{gen_params}")

res = client.chat.completions.create(
model=model_name,
messages=messages,
max_tokens=max_new_tokens,
stream=True,
)
text = ""
for chunk_idx, chunk in enumerate(res):
if len(chunk.choices) > 0:
text += chunk.choices[0].delta.content or ""

data = {
"text": text,
"error_code": 0,
}

if chunk_idx == 0:
if hasattr(chunk.choices[0].delta, "model"):
data["ans_model"] = chunk.choices[0].delta.model

if hasattr(chunk, "router_outputs"):
data["router_outputs"] = chunk.router_outputs

yield data


def upload_openai_file_to_gcs(file_id):
import openai
from google.cloud import storage
Expand Down
38 changes: 36 additions & 2 deletions fastchat/serve/gradio_web_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import random
import time
import uuid
from typing import List
from typing import List, Dict

import gradio as gr
import requests
Expand Down Expand Up @@ -119,6 +119,8 @@ def __init__(self, model_name, is_vision=False):
self.model_name = model_name
self.oai_thread_id = None
self.is_vision = is_vision
self.ans_models = []
self.router_outputs = []

# NOTE(chris): This could be sort of a hack since it assumes the user only uploads one image. If they can upload multiple, we should store a list of image hashes.
self.has_csam_image = False
Expand All @@ -128,6 +130,12 @@ def __init__(self, model_name, is_vision=False):
self.regen_support = False
self.init_system_prompt(self.conv, is_vision)

def update_ans_models(self, ans: str) -> None:
self.ans_models.append(ans)

def update_router_outputs(self, outputs: Dict[str, float]) -> None:
self.router_outputs.append(outputs)

def init_system_prompt(self, conv, is_vision):
system_prompt = conv.get_system_message(is_vision)
if len(system_prompt) == 0:
Expand All @@ -154,6 +162,20 @@ def dict(self):
}
)

if self.ans_models:
base.update(
{
"ans_models": self.ans_models,
}
)

if self.router_outputs:
base.update(
{
"router_outputs": self.router_outputs,
}
)

if self.is_vision:
base.update({"has_csam_image": self.has_csam_image})
return base
Expand Down Expand Up @@ -420,7 +442,7 @@ def is_limit_reached(model_name, ip):


def bot_response(
state,
state: State,
temperature,
top_p,
max_new_tokens,
Expand Down Expand Up @@ -532,6 +554,18 @@ def bot_response(
try:
data = {"text": ""}
for i, data in enumerate(stream_iter):
# Change for P2L:
if i == 0:
if "ans_model" in data:
ans_model = data.get("ans_model")

state.update_ans_models(ans_model)

if "router_outputs" in data:
router_outputs = data.get("router_outputs")

state.update_router_outputs(router_outputs)

if data["error_code"] == 0:
output = data["text"].strip()
conv.update_last_message(output + "▌")
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ dependencies = [
]

[project.optional-dependencies]
model_worker = ["accelerate>=0.21", "peft", "sentencepiece", "torch", "transformers>=4.31.0", "protobuf"]
webui = ["gradio>=4.10"]
model_worker = ["accelerate>=0.21", "peft", "sentencepiece", "torch", "transformers>=4.31.0", "protobuf", "openai", "anthropic"]
webui = ["gradio>=4.10", "plotly", "scipy"]
train = ["einops", "flash-attn>=2.0", "wandb"]
llm_judge = ["openai<1", "anthropic>=0.3", "ray"]
dev = ["black==23.3.0", "pylint==2.8.2"]
Expand Down

0 comments on commit 6f4258a

Please sign in to comment.