Skip to content

[CVPR 2024 Oral] InternVL Family: A Pioneering Open-Source Alternative to GPT-4V. 最接近GPT-4V表现的可商用开源模型

License

Notifications You must be signed in to change notification settings

moewhale/InternVL

 
 

Repository files navigation

image InternVL Family: Closing the Gap to Commercial Multimodal Models with Open-Source Suites —— A Pioneering Open-Source Alternative to GPT-4V

[Update Blog] [Paper] [InternVL 1.5 Technical Report] [Chat Demo] [HuggingFace Demo] [Quick Start] [中文解读]

News🚀🚀🚀

  • 2024/04/28: We release the INT8 version of InternVL-Chat-V1-5, see here.
  • 2024/04/28: We achieve the SOTA performance (75.74) on the Infographics VQA benchmark, see here.
  • 2024/04/18: InternVL-Chat-V1.5 has been released at HF link, approaching the performance of GPT-4V and Gemini Pro on various benchmarks like MMMU, DocVQA, ChartQA, MathVista, etc.
  • 2024/02/27: InternVL is accepted by CVPR 2024! 🎉
  • 2024/02/24: InternVL-Chat models have been included in the VLMEvalKit.
  • 2024/02/21: InternVL-Chat-V1.2-Plus achieves SOTA performance on MathVista (59.9), MMBench (83.8), and MMVP (58.7). See our blog for more details.
  • 2024/02/12: InternVL-Chat-V1.2 has been released. It achieves 51.6 on MMMU val and 82.3 on MMBench test. For more details, please refer to our blog, SFT data or try our demo. The model is now available on HuggingFace, and both training/evaluation data and scripts are open-sourced.
  • 2024/02/04: InternVL-Chat-V1.1 achieves 44.67% on MMVP, higher than GPT-4V!
  • 2024/01/27: We release 448 resolution model, achieving 76.6 on MMBench dev, see here.
  • 2024/01/24: InternVL-Chat-V1.1 is released, it supports Chinese and has stronger OCR capability, see here or try our demo.
  • 2024/01/16: We release our customized mmcv/mmsegmentation/mmdetection code, integrated with DeepSpeed, which can be used for training large-scale object detection and semantic segmentation models.

Documents

  • How to Evaluate InternVL-Chat-V1-5? [link]
  • How to Evaluate InternVL-Chat-V1-5 using VLMEvalKit? (Recommend) [link]

Compared with SOTA VLLMs

image

image image

What is InternVL?

InternVL scales up the ViT to 6B parameters and aligns it with LLM.

Model Zoo

Vision Large Language Model

Model Date Download Note
InternVL−Chat−V1.5-Int8 2024.04.28 🤗 HF link The INT8 version of InternVL-Chat-V1-5
InternVL−Chat−V1.5 2024.04.18 🤗 HF link support 4K image; super strong OCR; Approaching the performance of GPT-4V and Gemini Pro on various benchmarks like MMMU, DocVQA, ChartQA, MathVista, etc. (🔥new)
InternVL−Chat−V1.2−Plus 2024.02.21 🤗 HF link more SFT data and stronger
InternVL−Chat−V1.2 2024.02.11 🤗 HF link scaling up LLM to 34B
InternVL−Chat−V1.1 2024.01.24 🤗 HF link support Chinese and stronger OCR
InternVL−Chat−19B−448px 2024.02.03 🤗 HF link 448 resolution
InternVL−Chat−19B 2023.12.25 🤗 HF link English multimodal dialogue
InternVL−Chat−13B 2023.12.25 🤗 HF link English multimodal dialogue

Vision-Language Foundation Model

Model Date Download Note
InternViT−6B−448px−V1.5 2024.04.20 🤗 HF link support dynamic resolution, super strong OCR (🔥new)
InternViT−6B−448px−V1.2 2024.02.11 🤗 HF link 448 resolution
InternViT−6B−448px−V1.0 2024.01.30 🤗 HF link 448 resolution
InternViT−6B−224px 2023.12.22 🤗 HF link vision foundation model
InternVL−14B−224px 2023.12.22 🤗 HF link vision-language foundation model

