You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi,
Big fan of your work!
Recently, I am working with a project that need to extract two kinds of latent representation of the input from the model.
The first one is the latent representation of the pure image input, without any text.
The second one is the latent representation of the pure text input, without any image.
Currently, I already modified the run_llava.py, so that the model can generate answer based on pure image or pure text input. However, I need help on how to extract the representation of the input.
Below is my code:
import argparse
import torch
from llava.constants import (
IMAGE_TOKEN_INDEX,
DEFAULT_IMAGE_TOKEN,
DEFAULT_IM_START_TOKEN,
DEFAULT_IM_END_TOKEN,
IMAGE_PLACEHOLDER,
)
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import (
process_images,
tokenizer_image_token,
get_model_name_from_path,
)
from PIL import Image
import requests
from PIL import Image
from io import BytesIO
import re
def image_parser(args):
out = args.image_file.split(args.sep)
return out
def load_image(image_file):
if image_file.startswith("http") or image_file.startswith("https"):
response = requests.get(image_file)
image = Image.open(BytesIO(response.content)).convert("RGB")
else:
image = Image.open(image_file).convert("RGB")
return image
def load_images(image_files):
out = []
for image_file in image_files:
image = load_image(image_file)
out.append(image)
return out
def eval_model(args):
# Model
disable_torch_init()
model_name = get_model_name_from_path(args.model_path)
tokenizer, model, image_processor, context_len = load_pretrained_model(
args.model_path, args.model_base, model_name
)
qs = args.query
if args.image_file is not None:
image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
if IMAGE_PLACEHOLDER in qs:
if model.config.mm_use_im_start_end:
qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs)
else:
qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs)
else:
if model.config.mm_use_im_start_end:
qs = image_token_se + "\n" + qs
else:
qs = DEFAULT_IMAGE_TOKEN + "\n" + qs
else:
qs = qs
if "llama-2" in model_name.lower():
conv_mode = "llava_llama_2"
elif "mistral" in model_name.lower():
conv_mode = "mistral_instruct"
elif "v1.6-34b" in model_name.lower():
conv_mode = "chatml_direct"
elif "v1" in model_name.lower():
conv_mode = "llava_v1"
elif "mpt" in model_name.lower():
conv_mode = "mpt"
else:
conv_mode = "llava_v0"
if args.conv_mode is not None and conv_mode != args.conv_mode:
print(
"[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(
conv_mode, args.conv_mode, args.conv_mode
)
)
else:
args.conv_mode = conv_mode
# qs = ""
conv = conv_templates[args.conv_mode].copy()
conv.append_message(conv.roles[0], qs)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
if args.image_file is not None:
image_files = image_parser(args)
images = load_images(image_files)
image_sizes = [x.size for x in images]
images_tensor = process_images(
images,
image_processor,
model.config
).to(model.device, dtype=torch.float16)
else:
images_tensor = None
image_sizes = None
input_ids = (
tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
.unsqueeze(0)
.cuda()
)
with torch.inference_mode():
output_ids = model.generate(
input_ids,
images=images_tensor,
image_sizes=image_sizes,
do_sample=True if args.temperature > 0 else False,
temperature=args.temperature,
top_p=args.top_p,
num_beams=args.num_beams,
max_new_tokens=args.max_new_tokens,
use_cache=True,
)
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
print(outputs)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", type=str, default="liuhaotian/llava-v1.5-7b")
parser.add_argument("--model-base", type=str, default=None)
parser.add_argument("--image-file", type=str, default="../../val2017/000000000139.jpg") # ../../val2017/000000000139.jpg
parser.add_argument("--query", type=str, default="Describe this image")
parser.add_argument("--conv-mode", type=str, default=None)
parser.add_argument("--sep", type=str, default=",")
parser.add_argument("--temperature", type=float, default=0.2)
parser.add_argument("--top_p", type=float, default=None)
parser.add_argument("--num_beams", type=int, default=1)
parser.add_argument("--max_new_tokens", type=int, default=512)
args = parser.parse_args()
eval_model(args)
Thank you so much for your help!
The text was updated successfully, but these errors were encountered:
Question
Hi,
Big fan of your work!
Recently, I am working with a project that need to extract two kinds of latent representation of the input from the model.
Currently, I already modified the
run_llava.py
, so that the model can generate answer based on pure image or pure text input. However, I need help on how to extract the representation of the input.Below is my code:
Thank you so much for your help!
The text was updated successfully, but these errors were encountered: