Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changes to image handling, added arguments and eval_model function. #1205

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 31 additions & 14 deletions llava/eval/run_llava.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
import torch
from io import BytesIO

from llava.constants import (
IMAGE_TOKEN_INDEX,
Expand All @@ -26,16 +27,23 @@


def image_parser(args):
out = args.image_file.split(args.sep)
if type(args.image_file) is str:
out = args.image_file.split(args.sep)
else:
out = [args.image_file]
return out


def load_image(image_file):
if image_file.startswith("http") or image_file.startswith("https"):
if type(image_file) is str and (
image_file.startswith("http") or image_file.startswith("https")
):
response = requests.get(image_file)
image = Image.open(BytesIO(response.content)).convert("RGB")
else:
elif type(image_file) is str:
image = Image.open(image_file).convert("RGB")
else:
image = Image.open(BytesIO(image_file)).convert("RGB")
return image


Expand All @@ -51,12 +59,19 @@ def eval_model(args):
# Model
disable_torch_init()

model_name = get_model_name_from_path(args.model_path)
tokenizer, model, image_processor, context_len = load_pretrained_model(
args.model_path, args.model_base, model_name
)

tokenizer = args.tokenizer
model = args.model
image_processor = args.image_processor
model_path = args.model_path
qs = args.query

model_name = get_model_name_from_path(model_path)

if model is None or tokenizer is None or image_processor is None:
tokenizer, model, image_processor, context_len = load_pretrained_model(
model_path, args.model_base, args.model_name
)

image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
if IMAGE_PLACEHOLDER in qs:
if model.config.mm_use_im_start_end:
Expand Down Expand Up @@ -99,11 +114,9 @@ def eval_model(args):
image_files = image_parser(args)
images = load_images(image_files)
image_sizes = [x.size for x in images]
images_tensor = process_images(
images,
image_processor,
model.config
).to(model.device, dtype=torch.float16)
images_tensor = process_images(images, image_processor, model.config).to(
model.device, dtype=torch.float16
)

input_ids = (
tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
Expand All @@ -126,13 +139,17 @@ def eval_model(args):

outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
print(outputs)
return outputs


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
parser.add_argument("--model-base", type=str, default=None)
parser.add_argument("--image-file", type=str, required=True)
parser.add_argument("--model", default=None)
parser.add_argument("--tokenizer", default=None)
parser.add_argument("--image_processor", default=None)
parser.add_argument("--image-file", required=True)
parser.add_argument("--query", type=str, required=True)
parser.add_argument("--conv-mode", type=str, default=None)
parser.add_argument("--sep", type=str, default=",")
Expand Down