What can InternVL do?

Visual Perception (click to expand)
  • Linear-Probe Image Classification [see details]

    ViT-22B uses the private JFT-3B dataset.

    method #param IN-1K IN-ReaL IN-V2 IN-A IN-R IN-Sketch
    OpenCLIP-G 1.8B 86.2 89.4 77.2 63.8 87.8 66.4
    DINOv2-g 1.1B 86.5 89.6 78.4 75.9 78.8 62.5
    EVA-01-CLIP-g 1.1B 86.5 89.3 77.4 70.5 87.7 63.1
    MAWS-ViT-6.5B 6.5B 87.8 - - - - -
    ViT-22B* 21.7B 89.5 90.9 83.2 83.8 87.4
    InternViT-6B (ours) 5.9B 88.2 90.4 79.9 77.5 89.8 69.1
  • Semantic Segmentation [see details]

    method decoder #param (train/total) crop size mIoU
    OpenCLIP-G (frozen) Linear 0.3M / 1.8B 512 39.3
    ViT-22B (frozen) Linear 0.9M / 21.7B 504 34.6
    InternViT-6B (frozen) Linear 0.5M / 5.9B 504 47.2 (+12.6)
    ViT-22B (frozen) UperNet 0.8B / 22.5B 504 52.7
    InternViT-6B (frozen) UperNet 0.4B / 6.3B 504 54.9 (+2.2)
    ViT-22B UperNet 22.5B / 22.5B 504 55.3
    InternViT-6B UperNet 6.3B / 6.3B 504 58.9 (+3.6)
  • Zero-Shot Image Classification [see details]

    method IN-1K IN-A IN-R IN-V2 IN-Sketch ObjectNet
    OpenCLIP-G 80.1 69.3 92.1 73.6 68.9 73.0
    EVA-02-CLIP-E+ 82.0 82.1 94.5 75.7 71.6 79.6
    ViT-22B* 85.9 90.1 96.0 80.9 87.6
    InternVL-C (ours) 83.2 83.8 95.5 77.3 73.9 80.6
  • Multilingual Zero-Shot Image Classification [see details]

    EN: English, ZH: Chinese, JP: Japanese, Ar: Arabic, IT: Italian

    method IN-1K (EN) IN-1K (ZH) IN-1K (JP) IN-1K (AR) IN-1K (IT)
    Taiyi-CLIP-ViT-H - 54.4 - - -
    WuKong-ViT-L-G - 57.5 - - -
    CN-CLIP-ViT-H - 59.6 - - -
    AltCLIP-ViT-L 74.5 59.6 - - -
    EVA-02-CLIP-E+ 82.0 - - - 41.2
    OpenCLIP-XLM-R-H 77.0 55.7 53.1 37.0 56.8
    InternVL-C (ours) 83.2 64.5 61.5 44.9 65.7
  • Zero-Shot Video Classification [see details]

    method #frame K400 K600 K700
    OpenCLIP-G 1 65.9 66.1 59.2
    EVA-02-CLIP-E+ 1 69.8 69.3 63.4
    InternVL-C (ours) 1 71.0 71.3 65.7
    ViCLIP 8 75.7 73.5 66.4
    InternVL-C (ours) 8 79.4 78.8 71.5
Cross-Modal Retrieval (click to expand)
  • English Zero-Shot Image-Text Retrieval [see details]

    model Flickr30K COCO avg
    image-to-text text-to-image image-to-text text-to-image
    R@1 R@5 R@10 R@1 R@5 R@10 R@1 R@5 R@10 R@1 R@5 R@10
    OpenCLIP-G 92.9 99.3 99.8 79.5 95.0 97.1 67.3 86.9 92.6 51.4 74.9 83.0 85.0
    EVA-02-CLIP-E+ 93.9 99.4 99.8 78.8 94.2 96.8 68.8 87.8 92.8 51.1 75.0 82.7 85.1
    EVA-CLIP-8B 95.6 99.6 99.9 80.8 95.5 97.6 70.3 89.3 93.9 53.0 76.0 83.4 86.2
    InternVL-C (ours) 94.7 99.6 99.9 81.7 96.0 98.2 70.6 89.0 93.5 54.1 77.3 84.6 86.6
    InternVL-G (ours) 95.7 99.7 99.9 85.0 97.0 98.6 74.9 91.3 95.2 58.6 81.3 88.0 88.8
  • Chinese Zero-Shot Image-Text Retrieval [see details]

    model Flickr30K-CN COCO-CN avg
    image-to-text text-to-image image-to-text text-to-image
    R@1 R@5 R@10 R@1 R@5 R@10 R@1 R@5 R@10 R@1 R@5 R@10
    CN-CLIP-ViT-H 81.6 97.5 98.8 71.2 91.4 95.5 63.0 86.6 92.9 69.2 89.9 96.1 86.1
    OpenCLIP-XLM-R-H 86.1 97.5 99.2 71.0 90.5 94.9 70.0 91.5 97.0 66.1 90.8 96.0 87.6
    InternVL-C (ours) 90.3 98.8 99.7 75.1 92.9 96.4 68.8 92.0 96.7 68.9 91.9 96.5 89.0
    InternVL-G (ours) 92.9 99.4 99.8 77.7 94.8 97.3 71.4 93.9 97.7 73.8 94.4 98.1 90.9
  • Multilingual Zero-Shot Image-Text Retrieval on XTD [see details]

    method EN ES FR ZH IT KO RU JP average
    AltCLIP 95.4 94.1 92.9 95.1 94.2 94.4 91.8 91.7 93.7
    OpenCLIP-XLM-R-H 97.3 96.1 94.5 94.7 96.0 90.2 93.9 94.0 94.6
    InternVL-C (ours) 97.3 95.7 95.1 95.6 96.0 92.2 93.3 95.5 95.1
    InternVL-G (ours) 98.6 97.7 96.5 96.7 96.9 95.1 94.8 96.1 96.6
Multimodal Dialogue (see "Compared with SOTA VLLMs")

Installation

See INSTALLATION.md

Quick Start with Huggingface

using InternViT-6B (click to expand)
import torch
from PIL import Image
from transformers import AutoModel, CLIPImageProcessor

model = AutoModel.from_pretrained(
    'OpenGVLab/InternViT-6B-224px',
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    trust_remote_code=True).cuda().eval()

image = Image.open('./examples/image1.jpg').convert('RGB')

image_processor = CLIPImageProcessor.from_pretrained('OpenGVLab/InternViT-6B-224px')

pixel_values = image_processor(images=image, return_tensors='pt').pixel_values
pixel_values = pixel_values.to(torch.bfloat16).cuda()

outputs = model(pixel_values)
using InternVL-C(ontrastive) and InternVL-G(enerative) (click to expand)
import torch
from PIL import Image
from transformers import AutoModel, CLIPImageProcessor
from transformers import AutoTokenizer


model = AutoModel.from_pretrained(
    'OpenGVLab/InternVL-14B-224px',
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    trust_remote_code=True).cuda().eval()

image_processor = CLIPImageProcessor.from_pretrained('OpenGVLab/InternVL-14B-224px')

tokenizer = AutoTokenizer.from_pretrained(
    'OpenGVLab/InternVL-14B-224px', use_fast=False, add_eos_token=True)
tokenizer.pad_token_id = 0  # set pad_token_id to 0

images = [
    Image.open('./examples/image1.jpg').convert('RGB'),
    Image.open('./examples/image2.jpg').convert('RGB'),
    Image.open('./examples/image3.jpg').convert('RGB')
]
prefix = 'summarize:'
texts = [
    prefix + 'a photo of a red panda',  # English
    prefix + '一张熊猫的照片',  # Chinese
    prefix + '二匹の猫の写真'  # Japanese
]

