From 7cbd9ec7a9bfd4952ad522355b6bbb8e82b54fc9 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Mon, 29 Jul 2024 18:16:30 +0800 Subject: [PATCH 01/79] [Model] Initialize support for InternVL2 series models (#6514) Co-authored-by: Roger Wang --- docs/source/models/supported_models.rst | 4 + examples/offline_inference_vision_language.py | 15 + examples/openai_vision_api_client.py | 2 + requirements-test.txt | 1 + tests/models/test_internvl.py | 201 ++++++++ vllm/entrypoints/chat_utils.py | 2 +- vllm/model_executor/models/__init__.py | 1 + vllm/model_executor/models/intern_vit.py | 270 ++++++++++ vllm/model_executor/models/internlm2.py | 10 +- vllm/model_executor/models/internvl.py | 471 ++++++++++++++++++ vllm/model_executor/models/qwen2.py | 10 +- vllm/transformers_utils/config.py | 8 +- vllm/transformers_utils/configs/__init__.py | 2 + vllm/transformers_utils/configs/internvl.py | 51 ++ 14 files changed, 1042 insertions(+), 6 deletions(-) create mode 100644 tests/models/test_internvl.py create mode 100644 vllm/model_executor/models/intern_vit.py create mode 100644 vllm/model_executor/models/internvl.py create mode 100644 vllm/transformers_utils/configs/internvl.py diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 83c1b9c8bce86..4fe33e5ab5d80 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -200,6 +200,10 @@ Vision Language Models - Fuyu - :code:`adept/fuyu-8b` etc. - + * - :code:`InternVLChatModel` + - InternVL2 + - :code:`OpenGVLab/InternVL2-4B`, :code:`OpenGVLab/InternVL2-8B`, etc. + - * - :code:`LlavaForConditionalGeneration` - LLaVA-1.5 - :code:`llava-hf/llava-1.5-7b-hf`, :code:`llava-hf/llava-1.5-13b-hf`, etc. diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index 04ba1a96314c9..846246a2062a6 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -106,6 +106,20 @@ def run_minicpmv(question): return llm, prompt +# InternVL +def run_internvl(question): + # Generally, InternVL can use chatml template for conversation + TEMPLATE = "<|im_start|>User\n{prompt}<|im_end|>\n<|im_start|>Assistant\n" + prompt = f"\n{question}\n" + prompt = TEMPLATE.format(prompt=prompt) + llm = LLM( + model="OpenGVLab/InternVL2-4B", + trust_remote_code=True, + max_num_seqs=5, + ) + return llm, prompt + + # BLIP-2 def run_blip2(question): @@ -125,6 +139,7 @@ def run_blip2(question): "chameleon": run_chameleon, "minicpmv": run_minicpmv, "blip-2": run_blip2, + "internvl_chat": run_internvl, } diff --git a/examples/openai_vision_api_client.py b/examples/openai_vision_api_client.py index 2082c378e267c..be90394511f89 100644 --- a/examples/openai_vision_api_client.py +++ b/examples/openai_vision_api_client.py @@ -42,6 +42,7 @@ ], }], model=model, + max_tokens=64, ) result = chat_completion_from_url.choices[0].message.content @@ -78,6 +79,7 @@ def encode_image_base64_from_url(image_url: str) -> str: ], }], model=model, + max_tokens=64, ) result = chat_completion_from_base64.choices[0].message.content diff --git a/requirements-test.txt b/requirements-test.txt index a7604d2e1015e..9b88fcce3e842 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -16,6 +16,7 @@ ray sentence-transformers # required for embedding sparseml==1.8.0 # required for compressed-tensors compressed-tensors==0.4.0 # required for compressed-tensors +timm # required for internvl test # Benchmarking aiohttp diff --git a/tests/models/test_internvl.py b/tests/models/test_internvl.py new file mode 100644 index 0000000000000..66cb8dda248db --- /dev/null +++ b/tests/models/test_internvl.py @@ -0,0 +1,201 @@ +import types +from typing import List, Optional, Type + +import pytest +import torch +from huggingface_hub import snapshot_download +from PIL.Image import Image + +from vllm.model_executor.models.internvl import (IMG_CONTEXT, IMG_END, + IMG_START, + image_to_pixel_values) +from vllm.multimodal.utils import rescale_image_size +from vllm.utils import is_cpu + +from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets +from .utils import check_logprobs_close + +pytestmark = pytest.mark.vlm + +HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ + "stop_sign": + "<|im_start|>User\n\nWhat's the content in the center of the image?<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501 + "cherry_blossom": + "<|im_start|>User\n\nWhat is the season?<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501 +}) + +# we use snapshot_download to prevent conflicts between +# dynamic_module and trust_remote_code for hf_runner +models = [ + snapshot_download("OpenGVLab/InternVL2-1B"), + snapshot_download("OpenGVLab/InternVL2-2B"), + # snapshot_download("OpenGVLab/InternVL2-4B"), # broken +] + + +class InternVLProcessor: + """A simple processor for InternVL2 HF model which misses a processor.""" + + def __init__(self, hf_runner: HfRunner): + self.num_image_token = hf_runner.model.num_image_token + self.tokenizer = hf_runner.tokenizer + self.dtype = hf_runner.model.dtype + + def __call__(self, text: str, images: Image, **kwargs): + pixel_values = image_to_pixel_values(images).to(self.dtype) + num_patches_list = [pixel_values.shape[0]] + for num_patches in num_patches_list: + context_tokens = IMG_CONTEXT * self.num_image_token * num_patches + image_tokens = IMG_START + context_tokens + IMG_END + text = text.replace('', image_tokens, 1) + prompt = self.tokenizer(text, return_tensors="pt") + prompt.update({"pixel_values": pixel_values}) + return prompt + + +# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B/blob/main/modeling_internvl_chat.py +def generate( + self, + pixel_values: torch.FloatTensor, + input_ids: torch.FloatTensor, + attention_mask: Optional[torch.LongTensor] = None, + **generate_kwargs, +) -> torch.LongTensor: + """Generate method for InternVL2 model without fixed use_cache.""" + assert self.img_context_token_id is not None + vit_embeds = self.extract_feature(pixel_values) + input_embeds = self.language_model.get_input_embeddings()(input_ids) + B, N, C = input_embeds.shape + input_embeds = input_embeds.reshape(B * N, C) + + input_ids = input_ids.reshape(B * N) + selected = (input_ids == self.img_context_token_id) + assert selected.sum() != 0 + input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device) + + input_embeds = input_embeds.reshape(B, N, C) + + outputs = self.language_model.generate( + inputs_embeds=input_embeds, + attention_mask=attention_mask, + **generate_kwargs, + ) + + return outputs + + +def run_test( + hf_runner: Type[HfRunner], + vllm_runner: Type[VllmRunner], + image_assets: _ImageAssets, + model: str, + *, + size_factors: List[float], + dtype: str, + max_tokens: int, + num_logprobs: int, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, +): + """Inference result should be the same between hf and vllm. + + All the image fixtures for the test is under tests/images. + For huggingface runner, we provide the PIL images as input. + For vllm runner, we provide MultiModalDataDict objects + and corresponding vision language config as input. + Note, the text input is also adjusted to abide by vllm contract. + The text output is sanitized to be able to compare with hf. + """ + images = [asset.pil_image for asset in image_assets] + + inputs_per_image = [( + [prompt for _ in size_factors], + [rescale_image_size(image, factor) for factor in size_factors], + ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] + + # NOTE: take care of the order. run vLLM first, and then run HF. + # vLLM needs a fresh new process without cuda initialization. + # if we run HF first, the cuda initialization will be done and it + # will hurt multiprocessing backend with fork method (the default method). + + # max_model_len should be greater than image_feature_size + with vllm_runner(model, + max_model_len=4096, + dtype=dtype, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, + enforce_eager=True) as vllm_model: + vllm_outputs_per_image = [ + vllm_model.generate_greedy_logprobs(prompts, + max_tokens, + num_logprobs=num_logprobs, + images=images) + for prompts, images in inputs_per_image + ] + + with hf_runner(model, dtype=dtype) as hf_model: + img_context_token_id = hf_model.tokenizer.convert_tokens_to_ids( + "") + hf_model.model.img_context_token_id = img_context_token_id + hf_model.processor = InternVLProcessor(hf_model) + hf_model.model.get_output_embeddings = lambda: \ + hf_model.model.language_model.get_output_embeddings() + hf_model.model.generate = types.MethodType(generate, hf_model.model) + eos_token_id = hf_model.tokenizer.eos_token_id + hf_outputs_per_image = [ + hf_model.generate_greedy_logprobs_limit(prompts, + max_tokens, + num_logprobs=num_logprobs, + images=hf_images, + eos_token_id=eos_token_id) + for prompts, hf_images in inputs_per_image + ] + + for hf_outputs, vllm_outputs in zip(hf_outputs_per_image, + vllm_outputs_per_image): + # TODO: Check whether using original CLIPVisionModel can improve + # consistency against HF + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) + + +target_dtype = "half" +if is_cpu(): + target_dtype = "bfloat16" + + +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize( + "size_factors", + [ + # No image + [], + # Single-scale + [1.0], + # Single-scale, batched + [1.0, 1.0, 1.0], + # Multi-scale + [0.25, 0.5, 1.0], + ], +) +@pytest.mark.parametrize("dtype", [target_dtype]) +@pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("num_logprobs", [5]) +@torch.inference_mode() +def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, + dtype: str, max_tokens: int, num_logprobs: int) -> None: + run_test( + hf_runner, + vllm_runner, + image_assets, + model, + size_factors=size_factors, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=1, + ) diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 1f6d77b828459..fbb7f70b55e16 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -107,7 +107,7 @@ def _image_token_str(model_config: ModelConfig, return None if model_type.startswith("llava"): return tokenizer.decode(model_config.hf_config.image_token_index) - if model_type == "chameleon": + if model_type in ("chameleon", "internvl_chat"): return "" raise TypeError(f"Unknown model type: {model_type}") diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index fe04c6db5fbc2..94c3cea98be7b 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -37,6 +37,7 @@ "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"), "InternLMForCausalLM": ("llama", "LlamaForCausalLM"), "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"), + "InternVLChatModel": ("internvl", "InternVLChatModel"), "JAISLMHeadModel": ("jais", "JAISLMHeadModel"), "LlamaForCausalLM": ("llama", "LlamaForCausalLM"), "LlavaForConditionalGeneration": diff --git a/vllm/model_executor/models/intern_vit.py b/vllm/model_executor/models/intern_vit.py new file mode 100644 index 0000000000000..86d0930d80126 --- /dev/null +++ b/vllm/model_executor/models/intern_vit.py @@ -0,0 +1,270 @@ +# adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_intern_vit.py +# -------------------------------------------------------- +# InternVL +# Copyright (c) 2023 OpenGVLab +# Licensed under The MIT License [see LICENSE for details] +# -------------------------------------------------------- +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import PretrainedConfig + +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization import QuantizationConfig + +NORM2FN = { + 'rms_norm': RMSNorm, + 'layer_norm': nn.LayerNorm, +} + + +class InternVisionEmbeddings(nn.Module): + + def __init__(self, config: PretrainedConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim)) + + self.patch_embedding = nn.Conv2d(in_channels=3, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size) + + self.num_patches = (self.image_size // self.patch_size)**2 + self.num_positions = self.num_patches + 1 + + self.position_embedding = nn.Parameter( + torch.randn(1, self.num_positions, self.embed_dim)) + + def _get_pos_embed(self, pos_embed, H, W): + target_dtype = pos_embed.dtype + pos_embed = pos_embed.float().reshape( + 1, self.image_size // self.patch_size, + self.image_size // self.patch_size, -1).permute(0, 3, 1, 2) + pos_embed = F.interpolate(pos_embed, + size=(H, W), + mode='bicubic', + align_corners=False) + pos_embed = pos_embed.reshape(1, -1, H * W).permute(0, 2, + 1).to(target_dtype) + return pos_embed + + def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding(pixel_values.to( + target_dtype)) # shape = [*, channel, width, height] + batch_size, _, height, width = patch_embeds.shape + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + class_embeds = self.class_embedding.expand(batch_size, 1, + -1).to(target_dtype) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + position_embedding = torch.cat([ + self.position_embedding[:, :1, :], + self._get_pos_embed(self.position_embedding[:, 1:, :], height, + width) + ], + dim=1) + embeddings = embeddings + position_embedding.to(target_dtype) + return embeddings + + +class InternAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: PretrainedConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f'embed_dim must be divisible by num_heads ' + f'(got `embed_dim`: {self.embed_dim} and `num_heads`:' + f' {self.num_heads}).') + + self.scale = self.head_dim**-0.5 + self.qkv = nn.Linear(self.embed_dim, + 3 * self.embed_dim, + bias=config.qkv_bias) + + self.qk_normalization = config.qk_normalization + + if self.qk_normalization: + self.q_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps) + self.k_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps) + + self.proj = nn.Linear(self.embed_dim, self.embed_dim) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, + C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) + + if self.qk_normalization: + B_, H_, N_, D_ = q.shape + q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view( + B_, N_, H_, D_).transpose(1, 2) + k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view( + B_, N_, H_, D_).transpose(1, 2) + + x = F.scaled_dot_product_attention(q, k, v, scale=self.scale) + x = x.transpose(1, 2).reshape(B, N, C) + + x = self.proj(x) + return x + + +class InternMLP(nn.Module): + + def __init__(self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.config = config + self.activation_fn = get_act_fn(config.hidden_act) + self.fc1 = ColumnParallelLinear(config.hidden_size, + config.intermediate_size, + bias=True, + quant_config=quant_config) + self.fc2 = RowParallelLinear(config.intermediate_size, + config.hidden_size, + bias=True, + quant_config=quant_config) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states, _ = self.fc2(hidden_states) + + return hidden_states + + +class InternVisionEncoderLayer(nn.Module): + + def __init__(self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.embed_dim = config.hidden_size + self.intermediate_size = config.intermediate_size + self.norm_type = config.norm_type + + self.attn = InternAttention(config) + self.mlp = InternMLP(config, quant_config=quant_config) + self.norm1 = NORM2FN[self.norm_type](self.embed_dim, + eps=config.layer_norm_eps) + self.norm2 = NORM2FN[self.norm_type](self.embed_dim, + eps=config.layer_norm_eps) + + self.ls1 = nn.Parameter(config.initializer_factor * + torch.ones(self.embed_dim)) + self.ls2 = nn.Parameter(config.initializer_factor * + torch.ones(self.embed_dim)) + + def forward( + self, + hidden_states: torch.Tensor, + ): + hidden_states = hidden_states + self.attn( + self.norm1(hidden_states)) * self.ls1 + + hidden_states = hidden_states + self.mlp( + self.norm2(hidden_states)) * self.ls2 + + return hidden_states + + +class InternVisionEncoder(nn.Module): + + def __init__(self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + num_hidden_layers_override: Optional[int] = None): + super().__init__() + self.config = config + + if num_hidden_layers_override is None: + num_hidden_layers = config.num_hidden_layers + else: + num_hidden_layers = num_hidden_layers_override + self.layers = nn.ModuleList([ + InternVisionEncoderLayer(config=config, quant_config=quant_config) + for _ in range(num_hidden_layers) + ]) + + def forward(self, inputs_embeds: torch.Tensor): + + hidden_states = inputs_embeds + for encoder_layer in self.layers: + hidden_states = encoder_layer(hidden_states) + + return hidden_states + + +class InternVisionModel(nn.Module): + + def __init__(self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + num_hidden_layers_override: Optional[int] = None): + super().__init__() + self.config = config + + self.embeddings = InternVisionEmbeddings(config) + self.encoder = InternVisionEncoder( + config=config, + quant_config=quant_config, + num_hidden_layers_override=num_hidden_layers_override) + + def resize_pos_embeddings(self, old_size, new_size, patch_size): + pos_emb = self.embeddings.position_embedding + _, num_positions, embed_dim = pos_emb.shape + cls_emb = pos_emb[:, :1, :] + pos_emb = pos_emb[:, 1:, :].reshape(1, old_size // patch_size, + old_size // patch_size, + -1).permute(0, 3, 1, 2) + pos_emb = F.interpolate(pos_emb.float(), + size=new_size // patch_size, + mode='bicubic', + align_corners=False) + pos_emb = pos_emb.to(cls_emb.dtype).reshape(1, embed_dim, + -1).permute(0, 2, 1) + pos_emb = torch.cat([cls_emb, pos_emb], dim=1) + self.embeddings.position_embedding = nn.Parameter(pos_emb) + self.embeddings.image_size = new_size + + def get_input_embeddings(self): + return self.embeddings + + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + pixel_embeds: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + if pixel_values is None and pixel_embeds is None: + raise ValueError( + 'You have to specify pixel_values or pixel_embeds') + + if pixel_embeds is not None: + hidden_states = pixel_embeds + elif pixel_values is not None: + if pixel_values.ndim == 4: + hidden_states = self.embeddings(pixel_values) + else: + raise ValueError( + f'wrong pixel_values size: {pixel_values.shape}') + + encoder_outputs = self.encoder(inputs_embeds=hidden_states) + + return encoder_outputs diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index 22132f40fc5e6..745fbf99a902d 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -219,14 +219,22 @@ def __init__( ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.tok_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: IntermediateTensors = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = self.tok_embeddings(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.tok_embeddings(input_ids) residual = None for i in range(len(self.layers)): layer = self.layers[i] diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py new file mode 100644 index 0000000000000..f64c78c15f8ee --- /dev/null +++ b/vllm/model_executor/models/internvl.py @@ -0,0 +1,471 @@ +# adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_internvl_chat.py +# -------------------------------------------------------- +# InternVL +# Copyright (c) 2023 OpenGVLab +# Licensed under The MIT License [see LICENSE for details] +# -------------------------------------------------------- +from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union + +import torch +import torch.nn as nn +import torchvision.transforms as T +from PIL import Image +from transformers import PretrainedConfig + +from vllm.attention import AttentionMetadata +from vllm.config import CacheConfig, MultiModalConfig +from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models import ModelRegistry +from vllm.model_executor.models.intern_vit import InternVisionModel +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensors +from vllm.multimodal.base import MultiModalInputs +from vllm.multimodal.image import cached_get_tokenizer +from vllm.sequence import IntermediateTensors, SamplerOutput + +from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip, + get_clip_num_patches) +from .interfaces import SupportsVision +from .utils import merge_vision_embeddings + +IMG_START = '' +IMG_END = '' +IMG_CONTEXT = '' + +IMAGENET_MEAN = (0.485, 0.456, 0.406) +IMAGENET_STD = (0.229, 0.224, 0.225) + +MAX_IMAGE_FEATURE_SIZE_WIDTH = 3000 +MAX_IMAGE_FEATURE_SIZE_HEIGHT = 500 + + +class InternVLImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + data: BatchedTensors + """ + Shape: `(batch_size, 1 + num_patches, num_channels, height, width)` + + Note that `num_patches` may be different for each batch, in which case + the data is passed as a list instead of a batched tensor. + """ + + +# copied from https://huggingface.co/OpenGVLab/InternVL2-1B +def build_transform(input_size): + MEAN, STD = IMAGENET_MEAN, IMAGENET_STD + transform = T.Compose([ + T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), + T.Resize((input_size, input_size), + interpolation=T.InterpolationMode.BICUBIC), + T.ToTensor(), + T.Normalize(mean=MEAN, std=STD) + ]) + return transform + + +# copied from https://huggingface.co/OpenGVLab/InternVL2-1B +def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, + image_size): + best_ratio_diff = float('inf') + best_ratio = (1, 1) + area = width * height + for ratio in target_ratios: + target_aspect_ratio = ratio[0] / ratio[1] + ratio_diff = abs(aspect_ratio - target_aspect_ratio) + if ratio_diff < best_ratio_diff: + best_ratio_diff = ratio_diff + best_ratio = ratio + elif ratio_diff == best_ratio_diff: + if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: + best_ratio = ratio + return best_ratio + + +def calculate_num_blocks(orig_width: int, + orig_height: int, + min_num=1, + max_num=6, + image_size=448): + aspect_ratio = orig_width / orig_height + + # calculate the existing image aspect ratio + target_ratios = set((i, j) for n in range(min_num, max_num + 1) + for i in range(1, n + 1) for j in range(1, n + 1) + if i * j <= max_num and i * j >= min_num) + target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) + + # find the closest aspect ratio to the target + target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio, + target_ratios, orig_width, + orig_height, image_size) + + # calculate the target width and height + target_width = image_size * target_aspect_ratio[0] + target_height = image_size * target_aspect_ratio[1] + blocks = target_aspect_ratio[0] * target_aspect_ratio[1] + return blocks, target_width, target_height + + +# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B +def dynamic_preprocess(image, + min_num=1, + max_num=6, + image_size=448, + use_thumbnail=False): + orig_width, orig_height = image.size + + blocks, target_width, target_height = calculate_num_blocks( + orig_width, orig_height, min_num, max_num, image_size) + # resize the image + resized_img = image.resize((target_width, target_height)) + processed_images = [] + for i in range(blocks): + box = ((i % (target_width // image_size)) * image_size, + (i // (target_width // image_size)) * image_size, + ((i % (target_width // image_size)) + 1) * image_size, + ((i // (target_width // image_size)) + 1) * image_size) + # split the image + split_img = resized_img.crop(box) + processed_images.append(split_img) + assert len(processed_images) == blocks + if use_thumbnail and len(processed_images) != 1: + thumbnail_img = image.resize((image_size, image_size)) + processed_images.append(thumbnail_img) + return processed_images + + +# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B +def image_to_pixel_values(image: Image.Image, input_size=448, max_num=6): + transform = build_transform(input_size=input_size) + images = dynamic_preprocess(image, + image_size=input_size, + use_thumbnail=True, + max_num=max_num) + pixel_values = [transform(image) for image in images] + pixel_values = torch.stack(pixel_values) + return pixel_values + + +def get_internvl_num_patches(image_size: int, patch_size: int, + downsample_ratio: float): + return int( + get_clip_num_patches(image_size=image_size, patch_size=patch_size) * + (downsample_ratio**2)) + + +def get_max_internvl_image_tokens(ctx: InputContext): + hf_config = ctx.get_hf_config(PretrainedConfig) + vision_config = hf_config.vision_config + image_size = vision_config.image_size + patch_size = vision_config.patch_size + downsample_ratio = hf_config.downsample_ratio + num_patches = get_internvl_num_patches(image_size, patch_size, + downsample_ratio) + return num_patches * 7 + + +def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs): + multi_modal_data = llm_inputs.get("multi_modal_data") + if multi_modal_data is None or "image" not in multi_modal_data: + return llm_inputs + + model_config = ctx.model_config + hf_config = ctx.get_hf_config(PretrainedConfig) + vision_config = hf_config.vision_config + + image_data = multi_modal_data["image"] + if isinstance(image_data, Image.Image): + width, height = image_data.size + num_blocks, _, _ = calculate_num_blocks(width, height) + elif isinstance(image_data, torch.Tensor): + raise NotImplementedError("Embeddings input is not supported yet") + else: + raise TypeError(f"Invalid image type: {type(image_data)}") + + image_size = vision_config.image_size + patch_size = vision_config.patch_size + downsample_ratio = hf_config.downsample_ratio + num_patches = get_internvl_num_patches(image_size, patch_size, + downsample_ratio) + + tokenizer = cached_get_tokenizer(model_config.tokenizer, + trust_remote_code=True) + + prompt = llm_inputs["prompt"] + prompt_token_ids = llm_inputs["prompt_token_ids"] + if prompt is None: + prompt = tokenizer.decode(prompt_token_ids) + image_prompt = IMG_START + IMG_CONTEXT * (num_blocks + + 1) * num_patches + IMG_END + new_prompt = prompt.replace('', image_prompt, 1) + new_prompt_token_ids = tokenizer.encode(new_prompt) + + return LLMInputs(prompt=prompt, + prompt_token_ids=new_prompt_token_ids, + multi_modal_data=multi_modal_data) + + +def input_mapper_for_internvl(ctx: InputContext, data: object): + if isinstance(data, Image.Image): + data = image_to_pixel_values(data) + model_config = ctx.model_config + tokenizer = cached_get_tokenizer(model_config.tokenizer, + trust_remote_code=True) + image_token_id = tokenizer.encode(IMG_CONTEXT, + add_special_tokens=False, + return_tensors="pt")[0] + + return MultiModalInputs({ + "pixel_values": data, + "image_token_id": image_token_id + }) + + +def dummy_data_for_internvl(ctx: InputContext, seq_len: int): + + image_feature_size = get_max_internvl_image_tokens(ctx) + model_config = ctx.model_config + hf_config = ctx.get_hf_config(PretrainedConfig) + vision_config = hf_config.vision_config + tokenizer = cached_get_tokenizer(model_config.tokenizer, + trust_remote_code=True) + + seq_data = dummy_seq_data_for_clip( + vision_config, + seq_len, + image_token_id=tokenizer.encode(IMG_CONTEXT, + add_special_tokens=False)[0], + image_feature_size_override=image_feature_size, + ) + mm_data = dummy_image_for_clip( + vision_config, + image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH, + image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT, + ) + + return seq_data, mm_data + + +@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_internvl) +@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_internvl_image_tokens) +@INPUT_REGISTRY.register_dummy_data(dummy_data_for_internvl) +@INPUT_REGISTRY.register_input_processor(input_processor_for_internvl) +class InternVLChatModel(nn.Module, SupportsVision): + + def __init__(self, + config: PretrainedConfig, + multimodal_config: MultiModalConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None) -> None: + super().__init__() + + self.config = config + self.multimodal_config = multimodal_config + + image_size = config.force_image_size or config.vision_config.image_size + patch_size = config.vision_config.patch_size + self.patch_size = patch_size + self.select_layer = config.select_layer + self.num_image_token = int( + (image_size // patch_size)**2 * (config.downsample_ratio**2)) + self.downsample_ratio = config.downsample_ratio + self.ps_version = config.ps_version + + vision_feature_layer = self.select_layer + if vision_feature_layer < 0: + num_hidden_layers = config.vision_config.num_hidden_layers \ + + vision_feature_layer + 1 + else: + num_hidden_layers = vision_feature_layer + 1 + self.vision_model = InternVisionModel( + config.vision_config, num_hidden_layers_override=num_hidden_layers) + + llm_class = ModelRegistry.load_model_cls( + config.text_config.architectures[0]) + self.language_model = llm_class(config.text_config, cache_config, + quant_config) + + vit_hidden_size = config.vision_config.hidden_size + llm_hidden_size = config.text_config.hidden_size + + self.mlp1 = nn.Sequential( + nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio)**2), + nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio)**2, + llm_hidden_size), nn.GELU(), + nn.Linear(llm_hidden_size, llm_hidden_size)) + + self.img_context_token_id = None + + def pixel_shuffle(self, x, scale_factor=0.5): + n, w, h, c = x.size() + # N, W, H, C --> N, W, H * scale, C // scale + x = x.view(n, w, int(h * scale_factor), int(c / scale_factor)) + # N, W, H * scale, C // scale --> N, H * scale, W, C // scale + x = x.permute(0, 2, 1, 3).contiguous() + x = x.view(n, int(h * scale_factor), int(w * scale_factor), + int(c / (scale_factor * scale_factor))) + if self.ps_version == 'v1': + pass + else: + x = x.permute(0, 2, 1, 3).contiguous() + return x + + def extract_feature(self, pixel_values): + vit_embeds = self.vision_model(pixel_values=pixel_values) + vit_embeds = vit_embeds[:, 1:, :] + + h = w = int(vit_embeds.shape[1]**0.5) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) + vit_embeds = self.pixel_shuffle(vit_embeds, + scale_factor=self.downsample_ratio) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, + vit_embeds.shape[-1]) + vit_embeds = self.mlp1(vit_embeds) + return vit_embeds + + def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor: + if list(data.shape[1:]) != [2]: + raise ValueError( + f"The expected image sizes shape is batch dimension plus " + f"{[2]}. You supplied {data.shape}.") + + return data + + def _validate_pixel_values( + self, data: Union[torch.Tensor, List[torch.Tensor]] + ) -> Union[torch.Tensor, List[torch.Tensor]]: + + h = w = self.config.vision_config.image_size + expected_dims = (3, h, w) + + def _validate_shape(d: torch.Tensor): + actual_dims = tuple(d.shape) + + if actual_dims != expected_dims: + expected_expr = ("num_patches", *map(str, expected_dims)) + raise ValueError( + "The expected shape of pixel values in each batch element " + f"is {expected_expr}. You supplied {tuple(d.shape)}.") + + for d in data: + _validate_shape(d) + + return data + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[InternVLImagePixelInputs]: + pixel_values = kwargs.pop("pixel_values", None) + image_token_id = kwargs.pop("image_token_id", None) + + if pixel_values is None: + return None + + self.img_context_token_id = image_token_id[0] + + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") + + return InternVLImagePixelInputs( + type="pixel_values", + data=self._validate_pixel_values(pixel_values), + ) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + **kwargs: object, + ) -> SamplerOutput: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is not None: + inputs_embeds = self.language_model.model.get_input_embeddings( + input_ids) + vit_embeds = self.extract_feature(image_input["data"]) + inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds, + vit_embeds, + self.img_context_token_id) + input_ids = None + else: + inputs_embeds = None + + hidden_states = self.language_model.model(input_ids, + positions, + kv_caches, + attn_metadata, + None, + inputs_embeds=inputs_embeds) + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + return self.language_model.compute_logits(hidden_states, + sampling_metadata) + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + return self.language_model.sample(logits, sampling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + (".gate_up_proj", ".w1", 0), + (".gate_up_proj", ".w3", 1), + ] + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if self.config.text_config.tie_word_embeddings \ + and "lm_head.weight" in name: + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + # We only do sharding for language model + # and not vision model for now. + if "vision_embed_tokens" in name and self.vision_embed_tokens: + continue + if weight_name not in name: + continue + param = params_dict[name.replace(weight_name, param_name)] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + if "wqkv" in name: + config = self.config.text_config + kv_groups = (config.num_attention_heads // + config.num_key_value_heads) + head_dim = config.hidden_size // config.num_attention_heads + loaded_weight = loaded_weight.view(-1, 2 + kv_groups, + head_dim, + loaded_weight.shape[-1]) + wq, wk, wv = torch.split(loaded_weight, [kv_groups, 1, 1], + dim=1) + wq = wq.reshape(-1, wq.shape[-1]) + wk = wk.reshape(-1, wk.shape[-1]) + wv = wv.reshape(-1, wv.shape[-1]) + weight_loader = param.weight_loader + weight_loader(param, wq, 'q') + weight_loader(param, wk, 'k') + weight_loader(param, wv, 'v') + continue + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index e9aa4416eded4..3deb3d8840cc4 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -243,14 +243,22 @@ def __init__( ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.embed_tokens(input_ids) residual = None for i in range(len(self.layers)): layer = self.layers[i] diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 3ba2e01985598..3d13631b9b2b6 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -6,9 +6,10 @@ from vllm.envs import VLLM_USE_MODELSCOPE from vllm.logger import init_logger from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig, - JAISConfig, MedusaConfig, - MLPSpeculatorConfig, MPTConfig, - NemotronConfig, RWConfig) + InternVLChatConfig, JAISConfig, + MedusaConfig, MLPSpeculatorConfig, + MPTConfig, NemotronConfig, + RWConfig) if VLLM_USE_MODELSCOPE: from modelscope import AutoConfig @@ -26,6 +27,7 @@ "jais": JAISConfig, "mlp_speculator": MLPSpeculatorConfig, "medusa": MedusaConfig, + "internvl_chat": InternVLChatConfig, "nemotron": NemotronConfig, } diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 1750950b3c38b..5ccacd4a4c40a 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -4,6 +4,7 @@ # tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the # `FalconConfig` class from the official HuggingFace transformers library. from vllm.transformers_utils.configs.falcon import RWConfig +from vllm.transformers_utils.configs.internvl import InternVLChatConfig from vllm.transformers_utils.configs.jais import JAISConfig from vllm.transformers_utils.configs.medusa import MedusaConfig from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig @@ -15,6 +16,7 @@ "DbrxConfig", "MPTConfig", "RWConfig", + "InternVLChatConfig", "JAISConfig", "MedusaConfig", "MLPSpeculatorConfig", diff --git a/vllm/transformers_utils/configs/internvl.py b/vllm/transformers_utils/configs/internvl.py new file mode 100644 index 0000000000000..ac2492317aa36 --- /dev/null +++ b/vllm/transformers_utils/configs/internvl.py @@ -0,0 +1,51 @@ +# Adapted from +# https://huggingface.co/OpenGVLab/InternVL2-1B/blob/main/configuration_internvl_chat.py +# -------------------------------------------------------- +# InternVL +# Copyright (c) 2024 OpenGVLab +# Licensed under The MIT License [see LICENSE for details] +# -------------------------------------------------------- +from transformers.configuration_utils import PretrainedConfig + + +class InternVLChatConfig(PretrainedConfig): + model_type = 'internvl_chat' + is_composition = True + + def __init__(self, + vision_config=None, + llm_config=None, + use_backbone_lora=0, + use_llm_lora=0, + select_layer=-1, + force_image_size=None, + downsample_ratio=0.5, + template=None, + dynamic_image_size=False, + use_thumbnail=False, + ps_version='v1', + min_dynamic_patch=1, + max_dynamic_patch=6, + **kwargs): + super().__init__(**kwargs) + + if vision_config is None: + vision_config = {} + + if llm_config is None: + llm_config = {} + + self.vision_config = PretrainedConfig(**vision_config) + self.text_config = PretrainedConfig(**llm_config) + + self.use_backbone_lora = use_backbone_lora + self.use_llm_lora = use_llm_lora + self.select_layer = select_layer + self.force_image_size = force_image_size + self.downsample_ratio = downsample_ratio + self.template = template + self.dynamic_image_size = dynamic_image_size + self.use_thumbnail = use_thumbnail + self.ps_version = ps_version # pixel shuffle version + self.min_dynamic_patch = min_dynamic_patch + self.max_dynamic_patch = max_dynamic_patch From 766435e660a786933392eb8ef0a873bc38cf0c8b Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Mon, 29 Jul 2024 11:42:35 -0400 Subject: [PATCH 02/79] [Kernel] Tuned FP8 Kernels for Ada Lovelace (#6677) Co-authored-by: Varun Sundar Rabindranath --- .../cutlass_benchmarks/w8a8_benchmarks.py | 2 +- .../cutlass_w8a8/scaled_mm_c2x.cu | 520 ++---------------- .../cutlass_w8a8/scaled_mm_c2x.cuh | 340 ++++++++++++ .../scaled_mm_c2x_sm80_dispatch.cuh | 139 +++++ .../scaled_mm_c2x_sm89_dispatch.cuh | 362 ++++++++++++ tests/kernels/test_cutlass.py | 4 +- 6 files changed, 877 insertions(+), 490 deletions(-) create mode 100644 csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh create mode 100644 csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm80_dispatch.cuh create mode 100644 csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_dispatch.cuh diff --git a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py index 234c2c8a1074c..70247e94e63cf 100644 --- a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py +++ b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py @@ -13,7 +13,7 @@ from vllm import _custom_ops as ops from vllm.utils import FlexibleArgumentParser -DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())[1:] +DEFAULT_MODELS = list(WEIGHT_SHAPES.keys()) DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512] DEFAULT_TP_SIZES = [1] diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu index 6ce25c5ac897b..d26c43de522c9 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu @@ -1,470 +1,16 @@ #include #include - -#include - -// clang-format will break include orders -// clang-format off -#include "cute/tensor.hpp" -#include "cute/atom/mma_atom.hpp" -#include "cutlass/numeric_types.h" - -#include "cutlass/util/device_memory.h" - #include "cutlass/cutlass.h" -#include "cutlass/gemm_coord.h" -#include "cutlass/arch/mma_sm75.h" -#include "cutlass/arch/arch.h" -#include "cutlass/arch/mma.h" -#include "cutlass/gemm/device/gemm.h" -#include "cutlass/gemm/device/gemm_universal_adapter.h" -#include "cutlass/epilogue/threadblock/fusion/visitors.hpp" -#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h" - -#include "broadcast_load_epilogue_c2x.hpp" -#include "common.hpp" -// clang-format on - -using namespace cute; +#include "scaled_mm_c2x.cuh" +#include "scaled_mm_c2x_sm80_dispatch.cuh" +#include "scaled_mm_c2x_sm89_dispatch.cuh" /* This file defines quantized GEMM operations using the CUTLASS 2.x API, for NVIDIA GPUs with SM versions prior to sm90 (Hopper). - - Epilogue functions can be defined to post-process the output before it is - written to GPU memory. - Epilogues must contain a public type named EVTCompute of type Sm80EVT, - as well as a static prepare_args function that constructs an - EVTCompute::Arguments struct. */ -namespace { - -// Wrappers for the GEMM kernel that is used to guard against compilation on -// architectures that will never use the kernel. The purpose of this is to -// reduce the size of the compiled binary. -// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef -// into code that will be executed on the device where it is defined. -template -struct enable_sm75_to_sm80 : Kernel { - template - CUTLASS_DEVICE static void invoke(Args&&... args) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 750 && __CUDA_ARCH__ < 800 - Kernel::invoke(std::forward(args)...); -#endif - } -}; - -template -struct enable_sm80_to_sm89 : Kernel { - template - CUTLASS_DEVICE static void invoke(Args&&... args) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800 && __CUDA_ARCH__ < 890 - Kernel::invoke(std::forward(args)...); -#endif - } -}; - -template -struct enable_sm89_to_sm90 : Kernel { - template - CUTLASS_DEVICE static void invoke(Args&&... args) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 890 && __CUDA_ARCH__ < 900 - Kernel::invoke(std::forward(args)...); -#endif - } -}; - -/* - * This class provides the common ScaleA and ScaleB descriptors for the - * ScaledEpilogue and ScaledEpilogueBias classes. - */ -template -struct ScaledEpilogueBase { - protected: - using Accum = cutlass::epilogue::threadblock::VisitorAccFetch; - - using ScaleA = cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast< - OutputTileThreadMap, float, Stride, Int<0>, Int<0>>>; - - using ScaleB = cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast< - OutputTileThreadMap, float, Stride, Int<1>, Int<0>>>; -}; - -/* - This epilogue function defines a quantized GEMM operation similar to - torch._scaled_mm. - - A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or - per-row. B can be quantized per-tensor or per-column. - Any combination of per-tensor and per-row or column is supported. - A and B must have symmetric quantization (zero point == 0). - - So the GEMM operation is D = (a_scales * A) (b_scales * B), where the - scales are applied elementwise with numpy-style broadcasting. - - ScaleA and ScaleB define the epilogue functions that apply the scales for - the A and B operands respectively. These scales may be either per-tensor or - per row or column. -*/ -template -struct ScaledEpilogue - : private ScaledEpilogueBase { - private: - using SUPER = ScaledEpilogueBase; - using Accum = typename SUPER::Accum; - using ScaleA = typename SUPER::ScaleA; - using ScaleB = typename SUPER::ScaleB; - - using Compute0 = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiplies, float, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTCompute0 = - cutlass::epilogue::threadblock::Sm80EVT; - - using Compute1 = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiplies, ElementD, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - public: - using EVTCompute = - cutlass::epilogue::threadblock::Sm80EVT; - using ArgumentType = typename EVTCompute::Arguments; - - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales) { - using ScaleAArgs = typename ScaleA::Arguments; - using ScaleBArgs = typename ScaleB::Arguments; - - ScaleBArgs b_args{b_scales.data_ptr(), b_scales.numel() != 1, {}}; - ScaleAArgs a_args{a_scales.data_ptr(), a_scales.numel() != 1, {}}; - - typename EVTCompute0::Arguments evt0_compute_args{b_args}; - - typename EVTCompute::Arguments evt_compute_args{a_args, evt0_compute_args}; - return evt_compute_args; - } -}; - -template -struct ScaledEpilogueBias - : private ScaledEpilogueBase { - private: - using SUPER = ScaledEpilogueBase; - using Accum = typename SUPER::Accum; - using ScaleA = typename SUPER::ScaleA; - using ScaleB = typename SUPER::ScaleB; - - using Compute0 = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiplies, float, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTCompute0 = - cutlass::epilogue::threadblock::Sm80EVT; - - using Compute1 = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiply_add, ElementD, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - using Bias = cutlass::epilogue::threadblock::VisitorRowBroadcast< - OutputTileThreadMap, ElementD, Stride, Int<1>, Int<0>>>; - - public: - using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT; - using ArgumentType = typename EVTCompute::Arguments; - - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& bias) { - using ScaleAArgs = typename ScaleA::Arguments; - using ScaleBArgs = typename ScaleB::Arguments; - using BiasArgs = typename Bias::Arguments; - - ScaleBArgs b_args{b_scales.data_ptr(), b_scales.numel() != 1, {}}; - ScaleAArgs a_args{a_scales.data_ptr(), a_scales.numel() != 1, {}}; - BiasArgs bias_args{static_cast(bias.data_ptr()), {}}; - - typename EVTCompute0::Arguments evt0_compute_args{b_args}; - - typename EVTCompute::Arguments evt_compute_args{a_args, evt0_compute_args, - bias_args}; - return evt_compute_args; - } -}; - -template typename ArchGuard, - typename ElementAB_, typename ElementD_, - template typename Epilogue_, typename TileShape, - typename WarpShape, typename InstructionShape, int32_t MainLoopStages> -struct cutlass_2x_gemm { - using ElementAB = ElementAB_; - using ElementD = ElementD_; - - using ElementAcc = - typename std::conditional, int32_t, - float>::type; - - using Operator = - typename std::conditional, - cutlass::arch::OpMultiplyAddSaturate, - cutlass::arch::OpMultiplyAdd>::type; - - using OutputTileThreadMap = - cutlass::epilogue::threadblock::OutputTileThreadLayout< - TileShape, WarpShape, float, 4, 1 /* epilogue stages */ - >; - - using Epilogue = Epilogue_; - using EVTCompute = typename Epilogue::EVTCompute; - - using D = cutlass::epilogue::threadblock::VisitorAuxStore< - OutputTileThreadMap, ElementD, cutlass::FloatRoundStyle::round_to_nearest, - Stride, Int<0>>>; - - using EVTD = cutlass::epilogue::threadblock::Sm80EVT; - - // clang-format off - using RowMajor = typename cutlass::layout::RowMajor; - using ColumnMajor = typename cutlass::layout::ColumnMajor; - using KernelType = - ArchGuard::GemmKernel>; - // clang-format on - - using Op = cutlass::gemm::device::GemmUniversalAdapter; -}; - -template -void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, - EpilogueArgs&&... epilogue_params) { - using ElementAB = typename Gemm::ElementAB; - using ElementD = typename Gemm::ElementD; - - int32_t m = a.size(0); - int32_t n = b.size(1); - int32_t k = a.size(1); - cutlass::gemm::GemmCoord problem_size{m, n, k}; - - int64_t lda = a.stride(0); - int64_t ldb = b.stride(1); - int64_t ldc = out.stride(0); - - using StrideC = Stride, Int<0>>; - StrideC c_stride{ldc, Int<1>{}, Int<0>{}}; - - auto a_ptr = static_cast(a.data_ptr()); - auto b_ptr = static_cast(b.data_ptr()); - auto c_ptr = static_cast(out.data_ptr()); - - typename Gemm::D::Arguments d_args{c_ptr, c_stride}; - - using Epilogue = typename Gemm::Epilogue; - auto evt_args = - Epilogue::prepare_args(std::forward(epilogue_params)...); - - typename Gemm::EVTD::Arguments epilogue_args{ - evt_args, - d_args, - }; - - typename Gemm::Op::Arguments args{ - cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel, // universal mode - problem_size, // problem size - 1, // batch count - epilogue_args, - a_ptr, - b_ptr, - nullptr, - nullptr, - 0, - 0, - 0, - 0, - lda, - ldb, - ldc, - ldc}; - - // Launch the CUTLASS GEMM kernel. - typename Gemm::Op gemm_op; - size_t workspace_size = gemm_op.get_workspace_size(args); - cutlass::device_memory::allocation workspace(workspace_size); - - auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); - - CUTLASS_CHECK(gemm_op.can_implement(args)); - cutlass::Status status = gemm_op(args, workspace.get(), stream); - CUTLASS_CHECK(status); -} - -template -void fallback_cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, - EpilogueArgs&&... args) { - // In some cases, the GPU isn't able to accommodate the - // shared memory requirements of the Gemm. In such cases, use - // the FallbackGemm instead. - static const int max_shared_mem_per_block_opt_in = - get_cuda_max_shared_memory_per_block_opt_in(0); - - size_t const gemm_shared_mem_size = - sizeof(typename Gemm::KernelType::SharedStorage); - size_t const fallback_gemm_shared_mem_size = - sizeof(typename FallbackGemm::KernelType::SharedStorage); - - if (gemm_shared_mem_size <= max_shared_mem_per_block_opt_in) { - return cutlass_gemm_caller(out, a, b, - std::forward(args)...); - } else { - TORCH_CHECK(fallback_gemm_shared_mem_size <= - max_shared_mem_per_block_opt_in); - return cutlass_gemm_caller( - out, a, b, std::forward(args)...); - } -} - -template typename Epilogue> -struct sm80_config_default { - // This config is used in 2 cases, - // - M in (128, inf) - // - M in (64, 128] and N >= 8192 - // Shared Memory required by this Gemm - 81920 bytes - static_assert(std::is_same()); - using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>; - using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; - using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; - using Cutlass2xGemm = - cutlass_2x_gemm; -}; - -template typename Epilogue> -struct sm80_config_M64 { - // This config is used in 2 cases, - // - M in (32, 64] - // - M in (64, 128] and N < 8192 - // Shared Memory required by this Gemm - 122880 bytes - static_assert(std::is_same()); - using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>; - using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; - using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; - using Cutlass2xGemm = - cutlass_2x_gemm; -}; - -template typename Epilogue> -struct sm80_config_M32 { - // M in (16, 32] - // Shared Memory required by this Gemm - 61440 bytes - static_assert(std::is_same()); - using TileShape = typename cutlass::gemm::GemmShape<32, 64, 128>; - using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>; - using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; - using Cutlass2xGemm = - cutlass_2x_gemm; -}; - -template typename Epilogue> -struct sm80_config_M16 { - // M in [1, 16] - // Shared Memory required by this Gemm - 51200 bytes - static_assert(std::is_same()); - using TileShape = typename cutlass::gemm::GemmShape<16, 64, 128>; - using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>; - using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; - using Cutlass2xGemm = - cutlass_2x_gemm; -}; - -} // namespace - -template typename Epilogue, - typename... EpilogueArgs> -void cutlass_gemm_sm80_dispatch(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, - EpilogueArgs&&... args) { - static_assert(std::is_same()); - TORCH_CHECK(a.dtype() == torch::kInt8); - TORCH_CHECK(b.dtype() == torch::kInt8); - - using Cutlass2xGemmDefault = - typename sm80_config_default::Cutlass2xGemm; - using Cutlass2xGemmM128BigN = - typename sm80_config_default::Cutlass2xGemm; - using Cutlass2xGemmM128SmallN = - typename sm80_config_M64::Cutlass2xGemm; - using Cutlass2xGemmM64 = - typename sm80_config_M64::Cutlass2xGemm; - using Cutlass2xGemmM32 = - typename sm80_config_M32::Cutlass2xGemm; - using Cutlass2xGemmM16 = - typename sm80_config_M16::Cutlass2xGemm; - - // Due to shared memory requirements, some Gemms may fail to run on some - // GPUs. As the name indicates, the Fallback Gemm is used as an alternative - // in such cases. - // sm80_config_M16 has the least shared-memory requirement. However, - // based on some profiling, we select sm80_config_M32 as a better alternative - // performance wise. - using FallbackGemm = - typename sm80_config_M32::Cutlass2xGemm; - - uint32_t const m = a.size(0); - uint32_t const mp2 = - std::max(static_cast(16), next_pow_2(m)); // next power of 2 - if (mp2 <= 16) { - // M in [1, 16] - return fallback_cutlass_gemm_caller( - out, a, b, std::forward(args)...); - } else if (mp2 <= 32) { - // M in (16, 32] - return fallback_cutlass_gemm_caller( - out, a, b, std::forward(args)...); - } else if (mp2 <= 64) { - // M in (32, 64] - return fallback_cutlass_gemm_caller( - out, a, b, std::forward(args)...); - } else if (mp2 <= 128) { - // M in (64, 128] - uint32_t const n = out.size(1); - bool const small_n = n < 8192; - if (small_n) { - return fallback_cutlass_gemm_caller( - out, a, b, std::forward(args)...); - } else { - return fallback_cutlass_gemm_caller( - out, a, b, std::forward(args)...); - } - } else { - // M in (128, inf) - return fallback_cutlass_gemm_caller( - out, a, b, std::forward(args)...); - } -} - template