Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Llava #2639

Closed
wants to merge 41 commits into from
Closed
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
e2d9db7
init commit
BabyChouSr Nov 4, 2023
ebe4013
loading model
BabyChouSr Nov 4, 2023
5b40cc2
text model test works
BabyChouSr Nov 4, 2023
ef81e0d
add ability to have images
BabyChouSr Nov 5, 2023
ce54a29
reset convo for new image
BabyChouSr Nov 5, 2023
cc75581
simply warn if file doesn't exist
BabyChouSr Nov 6, 2023
b3da05a
fix model worker endpoint
BabyChouSr Nov 6, 2023
a443383
init gradio single vision language model interface
BabyChouSr Nov 6, 2023
5f4e439
create new file for vision server
BabyChouSr Nov 7, 2023
9388048
format
BabyChouSr Nov 7, 2023
547ca26
format issues, refactor vision server, rename constants
BabyChouSr Nov 10, 2023
e5388df
remove load image from multimodal worker
BabyChouSr Nov 10, 2023
915b4cd
remove some unused imports
BabyChouSr Nov 10, 2023
f1602bf
add model support
BabyChouSr Nov 10, 2023
8173ee3
fix some imports
BabyChouSr Nov 11, 2023
86039cf
remove unused args and create utility fuction
BabyChouSr Nov 11, 2023
21ee79e
add conversion to openai format
BabyChouSr Nov 12, 2023
f9d8962
rename flag
BabyChouSr Nov 12, 2023
fd32e20
Merge branch 'main' of https://github.com/BabyChouSr/FastChat into pr…
BabyChouSr Nov 12, 2023
502cf23
clean and test
BabyChouSr Nov 12, 2023
ec21cb6
fix conversation description
BabyChouSr Nov 12, 2023
78181b1
fix and format
BabyChouSr Nov 16, 2023
b48faf4
fix get_images function for multiple images
BabyChouSr Nov 16, 2023
921fa3d
move imports into the function
BabyChouSr Nov 22, 2023
d942f3e
Merge remote-tracking branch 'origin' into pr-multimodal
BabyChouSr Nov 22, 2023
60960a0
format
BabyChouSr Nov 22, 2023
90af475
consolidate multimodal model worker and model worker
BabyChouSr Nov 23, 2023
919462a
change gradio server to just a tab in multi-view
BabyChouSr Nov 23, 2023
0ff62c0
significantly reduce code in vision file
BabyChouSr Nov 23, 2023
7faf3dc
separate get_model_list for text vs. vision language models
BabyChouSr Nov 24, 2023
d9bc4ac
Merge branch 'main' of https://github.com/BabyChouSr/FastChat into pr…
BabyChouSr Nov 25, 2023
825c0e2
add vision optional dependency
BabyChouSr Nov 25, 2023
bd9a5e2
Merge branch 'main' of https://github.com/BabyChouSr/FastChat into pr…
BabyChouSr Dec 1, 2023
57e692f
Update server for open_ai compatibility
BabyChouSr Dec 3, 2023
857132e
format
BabyChouSr Dec 3, 2023
96e00c5
Change burden for image token to the Conversation class
BabyChouSr Dec 4, 2023
1501a7c
Merge branch 'main' of https://github.com/BabyChouSr/FastChat into pr…
BabyChouSr Dec 17, 2023
dc22df2
Format
BabyChouSr Dec 17, 2023
0cc1bd8
Use huggingface llava
BabyChouSr Dec 18, 2023
a3f2011
Fix conversation and remove old LLaVA files
BabyChouSr Dec 19, 2023
487e987
Remove LLaVA specific stream_func
BabyChouSr Dec 22, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions fastchat/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
INACTIVE_MSG = "THIS SESSION HAS BEEN INACTIVE FOR TOO LONG. PLEASE REFRESH THIS PAGE."
# Maximum input length
INPUT_CHAR_LEN_LIMIT = int(os.getenv("FASTCHAT_INPUT_CHAR_LEN_LIMIT", 3072))
INPUT_CHAR_LEN_LIMIT_WITH_IMAGE = int(
BabyChouSr marked this conversation as resolved.
Show resolved Hide resolved
os.getenv("FASTCHAT_INPUT_CHAR_LEN_LIMIT_WITH_IMAGE", 2816)
)
# Maximum conversation turns
CONVERSATION_TURN_LIMIT = 50
# Session expiration time
Expand Down
112 changes: 111 additions & 1 deletion fastchat/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ def get_prompt(self) -> str:
for i, (role, message) in enumerate(self.messages):
tag = self.roles[i % 2]
if message:
if type(message) is tuple:
BabyChouSr marked this conversation as resolved.
Show resolved Hide resolved
message, _, _ = message
if i == 0:
ret += message + " "
else:
Expand Down Expand Up @@ -216,6 +218,65 @@ def get_prompt(self) -> str:
else:
raise ValueError(f"Invalid style: {self.sep_style}")

def get_images(self, return_pil=False):
images = []
for i, (role, msg) in enumerate(self.messages[self.offset :]):
if i % 2 == 0:
if type(msg) is tuple:
import base64
from io import BytesIO
from PIL import Image

msg, image, image_process_mode = msg
if image_process_mode == "Pad":

def expand2square(pil_img, background_color=(122, 116, 104)):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(
pil_img.mode, (width, width), background_color
)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(
pil_img.mode, (height, height), background_color
)
result.paste(pil_img, ((height - width) // 2, 0))
return result

image = expand2square(image)
elif image_process_mode in ["Default", "Crop"]:
pass
elif image_process_mode == "Resize":
image = image.resize((336, 336))
else:
raise ValueError(
f"Invalid image_process_mode: {image_process_mode}"
)
max_hw, min_hw = max(image.size), min(image.size)
aspect_ratio = max_hw / min_hw
max_len, min_len = 800, 400
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
longest_edge = int(shortest_edge * aspect_ratio)
W, H = image.size
if longest_edge != max(image.size):
if H > W:
H, W = longest_edge, shortest_edge
else:
H, W = shortest_edge, longest_edge
image = image.resize((W, H))
if return_pil:
images.append(image)
else:
buffered = BytesIO()
image.save(buffered, format="PNG")
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
images.append(img_b64_str)
return images

def set_system_message(self, system_message: str):
"""Set the system message."""
self.system_message = system_message
Expand All @@ -237,7 +298,30 @@ def to_gradio_chatbot(self):
ret = []
for i, (role, msg) in enumerate(self.messages[self.offset :]):
if i % 2 == 0:
ret.append([msg, None])
if type(msg) is tuple:
import base64
from io import BytesIO

msg, image, image_process_mode = msg
max_hw, min_hw = max(image.size), min(image.size)
aspect_ratio = max_hw / min_hw
max_len, min_len = 800, 400
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
longest_edge = int(shortest_edge * aspect_ratio)
W, H = image.size
if H > W:
H, W = longest_edge, shortest_edge
else:
H, W = shortest_edge, longest_edge
image = image.resize((W, H))
buffered = BytesIO()
image.save(buffered, format="JPEG")
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
msg = img_str + msg.replace("<image>", "").strip()
ret.append([msg, None])
else:
ret.append([msg, None])
else:
ret[-1][-1] = msg
return ret
Expand Down Expand Up @@ -270,6 +354,16 @@ def copy(self):
)

def dict(self):
if len(self.get_images()) > 0:
return {
"template_name": self.name,
"system_message": self.system_message,
"roles": self.roles,
"messages": [
[x, y[0] if type(y) is tuple else y] for x, y in self.messages
],
"offset": self.offset,
}
return {
"template_name": self.name,
"system_message": self.system_message,
Expand Down Expand Up @@ -1153,6 +1247,22 @@ def get_conv_template(name: str) -> Conversation:
)
)

# Llava template
# reference: conv_llava_llama_2 from https://github.com/haotian-liu/LLaVA/blob/main/llava/conversation.py
register_conv_template(
Conversation(
name="llava",
system_template="You are a helpful language and vision assistant. "
"You are able to understand the visual content that the user provides, "
"and assist the user with a variety of tasks using natural language.",
roles=("USER", "ASSISTANT"),
sep_style=SeparatorStyle.LLAMA2,
sep="<s>",
sep2="</s>",
stop_str="</s>",
)
)


if __name__ == "__main__":
from fastchat.conversation import get_conv_template
Expand Down
8 changes: 8 additions & 0 deletions fastchat/model/llava/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Model Constants
IGNORE_INDEX = -100
IMAGE_TOKEN_INDEX = -200
DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
DEFAULT_IM_START_TOKEN = "<im_start>"
DEFAULT_IM_END_TOKEN = "<im_end>"
IMAGE_PLACEHOLDER = "<image-placeholder>"
169 changes: 169 additions & 0 deletions fastchat/model/llava/language_model/llava_llama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# Adapted from haotian-liu/LLaVA
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from typing import List, Optional, Tuple, Union

import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss

from transformers import (
AutoConfig,
AutoModelForCausalLM,
LlamaConfig,
LlamaModel,
LlamaForCausalLM,
)

from transformers.modeling_outputs import CausalLMOutputWithPast

from fastchat.model.llava.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM


class LlavaConfig(LlamaConfig):
model_type = "llava"


class LlavaLlamaModel(LlavaMetaModel, LlamaModel):
config_class = LlavaConfig

def __init__(self, config: LlamaConfig):
super(LlavaLlamaModel, self).__init__(config)


class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
config_class = LlavaConfig

def __init__(self, config):
super(LlamaForCausalLM, self).__init__(config)
self.model = LlavaLlamaModel(config)

self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

# Initialize weights and apply final processing
self.post_init()

def get_model(self):
return self.model

def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
images: Optional[torch.FloatTensor] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)

(
input_ids,
attention_mask,
past_key_values,
inputs_embeds_prepared,
labels,
) = self.prepare_inputs_labels_for_multimodal(
input_ids, attention_mask, past_key_values, labels, images
)
if inputs_embeds is None:
inputs_embeds = inputs_embeds_prepared

# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

hidden_states = outputs[0]
logits = self.lm_head(hidden_states)

loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model/pipeline parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)

if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output

return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
**kwargs,
):
if past_key_values:
input_ids = input_ids[:, -1:]

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}

model_inputs.update(
{
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
"images": kwargs.get("images", None),
}
)
return model_inputs


AutoConfig.register("llava", LlavaConfig)
AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)
Loading