pixel_values = image_processor(images=images, return_tensors='pt').pixel_values
pixel_values = pixel_values.to(torch.bfloat16).cuda()
input_ids = tokenizer(texts, return_tensors='pt', max_length=80,
                      truncation=True, padding='max_length').input_ids.cuda()

# InternVL-C
logits_per_image, logits_per_text = model(
    image=pixel_values, text=input_ids, mode='InternVL-C')
probs = logits_per_image.softmax(dim=-1)
# tensor([[9.9609e-01, 5.2185e-03, 6.0070e-08],
#         [2.2949e-02, 9.7656e-01, 5.9903e-06],
#         [3.2932e-06, 7.4863e-05, 1.0000e+00]], device='cuda:0',
#        dtype=torch.bfloat16, grad_fn=<SoftmaxBackward0>)

# InternVL-G
logits_per_image, logits_per_text = model(
    image=pixel_values, text=input_ids, mode='InternVL-G')
probs = logits_per_image.softmax(dim=-1)
# tensor([[9.9609e-01, 3.1738e-03, 3.6322e-08],
#         [8.6060e-03, 9.9219e-01, 2.8759e-06],
#         [1.7583e-06, 3.1233e-05, 1.0000e+00]], device='cuda:0',
#        dtype=torch.bfloat16, grad_fn=<SoftmaxBackward0>)

# please set add_eos_token to False for generation
tokenizer.add_eos_token = False
image = Image.open('./examples/image1.jpg').convert('RGB')
pixel_values = image_processor(images=image, return_tensors='pt').pixel_values
pixel_values = pixel_values.to(torch.bfloat16).cuda()

tokenized = tokenizer("English caption:", return_tensors='pt')
pred = model.generate(
    pixel_values=pixel_values,
    input_ids=tokenized.input_ids.cuda(),
    attention_mask=tokenized.attention_mask.cuda(),
    num_beams=5,
    min_new_tokens=8,
)
caption = tokenizer.decode(pred[0].cpu(), skip_special_tokens=True).strip()
# English caption: a red panda sitting on top of a wooden platform
using InternVL-Chat (click to expand)
from transformers import AutoTokenizer, AutoModel
import torch
import torchvision.transforms as T
from PIL import Image

from torchvision.transforms.functional import InterpolationMode


IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)


def build_transform(input_size):
    MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
    transform = T.Compose([
        T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
        T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize(mean=MEAN, std=STD)
    ])
    return transform


def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
    best_ratio_diff = float('inf')
    best_ratio = (1, 1)
    area = width * height
    for ratio in target_ratios:
        target_aspect_ratio = ratio[0] / ratio[1]
        ratio_diff = abs(aspect_ratio - target_aspect_ratio)
        if ratio_diff < best_ratio_diff:
            best_ratio_diff = ratio_diff
            best_ratio = ratio
        elif ratio_diff == best_ratio_diff:
            if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
                best_ratio = ratio
    return best_ratio


