Skip to content

Commit

Permalink
remove unused args and create utility fuction
Browse files Browse the repository at this point in the history
  • Loading branch information
BabyChouSr committed Nov 11, 2023
1 parent 8173ee3 commit 86039cf
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 49 deletions.
70 changes: 29 additions & 41 deletions fastchat/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,35 +219,39 @@ def get_prompt(self) -> str:
else:
raise ValueError(f"Invalid style: {self.sep_style}")

def get_images(self, return_pil=False):
def extract_base64encoded_image_from_message(self, message):
"""Given a message with an input tuple of (str, PIL.image), we return the base64 encoded image string."""
import base64
from io import BytesIO
from PIL import Image

msg, image = msg
max_hw, min_hw = max(image.size), min(image.size)
aspect_ratio = max_hw / min_hw
max_len, min_len = 800, 400
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
longest_edge = int(shortest_edge * aspect_ratio)
W, H = image.size
if longest_edge != max(image.size):
if H > W:
H, W = longest_edge, shortest_edge
else:
H, W = shortest_edge, longest_edge
image = image.resize((W, H))

buffered = BytesIO()
image.save(buffered, format="PNG")
img_b64_str = base64.b64encode(buffered.getvalue()).decode()

return img_b64_str

def get_images(self):
images = []
for i, (role, msg) in enumerate(self.messages[self.offset :]):
if i % 2 == 0:
if type(msg) is tuple:
import base64
from io import BytesIO
from PIL import Image
images.append(self.extract_base64encoded_image_from_message(msg))

msg, image = msg
max_hw, min_hw = max(image.size), min(image.size)
aspect_ratio = max_hw / min_hw
max_len, min_len = 800, 400
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
longest_edge = int(shortest_edge * aspect_ratio)
W, H = image.size
if longest_edge != max(image.size):
if H > W:
H, W = longest_edge, shortest_edge
else:
H, W = shortest_edge, longest_edge
image = image.resize((W, H))
if return_pil:
images.append(image)
else:
buffered = BytesIO()
image.save(buffered, format="PNG")
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
images.append(img_b64_str)
return images

def set_system_message(self, system_message: str):
Expand All @@ -272,24 +276,8 @@ def to_gradio_chatbot(self):
for i, (role, msg) in enumerate(self.messages[self.offset :]):
if i % 2 == 0:
if type(msg) is tuple:
import base64
from io import BytesIO

msg, image = msg
max_hw, min_hw = max(image.size), min(image.size)
aspect_ratio = max_hw / min_hw
max_len, min_len = 800, 400
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
longest_edge = int(shortest_edge * aspect_ratio)
W, H = image.size
if H > W:
H, W = longest_edge, shortest_edge
else:
H, W = shortest_edge, longest_edge
image = image.resize((W, H))
buffered = BytesIO()
image.save(buffered, format="JPEG")
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
img_b64_str = self.extract_base64encoded_image_from_message(msg)
img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
msg = img_str + msg.replace("<image>", "").strip()
ret.append([msg, None])
Expand Down
2 changes: 1 addition & 1 deletion fastchat/serve/gradio_web_server_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def add_text(state, model_selector, text, image, request: gr.Request):
no_change_btn,
) * 5

if image is not None and len(state.conv.get_images(return_pil=True)) > 0:
if image is not None and len(state.conv.get_images()) > 0:
# reset convo with new image
state.conv = get_conversation_template(state.model_name)

Expand Down
7 changes: 1 addition & 6 deletions fastchat/serve/multimodal_model_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ def __init__(
device: str,
num_gpus: int,
max_gpu_memory: str,
multimodal: bool,
dtype: Optional[torch.dtype] = None,
load_8bit: bool = False,
load_4bit: bool = False,
Expand All @@ -75,7 +74,7 @@ def __init__(
self.controller_addr = controller_addr
self.worker_addr = worker_addr
self.worker_id = worker_id
self.multimodal = multimodal
self.multimodal = True

logger.info(f"Loading the model {self.model_names} on worker {worker_id} ...")

Expand Down Expand Up @@ -157,14 +156,11 @@ def create_multimodal_model_worker():
parser.add_argument(
"--controller-address", type=str, default="http://localhost:21001"
)
# FOR PEFT (not supported yet): parser.add_argument("--model-base", type=str, default=None)
parser.add_argument("--embed-in-truncate", action="store_true")
parser.add_argument(
"--model-names",
type=lambda s: s.split(","),
help="Optional display comma separated names",
)
parser.add_argument("--multimodal", action="store_true", default=True)
parser.add_argument(
"--conv-template", type=str, default=None, help="Conversation prompt template."
)
Expand Down Expand Up @@ -196,7 +192,6 @@ def create_multimodal_model_worker():
device=args.device,
num_gpus=args.num_gpus,
max_gpu_memory=args.max_gpu_memory,
multimodal=args.multimodal,
dtype=str_to_torch_dtype(args.dtype),
load_8bit=args.load_8bit,
cpu_offloading=args.cpu_offloading,
Expand Down
3 changes: 2 additions & 1 deletion fastchat/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import warnings

import requests
from PIL import Image

from fastchat.constants import LOGDIR

Expand Down Expand Up @@ -336,6 +335,8 @@ def str_to_torch_dtype(dtype: str):


def load_image(image_file):
from PIL import Image

if image_file.startswith("http://") or image_file.startswith("https://"):
response = requests.get(image_file)
image = Image.open(BytesIO(response.content)).convert("RGB")
Expand Down

0 comments on commit 86039cf

Please sign in to comment.