From 6f4258a18d4579d3fced158b040549155c0b7c2e Mon Sep 17 00:00:00 2001 From: eFrick <77757559+efrick2002@users.noreply.github.com> Date: Sat, 11 Jan 2025 13:46:34 -0700 Subject: [PATCH] p2l stuff (#3660) --- fastchat/model/model_adapter.py | 2 +- fastchat/serve/api_provider.py | 76 +++++++++++++++++++++++++++++ fastchat/serve/gradio_web_server.py | 38 ++++++++++++++- pyproject.toml | 4 +- 4 files changed, 115 insertions(+), 5 deletions(-) diff --git a/fastchat/model/model_adapter.py b/fastchat/model/model_adapter.py index 9625df6db..16cf5d2b6 100644 --- a/fastchat/model/model_adapter.py +++ b/fastchat/model/model_adapter.py @@ -2489,7 +2489,7 @@ def get_default_conv_template(self, model_path: str) -> Conversation: class NoSystemAdapter(BaseModelAdapter): def match(self, model_path: str): - keyword_list = ["athene-70b"] + keyword_list = ["athene-70b", "p2l"] for keyword in keyword_list: if keyword == model_path.lower(): diff --git a/fastchat/serve/api_provider.py b/fastchat/serve/api_provider.py index e00326b30..f830d397a 100644 --- a/fastchat/serve/api_provider.py +++ b/fastchat/serve/api_provider.py @@ -246,6 +246,17 @@ def get_api_provider_stream_iter( api_key=model_api_dict["api_key"], conversation_id=state.conv_id, ) + elif model_api_dict["api_type"] == "p2l": + prompt = conv.to_openai_api_messages() + stream_iter = p2l_api_stream_iter( + model_api_dict["model_name"], + prompt, + temperature, + top_p, + max_new_tokens, + api_base=model_api_dict["api_base"], + api_key=model_api_dict["api_key"], + ) else: raise NotImplementedError() @@ -412,6 +423,71 @@ def column_api_stream_iter( } +def p2l_api_stream_iter( + model_name, + messages, + temperature, + top_p, + max_new_tokens, + api_base=None, + api_key=None, +): + import openai + + client = openai.OpenAI( + base_url=api_base, + api_key=api_key or "-", + timeout=180, + ) + + # Make requests for logging + text_messages = [] + for message in messages: + if type(message["content"]) == str: # text-only model + text_messages.append(message) + else: # vision model + filtered_content_list = [ + content for content in message["content"] if content["type"] == "text" + ] + text_messages.append( + {"role": message["role"], "content": filtered_content_list} + ) + + gen_params = { + "model": model_name, + "prompt": text_messages, + "temperature": None, + "top_p": None, + "max_new_tokens": max_new_tokens, + } + logger.info(f"==== request ====\n{gen_params}") + + res = client.chat.completions.create( + model=model_name, + messages=messages, + max_tokens=max_new_tokens, + stream=True, + ) + text = "" + for chunk_idx, chunk in enumerate(res): + if len(chunk.choices) > 0: + text += chunk.choices[0].delta.content or "" + + data = { + "text": text, + "error_code": 0, + } + + if chunk_idx == 0: + if hasattr(chunk.choices[0].delta, "model"): + data["ans_model"] = chunk.choices[0].delta.model + + if hasattr(chunk, "router_outputs"): + data["router_outputs"] = chunk.router_outputs + + yield data + + def upload_openai_file_to_gcs(file_id): import openai from google.cloud import storage diff --git a/fastchat/serve/gradio_web_server.py b/fastchat/serve/gradio_web_server.py index 4f0521da0..381e62a20 100644 --- a/fastchat/serve/gradio_web_server.py +++ b/fastchat/serve/gradio_web_server.py @@ -11,7 +11,7 @@ import random import time import uuid -from typing import List +from typing import List, Dict import gradio as gr import requests @@ -119,6 +119,8 @@ def __init__(self, model_name, is_vision=False): self.model_name = model_name self.oai_thread_id = None self.is_vision = is_vision + self.ans_models = [] + self.router_outputs = [] # NOTE(chris): This could be sort of a hack since it assumes the user only uploads one image. If they can upload multiple, we should store a list of image hashes. self.has_csam_image = False @@ -128,6 +130,12 @@ def __init__(self, model_name, is_vision=False): self.regen_support = False self.init_system_prompt(self.conv, is_vision) + def update_ans_models(self, ans: str) -> None: + self.ans_models.append(ans) + + def update_router_outputs(self, outputs: Dict[str, float]) -> None: + self.router_outputs.append(outputs) + def init_system_prompt(self, conv, is_vision): system_prompt = conv.get_system_message(is_vision) if len(system_prompt) == 0: @@ -154,6 +162,20 @@ def dict(self): } ) + if self.ans_models: + base.update( + { + "ans_models": self.ans_models, + } + ) + + if self.router_outputs: + base.update( + { + "router_outputs": self.router_outputs, + } + ) + if self.is_vision: base.update({"has_csam_image": self.has_csam_image}) return base @@ -420,7 +442,7 @@ def is_limit_reached(model_name, ip): def bot_response( - state, + state: State, temperature, top_p, max_new_tokens, @@ -532,6 +554,18 @@ def bot_response( try: data = {"text": ""} for i, data in enumerate(stream_iter): + # Change for P2L: + if i == 0: + if "ans_model" in data: + ans_model = data.get("ans_model") + + state.update_ans_models(ans_model) + + if "router_outputs" in data: + router_outputs = data.get("router_outputs") + + state.update_router_outputs(router_outputs) + if data["error_code"] == 0: output = data["text"].strip() conv.update_last_message(output + "▌") diff --git a/pyproject.toml b/pyproject.toml index fedd9e2dc..916aaeae0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,8 +19,8 @@ dependencies = [ ] [project.optional-dependencies] -model_worker = ["accelerate>=0.21", "peft", "sentencepiece", "torch", "transformers>=4.31.0", "protobuf"] -webui = ["gradio>=4.10"] +model_worker = ["accelerate>=0.21", "peft", "sentencepiece", "torch", "transformers>=4.31.0", "protobuf", "openai", "anthropic"] +webui = ["gradio>=4.10", "plotly", "scipy"] train = ["einops", "flash-attn>=2.0", "wandb"] llm_judge = ["openai<1", "anthropic>=0.3", "ray"] dev = ["black==23.3.0", "pylint==2.8.2"]