Skip to content

Commit

Permalink
Add Multimodal Textbox (lm-sys#3297)
Browse files Browse the repository at this point in the history
  • Loading branch information
BabyChouSr authored May 10, 2024
1 parent be66155 commit e395370
Show file tree
Hide file tree
Showing 15 changed files with 1,798 additions and 257 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,8 @@ tests/state_of_the_union.txt

# Build
build

# Image data
serve_images
val2014
vqa_examples
13 changes: 13 additions & 0 deletions docs/arena.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,16 @@ If you have a model hosted by a 3rd party API provider or yourself, please give
### Method 2: Hosted by LMSYS
1. Contribute the code to support this model in FastChat by submitting a pull request. See [instructions](model_support.md).
2. After the model is supported, we will try to schedule some compute resources to host the model in the arena. However, due to the limited resources we have, we may not be able to serve every model. We will select the models based on popularity, quality, diversity, and other factors.


## How to launch vision arena

1. Run `python3 -m fastchat.serve.controller` to start the controller and begin registering local model workers and API-provided workers.
2. Run `python3 -m fastchat.serve.sglang_worker --model-path <model-path> --tokenizer-path <tokenizer-path>` to run local vision-language models. Currently supported models include the LLaVA and Yi-VL series.
3. If you are using a 3rd party model with an API provider (e.g. GPT-4-V, Gemini 1.5), please follow the instructions [model_support.md](model_support.md) to add a json file `api_endpoints.json`.
4. Run the gradio server with the `--vision-arena` flag on.

Example command:
```
python3 -m fastchat.serve.gradio_web_server_multi --share --register-api-endpoint-file api_endpoints.json --vision-arena
```
11 changes: 10 additions & 1 deletion docs/model_support.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,21 @@ For custom protocols, implementation of a streaming generator in [fastchat/serve
"api_type": "openai",
"api_base": "https://api.openai.com/v1",
"api_key": "sk-******",
"anony_only": false
"anony_only": false,
"recommended_config": {
"temperature": 0.7,
"top_p": 1.0
},
"text-arena": true,
"vision-arena": false,
}
}
```
- "api_type" can be one of the following: openai, anthropic, gemini, mistral, yandexgpt or reka. For custom APIs, add a new type and implement it accordingly.
- "anony_only" indicates whether to display this model in anonymous mode only.
- "recommended_config" indicates the recommended generation parameters for temperature and top_p.
- "text-arena" indicates whether the model should be displayed in the Text Arena.
- "vision-arena" indicates whether the model should be displayed in the Vision Arena.

2. Launch the Gradio web server with the argument `--register api_endpoints.json`:
```
Expand Down
244 changes: 238 additions & 6 deletions fastchat/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import dataclasses
from enum import auto, IntEnum
from io import BytesIO
import os
from typing import List, Any, Dict, Union, Tuple


Expand Down Expand Up @@ -311,6 +312,8 @@ def get_prompt(self) -> str:
ret = system_prompt + "\n"
for role, message in self.messages:
if message:
if type(message) is tuple:
message, images = message
ret += role + ": " + message + "\n"
else:
ret += role + ":"
Expand Down Expand Up @@ -391,14 +394,81 @@ def to_gradio_chatbot(self):
if type(msg) is tuple:
msg, image = msg
img_b64_str = image[0] # Only one image on gradio at one time
img_str = f'<img src="data:image/jpeg;base64,{img_b64_str}" alt="user upload image" />'
if img_b64_str.startswith("http://") or img_b64_str.startswith(
"https://"
):
img_str = f'<img src="{img_b64_str}" alt="user upload image" />'
else:
img_str = f'<img src="data:image/jpeg;base64,{img_b64_str}" alt="user upload image" />'
msg = img_str + msg.replace("<image>\n", "").strip()

ret.append([msg, None])
else:
ret[-1][-1] = msg
return ret

def to_openai_image_format(self, image_urls):
import base64