def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False):
    orig_width, orig_height = image.size
    aspect_ratio = orig_width / orig_height

    # calculate the existing image aspect ratio
    target_ratios = set(
        (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
        i * j <= max_num and i * j >= min_num)
    target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])

    # find the closest aspect ratio to the target
    target_aspect_ratio = find_closest_aspect_ratio(
        aspect_ratio, target_ratios, orig_width, orig_height, image_size)

    # calculate the target width and height
    target_width = image_size * target_aspect_ratio[0]
    target_height = image_size * target_aspect_ratio[1]
    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]

    # resize the image
    resized_img = image.resize((target_width, target_height))
    processed_images = []
    for i in range(blocks):
        box = (
            (i % (target_width // image_size)) * image_size,
            (i // (target_width // image_size)) * image_size,
            ((i % (target_width // image_size)) + 1) * image_size,
            ((i // (target_width // image_size)) + 1) * image_size
        )
        # split the image
        split_img = resized_img.crop(box)
        processed_images.append(split_img)
    assert len(processed_images) == blocks
    if use_thumbnail and len(processed_images) != 1:
        thumbnail_img = image.resize((image_size, image_size))
        processed_images.append(thumbnail_img)
    return processed_images


def load_image(image_file, input_size=448, max_num=6):
    image = Image.open(image_file).convert('RGB')
    transform = build_transform(input_size=input_size)
    images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
    pixel_values = [transform(image) for image in images]
    pixel_values = torch.stack(pixel_values)
    return pixel_values


path = "OpenGVLab/InternVL-Chat-V1-5"
# If you have an 80G A100 GPU, you can put the entire model on a single GPU.
model = AutoModel.from_pretrained(
    path,
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    trust_remote_code=True).eval().cuda()
# Otherwise, you need to set device_map='auto' to use multiple GPUs for inference.
# model = AutoModel.from_pretrained(
#     path,
#     torch_dtype=torch.bfloat16,
#     low_cpu_mem_usage=True,
#     trust_remote_code=True,
#     device_map='auto').eval()

tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
# set the max number of tiles in `max_num`
pixel_values = load_image('./examples/image1.jpg', max_num=6).to(torch.bfloat16).cuda()

generation_config = dict(
    num_beams=1,
    max_new_tokens=512,
    do_sample=False,
)

# single-round single-image conversation
question = "请详细描述图片" # Please describe the picture in detail
response = model.chat(tokenizer, pixel_values, question, generation_config)
print(question, response)

# multi-round single-image conversation
question = "请详细描述图片" # Please describe the picture in detail
response, history = model.chat(tokenizer, pixel_values, question, generation_config, history=None, return_history=True)
print(question, response)

question = "请根据图片写一首诗" # Please write a poem according to the picture
response, history = model.chat(tokenizer, pixel_values, question, generation_config, history=history, return_history=True)
print(question, response)

# multi-round multi-image conversation
pixel_values1 = load_image('./examples/image1.jpg', max_num=6).to(torch.bfloat16).cuda()
pixel_values2 = load_image('./examples/image2.jpg', max_num=6).to(torch.bfloat16).cuda()
pixel_values = torch.cat((pixel_values1, pixel_values2), dim=0)

question = "详细描述这两张图片" # Describe the two pictures in detail
response, history = model.chat(tokenizer, pixel_values, question, generation_config, history=None, return_history=True)
print(question, response)

question = "这两张图片的相同点和区别分别是什么" # What are the similarities and differences between these two pictures
response, history = model.chat(tokenizer, pixel_values, question, generation_config, history=history, return_history=True)
print(question, response)

# batch inference (single image per sample)
pixel_values1 = load_image('./examples/image1.jpg', max_num=6).to(torch.bfloat16).cuda()
pixel_values2 = load_image('./examples/image2.jpg', max_num=6).to(torch.bfloat16).cuda()
image_counts = [pixel_values1.size(0), pixel_values2.size(0)]
pixel_values = torch.cat((pixel_values1, pixel_values2), dim=0)

questions = ["Describe the image in detail."] * len(image_counts)
responses = model.batch_chat(tokenizer, pixel_values,
                             image_counts=image_counts,
                             questions=questions,
                             generation_config=generation_config)
for question, response in zip(questions, responses):
    print(question)
    print(response)

Chat Web Demo

Launch a local chat demo (click to expand)

Launch a controller

# run the command in the `internvl_chat_llava` folder
python -m llava.serve.controller --host 0.0.0.0 --port 10000

Launch a gradio web server

# run the command in the `internvl_chat_llava` folder
python -m llava.serve.gradio_web_server --controller http://localhost:10000 --model-list-mode reload

Launch a model worker

# OpenGVLab/InternVL-Chat-ViT-6B-Vicuna-7B
# run the command in the `internvl_chat_llava` folder
python -m llava.serve.model_worker --host 0.0.0.0 --controller http://localhost:10000 --port 40000 --worker http://localhost:40000 --model-path OpenGVLab/InternVL-Chat-ViT-6B-Vicuna-7B

# OpenGVLab/InternVL-Chat-ViT-6B-Vicuna-13B
# run the command in the `internvl_chat_llava` folder
python -m llava.serve.model_worker --host 0.0.0.0 --controller http://localhost:10000 --port 40001 --worker http://localhost:40001 --model-path OpenGVLab/InternVL-Chat-ViT-6B-Vicuna-13B

# OpenGVLab/InternVL-Chat-V1-1
# run the command in the `internvl_chat` folder
python -m internvl.serve.model_worker --host 0.0.0.0 --controller http://localhost:10000 --port 40002 --worker http://localhost:40002 --model-path OpenGVLab/InternVL-Chat-V1-1

# OpenGVLab/InternVL-Chat-V1-2
# run the command in the `internvl_chat` folder
python -m internvl.serve.model_worker --host 0.0.0.0 --controller http://localhost:10000 --port 40003 --worker http://localhost:40003 --model-path OpenGVLab/InternVL-Chat-V1-2

# OpenGVLab/InternVL-Chat-V1-2-Plus
# run the command in the `internvl_chat` folder
python -m internvl.serve.model_worker --host 0.0.0.0 --controller http://localhost:10000 --port 40004 --worker http://localhost:40004 --model-path OpenGVLab/InternVL-Chat-V1-2-Plus

# OpenGVLab/InternVL-Chat-V1-5
# run the command in the `internvl_chat` folder
python -m internvl.serve.model_worker --host 0.0.0.0 --controller http://localhost:10000 --port 40005 --worker http://localhost:40005 --model-path OpenGVLab/InternVL-Chat-V1-5

License

This project is released under the MIT license. Parts of this project contain code and models from other sources, which are subject to their respective licenses.

Citation

If you find this project useful in your research, please consider cite:

@article{chen2023internvl,
  title={InternVL: Scaling up Vision Foundation Models and Aligning for Generic Visual-Linguistic Tasks},
  author={Chen, Zhe and Wu, Jiannan and Wang, Wenhai and Su, Weijie and Chen, Guo and Xing, Sen and Zhong, Muyan and Zhang, Qinglong and Zhu, Xizhou and Lu, Lewei and Li, Bin and Luo, Ping and Lu, Tong and Qiao, Yu and Dai, Jifeng},
  journal={arXiv preprint arXiv:2312.14238},
  year={2023}
}

@article{chen2024far,
  title={How Far Are We to GPT-4V? Closing the Gap to Commercial Multimodal Models with Open-Source Suites},
  author={Chen, Zhe and Wang, Weiyun and Tian, Hao and Ye, Shenglong and Gao, Zhangwei and Cui, Erfei and Tong, Wenwen and Hu, Kongzhi and Luo, Jiapeng and Ma, Zheng and others},
  journal={arXiv preprint arXiv:2404.16821},
  year={2024}
}

Acknowledgement

InternVL is built with reference to the code of the following projects: OpenAI CLIP, Open CLIP, CLIP Benchmark, EVA, InternImage, ViT-Adapter, MMSegmentation, Transformers, DINOv2, BLIP-2, Qwen-VL, and LLaVA-1.5. Thanks for their awesome work!


If you want to join our WeChat group, please scan the following QR Code to add our assistant as a Wechat friend:

image

About

[CVPR 2024 Oral] InternVL Family: A Pioneering Open-Source Alternative to GPT-4V. 最接近GPT-4V表现的可商用开源模型

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Jupyter Notebook 49.5%
  • Python 47.7%
  • Shell 2.3%
  • JavaScript 0.2%
  • HTML 0.2%
  • Makefile 0.1%