openai_images = []
for image_url in image_urls:
if image_url.startswith("http://") or image_url.startswith(
"https://"
): # input is a url
openai_images.append(image_url)
elif image_url.lower().endswith(
("png", "jpg", "jpeg", "webp", "gif")
): # input is a local image
img_b64_str = self.convert_image_to_base64(image_url)
filetype = image_url.split(".")[-1].lower()
openai_images.append(f"data:image/{filetype};base64,{img_b64_str}")
else:
try:
assert (
base64.b64encode(base64.b64decode(image_url))
== image_url.encode()
), "The image data is not a valid base64 encoded string"
openai_images.append(f"data:image/jpeg;base64,{image_url}")
except:
raise ValueError(
f"This file is not valid or not currently supported by the OpenAI API: {image_url}"
)
return openai_images

def to_openai_vision_api_messages(self):
"""Convert the conversation to OpenAI vision api completion format"""
ret = [
{
"role": "system",
"content": [{"type": "text", "text": self.system_message}],
}
]
for i, (_, msg) in enumerate(self.messages[self.offset :]):
if i % 2 == 0:
if type(msg) is tuple:
content_list = [{"type": "text", "text": msg[0]}]

image_urls = self.to_openai_image_format(msg[1])
for image_url in image_urls:
content_list.append(
{"type": "image_url", "image_url": {"url": image_url}}
)

ret.append({"role": "user", "content": content_list})
else:
ret.append(
{"role": "user", "content": [{"type": "text", "text": msg}]}
)
else:
if msg is not None:
ret.append(
{
"role": "assistant",
"content": [{"type": "text", "text": msg}],
}
)
return ret

def to_openai_api_messages(self):
"""Convert the conversation to OpenAI chat completion format."""
if self.system_message == "":
Expand All @@ -414,11 +484,163 @@ def to_openai_api_messages(self):
ret.append({"role": "assistant", "content": msg})
return ret

def extract_text_from_messages(self):
return [
(role, message[0]) if type(message) is tuple else (role, message)
for role, message in self.messages
def to_vertex_api_messages(self):
from vertexai.preview.generative_models import Image
import base64
import requests

if self.system_message == "":
ret = []
else:
ret = [self.system_message]

for role, msg in self.messages[self.offset :]:
if msg is not None:
if type(msg) is tuple:
text, images = msg[0], msg[1]
for image in images:
if image.startswith("http://") or image.startswith("https://"):
response = requests.get(image)
image = response.content
else: # base64
image = base64.b64decode(image)
ret.append(Image.from_bytes(image))
ret.append(text)
else:
ret.append(msg)

return ret

def to_anthropic_vision_api_messages(self):
"""Convert the conversation to Claude-3 Messages Vision API format"""
ret = [
{
"role": "system",
"content": [{"type": "text", "text": self.system_message}],
}
]
for i, (_, msg) in enumerate(self.messages[self.offset :]):
if i % 2 == 0:
if type(msg) is tuple:
content_list = [{"type": "text", "text": msg[0]}]

for image_url in msg[1]:
# Claude only supports base64
if image_url.startswith("http://") or image_url.startswith(
"https://"
):
image_url = self.convert_image_to_base64(image_url)

content_list.append(
{
"type": "image",
"source": {
"type": "base64",
"media_type": "image/png",
"data": image_url,
},
}
)

ret.append({"role": "user", "content": content_list})
else:
ret.append(
{"role": "user", "content": [{"type": "text", "text": msg}]}
)
else:
if msg is not None:
ret.append(
{
"role": "assistant",
"content": [{"type": "text", "text": msg}],
}
)
return ret

def to_reka_api_messages(self):
ret = []
for i, (_, msg) in enumerate(self.messages[self.offset :]):
if i % 2 == 0:
if type(msg) == tuple:
text, images = msg
for image in images:
if image.startswith("https://") or image.startswith("http://"):
ret.append(
{"type": "human", "text": text, "media_url": image}
)
else:
ret.append(
{
"type": "human",
"text": text,
"media_url": f"data:image/jpeg;base64,{image}",
}
)
else:
ret.append({"type": "human", "text": msg})
else:
if msg is not None:
ret.append({"type": "model", "text": msg})

return ret

def save_new_images(self, use_remote_storage=False):
import hashlib
from fastchat.constants import LOGDIR
from fastchat.utils import load_image, upload_image_file_to_gcs

_, last_user_message = self.messages[-2]

if type(last_user_message) == tuple:
text, images = last_user_message[0], last_user_message[1]
loaded_images = [load_image(image) for image in images]
image_hashes = [
hashlib.md5(image.tobytes()).hexdigest() for image in loaded_images
]

image_filenames = []
for i, (loaded_image, hash_str) in enumerate(
zip(loaded_images, image_hashes)
):
filename = os.path.join(
"serve_images",
f"{hash_str}.jpg",
)

if use_remote_storage:
image_url = upload_image_file_to_gcs(loaded_image, filename)
# NOTE(chris): If the URL were public, then we set it here so future model uses the link directly
# images[i] = image_url
else:
filename = os.path.join(LOGDIR, filename)
if not os.path.isfile(filename):
os.makedirs(os.path.dirname(filename), exist_ok=True)
loaded_image.save(filename)

def extract_text_and_image_hashes_from_messages(self):
import hashlib
from fastchat.utils import load_image

messages = []

for role, message in self.messages:
if type(message) is tuple:
text, images = message[0], message[1]

image_hashes = []
for image in images:
if image.startswith("http://") or image.startswith("https://"):
image_hashes.append(image)
else:
image = load_image(image)
image_hash = hashlib.md5(image.tobytes()).hexdigest()
image_hashes.append(image_hash)

messages.append((role, (text, image_hashes)))
else:
messages.append((role, message))

return messages

def copy(self):
return Conversation(
Expand All @@ -440,7 +662,7 @@ def dict(self):
"template_name": self.name,
"system_message": self.system_message,
"roles": self.roles,
"messages": self.extract_text_from_messages(),
"messages": self.extract_text_and_image_hashes_from_messages(),
"offset": self.offset,
}

Expand Down Expand Up @@ -1802,6 +2024,16 @@ def get_conv_template(name: str) -> Conversation:
)
)

register_conv_template(
Conversation(
name="reka",
system_message="",
roles=("user", "assistant"),
sep_style=SeparatorStyle.DEFAULT,
sep=None,
)
)


if __name__ == "__main__":
from fastchat.conversation import get_conv_template
Expand Down
16 changes: 14 additions & 2 deletions fastchat/model/model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def get_model_info(name: str) -> ModelInfo:
"claude-1",
],
"Claude",
"https://www.anthropic.com/index/claude-2",
"https://www.anthropic.com/news/claude-3-family",
"Claude by Anthropic",
)

Expand Down Expand Up @@ -151,7 +151,12 @@ def get_model_info(name: str) -> ModelInfo:
)

register_model_info(
["gemini-pro", "gemini-pro-dev-api"],
[
"gemini-pro",
"gemini-pro-dev-api",
"gemini-1.0-pro-vision",
"gemini-1.5-pro-preview-0409",
],
"Gemini",
"https://blog.google/technology/ai/google-gemini-pro-imagen-duet-ai-update/",
"Gemini by Google",
Expand Down Expand Up @@ -750,3 +755,10 @@ def get_model_info(name: str) -> ModelInfo:
"https://huggingface.co/cllm",
"consistency-llm is a new generation of parallel decoder LLMs with fast generation speed.",
)

register_model_info(
["reka-flash", "reka-flash-20240226"],
"Reka Flash",
"https://reka.ai/reka-flash",
"Multimodal model by Reka",
)
Loading

0 comments on commit e395370

Please sign in to comment.