From c7cb5c333564cb00fc4f6a99d32c35e9ebc0f1ed Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 9 Sep 2024 16:27:26 -0400 Subject: [PATCH 01/81] [Misc] GPTQ Activation Ordering (#8135) --- tests/weight_loading/models.txt | 1 + .../compressed_tensors/compressed_tensors.py | 3 +- .../schemes/compressed_tensors_wNa16.py | 45 ++++++++++++++----- .../quantization/compressed_tensors/utils.py | 30 ++++++++++++- 4 files changed, 64 insertions(+), 15 deletions(-) diff --git a/tests/weight_loading/models.txt b/tests/weight_loading/models.txt index 1dc529037a98e..c708e6d5eb897 100644 --- a/tests/weight_loading/models.txt +++ b/tests/weight_loading/models.txt @@ -21,6 +21,7 @@ compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main +compressed-tensors, nm-testing/TinyLlama-1.1B-Chat-v1.0-actorder-group, main awq, casperhansen/mixtral-instruct-awq, main awq_marlin, casperhansen/mixtral-instruct-awq, main fp8, neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV, main diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 0768b37044aac..1170d55f31993 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -232,7 +232,8 @@ def _get_scheme_from_parts( return CompressedTensorsWNA16( num_bits=weight_quant.num_bits, strategy=weight_quant.strategy, - group_size=weight_quant.group_size) + group_size=weight_quant.group_size, + actorder=weight_quant.actorder) # Detect If Activation Quantization. # TODO @dsikka: clean-up conditions diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py index 7ca8eecb9283e..8897737c1c55a 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py @@ -5,14 +5,18 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) +from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( + ActivationOrdering) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( apply_gptq_marlin_linear, marlin_make_empty_g_idx, marlin_make_workspace, - marlin_permute_scales, replace_tensor, verify_marlin_supported, + marlin_permute_scales, marlin_repeat_scales_on_all_ranks, + marlin_sort_g_idx, replace_tensor, verify_marlin_supported, verify_marlin_supports_shape) from vllm.model_executor.parameter import (BasevLLMParameter, ChannelQuantScaleParameter, GroupQuantScaleParameter, - PackedvLLMParameter) + PackedvLLMParameter, + RowvLLMParameter) from vllm.scalar_type import scalar_types __all__ = ["CompressedTensorsWNA16"] @@ -28,11 +32,13 @@ class CompressedTensorsWNA16(CompressedTensorsScheme): def __init__(self, strategy: str, num_bits: int, - group_size: Optional[int] = None): + group_size: Optional[int] = None, + actorder: Optional[ActivationOrdering] = None): self.pack_factor = 32 // num_bits self.strategy = strategy self.group_size = -1 if group_size is None else group_size + self.has_g_idx = actorder == ActivationOrdering.GROUP if self.group_size == -1 and self.strategy != "channel": raise ValueError("Marlin kernels require group quantization or " @@ -64,12 +70,10 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, output_size_per_partition = sum(output_partition_sizes) # If group_size is -1, we are in channelwise case. - channelwise = (self.group_size == -1) group_size = self.group_size if self.group_size != -1 else input_size row_parallel = (input_size != input_size_per_partition) - # In the case of channelwise quantization, we need to replicate the - # scales across all gpus. - partition_scales = (row_parallel and not channelwise) + partition_scales = not marlin_repeat_scales_on_all_ranks( + self.has_g_idx, self.group_size, row_parallel) verify_marlin_supports_shape( output_size_per_partition=output_size_per_partition, @@ -123,6 +127,16 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, layer.register_parameter("weight_scale", weight_scale) layer.register_parameter("weight_shape", weight_shape) + # group index (for activation reordering) + if self.has_g_idx: + weight_g_idx = RowvLLMParameter(data=torch.empty( + input_size_per_partition, + dtype=torch.int32, + ), + input_dim=0, + weight_loader=weight_loader) + layer.register_parameter("weight_g_idx", weight_g_idx) + layer.input_size_per_partition = input_size_per_partition layer.output_size_per_partition = output_size_per_partition layer.input_size = input_size @@ -137,9 +151,14 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.workspace = marlin_make_workspace( layer.output_size_per_partition, device) - # Act-order not supported in compressed-tensors yet, so set to empty. - layer.g_idx = marlin_make_empty_g_idx(device) - layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) + # Handle sorting for activation reordering if needed. + if self.has_g_idx: + g_idx, g_idx_sort_indices = marlin_sort_g_idx(layer.weight_g_idx) + layer.g_idx_sort_indices = g_idx_sort_indices + replace_tensor(layer, "weight_g_idx", g_idx) + else: + layer.weight_g_idx = marlin_make_empty_g_idx(device) + layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) # No zero-point layer.weight_zp = marlin_make_empty_g_idx(device) @@ -159,9 +178,11 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: replace_tensor(layer, "weight_packed", marlin_qweight) # Permute scales from compressed-tensors format to marlin format. + # scale is required on all partitions if activation reordering marlin_scales = marlin_permute_scales( layer.weight_scale, - size_k=layer.input_size_per_partition, + size_k=(layer.input_size + if self.has_g_idx else layer.input_size_per_partition), size_n=layer.output_size_per_partition, group_size=layer.group_size) replace_tensor(layer, "weight_scale", marlin_scales) @@ -174,7 +195,7 @@ def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, weight=layer.weight_packed, weight_scale=layer.weight_scale, weight_zp=layer.weight_zp, - g_idx=layer.g_idx, + g_idx=layer.weight_g_idx, g_idx_sort_indices=layer.g_idx_sort_indices, workspace=layer.workspace, wtype=self.quant_type, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py index 7912cbde5721f..fc531b9d666e3 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py @@ -1,8 +1,8 @@ import re from enum import Enum -from typing import Any, Dict, Iterable, Optional +from typing import Any, Dict, Iterable, Optional, Union -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_validator from torch.nn import Module from vllm.model_executor.layers.quantization.utils.quant_utils import ( @@ -40,6 +40,19 @@ class QuantizationStrategy(str, Enum): TOKEN = "token" +class ActivationOrdering(str, Enum): + """ + Enum storing strategies for activation ordering + + Group: reorder groups and weight\n + Weight: only reorder weight, not groups. Slightly lower latency and + accuracy compared to group actorder\n + """ + + GROUP = "group" + WEIGHT = "weight" + + class QuantizationArgs(BaseModel): """ User facing arguments used to define a quantization config @@ -58,6 +71,8 @@ class QuantizationArgs(BaseModel): observed with every sample. Defaults to False for static quantization. Note that enabling dynamic quantization will change the default observer to a memoryless one + :param actorder: whether to apply group quantization in decreasing order of + activation. Defaults to None for arbitrary ordering """ num_bits: int = 8 @@ -67,6 +82,7 @@ class QuantizationArgs(BaseModel): strategy: Optional[QuantizationStrategy] = None block_structure: Optional[str] = None dynamic: bool = False + actorder: Union[ActivationOrdering, bool, None] = None observer: str = Field( default="minmax", description=("The class to use to compute the quantization param - " @@ -79,6 +95,16 @@ class QuantizationArgs(BaseModel): "Observers constructor excluding quantization range or symmetry"), ) + @field_validator("actorder", mode="before") + def validate_actorder(cls, value) -> Optional[ActivationOrdering]: + if isinstance(value, bool): + return ActivationOrdering.GROUP if value else None + + if isinstance(value, str): + return ActivationOrdering(value.lower()) + + return value + def is_activation_quantization_format(format: str) -> bool: _ACTIVATION_QUANTIZATION_FORMATS = [ From 6cd5e5b07e4415d064d93b8a66331a097bd9287e Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Mon, 9 Sep 2024 23:02:52 -0400 Subject: [PATCH 02/81] [Misc] Fused MoE Marlin support for GPTQ (#8217) --- .buildkite/test-pipeline.yaml | 13 +- csrc/moe/marlin_moe_ops.cu | 2 +- csrc/moe/marlin_moe_ops.h | 2 +- csrc/moe/torch_bindings.cpp | 1 - tests/kernels/test_moe.py | 221 ++++++++++++- tests/weight_loading/models-large.txt | 3 + tests/weight_loading/models.txt | 2 - .../layers/fused_moe/__init__.py | 14 +- .../layers/fused_moe/fused_marlin_moe.py | 219 ++++++++++++ .../layers/fused_moe/fused_moe.py | 138 ++------ vllm/model_executor/layers/fused_moe/layer.py | 75 +++-- .../compressed_tensors_moe.py | 48 +-- .../schemes/compressed_tensors_wNa16.py | 2 +- .../layers/quantization/gptq_marlin.py | 312 +++++++++++++++++- .../layers/quantization/utils/marlin_utils.py | 17 + .../quantization/utils/marlin_utils_test.py | 11 +- .../layers/quantization/utils/quant_utils.py | 19 +- vllm/model_executor/model_loader/utils.py | 8 + vllm/model_executor/models/mixtral.py | 9 +- 19 files changed, 912 insertions(+), 204 deletions(-) create mode 100644 tests/weight_loading/models-large.txt create mode 100644 vllm/model_executor/layers/fused_moe/fused_marlin_moe.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index d0317b2fc48c9..a0c7b7442b3b3 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -386,7 +386,18 @@ steps: - vllm/ - tests/weight_loading commands: - - bash weight_loading/run_model_weight_loading_test.sh + - bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models.txt + +- label: Weight Loading Multiple GPU Test - Large Models # optional + working_dir: "/vllm-workspace/tests" + num_gpus: 2 + gpu: a100 + optional: true + source_file_dependencies: + - vllm/ + - tests/weight_loading + commands: + - bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models-large.txt ##### multi gpus test ##### diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu index 1e170e80d2f70..92184f43c9eb0 100644 --- a/csrc/moe/marlin_moe_ops.cu +++ b/csrc/moe/marlin_moe_ops.cu @@ -1737,4 +1737,4 @@ torch::Tensor marlin_gemm_moe( moe_block_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, max_par, replicate_input, apply_weights); return c; -} \ No newline at end of file +} diff --git a/csrc/moe/marlin_moe_ops.h b/csrc/moe/marlin_moe_ops.h index 01ba8ff69850d..43d264e0770d6 100644 --- a/csrc/moe/marlin_moe_ops.h +++ b/csrc/moe/marlin_moe_ops.h @@ -9,4 +9,4 @@ torch::Tensor marlin_gemm_moe( const torch::Tensor& g_idx, const torch::Tensor& perm, torch::Tensor& workspace, int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full, int64_t num_experts, int64_t topk, int64_t moe_block_size, - bool replicate_input, bool apply_weights); \ No newline at end of file + bool replicate_input, bool apply_weights); diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index d4d43e2c601b5..8a0e625b43fa1 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -16,7 +16,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { "g_idx, Tensor! perm, Tensor! workspace, int size_m, int size_n, int " "size_k, bool is_k_full, int num_experts, int topk, int moe_block_size, " "bool replicate_input, bool apply_weights) -> Tensor"); - m.impl("marlin_gemm_moe", torch::kCUDA, &marlin_gemm_moe); #endif } diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index f526c381b3339..2250cf1598b8b 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -2,6 +2,8 @@ Run `pytest tests/kernels/test_moe.py`. """ +from typing import List + import pytest import torch from transformers import MixtralConfig @@ -9,7 +11,13 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( + fused_marlin_moe, single_marlin_moe) +from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk +from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( + marlin_quantize) from vllm.model_executor.models.mixtral import MixtralMoE +from vllm.scalar_type import scalar_types def torch_moe(a, w1, w2, score, topk): @@ -29,6 +37,20 @@ def torch_moe(a, w1, w2, score, topk): topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) +def torch_moe_single(a, w, score, topk): + B, D = a.shape + a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) + out = torch.zeros(B * topk, w.shape[1], dtype=a.dtype, device=a.device) + score = torch.softmax(score, dim=-1, dtype=torch.float32) + _, topk_ids = torch.topk(score, topk) + topk_ids = topk_ids.view(-1) + for i in range(w.shape[0]): + mask = topk_ids == i + if mask.sum(): + out[mask] = a[mask] @ w[i].transpose(0, 1) + return (out.view(B, -1, w.shape[1])).sum(dim=1) + + @pytest.mark.parametrize("m", [1024 * 128, 512, 222, 33, 1]) @pytest.mark.parametrize("n", [2048, 256, 1024]) @pytest.mark.parametrize("k", [128, 511, 1024]) @@ -43,11 +65,11 @@ def test_fused_moe( topk: int, dtype: torch.dtype, ): - a = torch.randn((m, k), device='cuda', dtype=dtype) / 10 - w1 = torch.randn((e, 2 * n, k), device='cuda', dtype=dtype) / 10 - w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10 + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 - score = torch.randn((m, e), device='cuda', dtype=dtype) + score = torch.randn((m, e), device="cuda", dtype=dtype) triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False) torch_output = torch_moe(a, w1, w2, score, topk) torch.testing.assert_close(triton_output, torch_output, atol=1e-2, rtol=0) @@ -99,3 +121,194 @@ def test_mixtral_moe(dtype: torch.dtype): vllm_states, rtol=mixtral_moe_tol[dtype], atol=mixtral_moe_tol[dtype]) + + +def stack_and_dev(tensors: List[torch.Tensor]): + dev = tensors[0].device + return torch.stack(tensors, dim=0).to(dev) + + +def compute_max_diff(output, output_ref): + return torch.mean(torch.abs(output - output_ref)) / torch.mean( + torch.abs(output_ref)) + + +@pytest.mark.parametrize("m", [64, 512, 222, 33, 1]) +@pytest.mark.parametrize("n", [128, 2048, 256, 1024]) +@pytest.mark.parametrize("k", [128, 1024, 512]) +@pytest.mark.parametrize("e", [4, 8, 64]) +@pytest.mark.parametrize("topk", [2, 6]) +@pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) +@pytest.mark.parametrize("act_order", [True, False]) +def test_fused_marlin_moe( + m: int, + n: int, + k: int, + e: int, + topk: int, + group_size: int, + act_order: bool, +): + torch.manual_seed(7) + + if topk > e: + return + + # Filter act_order + if act_order: + if group_size == -1: + return + if group_size in (k, n): + return + + quant_type = scalar_types.uint4b8 + dtype = torch.float16 + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + for i in range(w2.shape[0]): + w2[0] = torch.eye(k, n, device="cuda", dtype=dtype) + + w_ref1_l = [] + qweight1_l = [] + scales1_l = [] + g_idx1_l = [] + sort_indices1_l = [] + + for i in range(w1.shape[0]): + test_perm = torch.randperm(k) + w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = marlin_quantize( + w1[i].transpose(1, 0), quant_type, group_size, act_order, + test_perm) + w_ref1_l.append(w_ref1) + qweight1_l.append(qweight1) + scales1_l.append(scales1) + g_idx1_l.append(g_idx1) + sort_indices1_l.append(sort_indices1) + + w_ref1 = stack_and_dev(w_ref1_l) + qweight1 = stack_and_dev(qweight1_l).contiguous() + scales1 = stack_and_dev(scales1_l) + g_idx1 = stack_and_dev(g_idx1_l) + sort_indices1 = stack_and_dev(sort_indices1_l) + + w_ref2_l = [] + qweight2_l = [] + scales2_l = [] + g_idx2_l = [] + sort_indices2_l = [] + + for i in range(w2.shape[0]): + test_perm = torch.randperm(n) + w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = marlin_quantize( + w2[i].transpose(1, 0), quant_type, group_size, act_order, + test_perm) + w_ref2_l.append(w_ref2) + qweight2_l.append(qweight2) + scales2_l.append(scales2) + g_idx2_l.append(g_idx2) + sort_indices2_l.append(sort_indices2) + + w_ref2 = stack_and_dev(w_ref2_l) + qweight2 = stack_and_dev(qweight2_l).contiguous() + scales2 = stack_and_dev(scales2_l) + g_idx2 = stack_and_dev(g_idx2_l) + sort_indices2 = stack_and_dev(sort_indices2_l) + + score = torch.randn((m, e), device="cuda", dtype=dtype) + + topk_weights, topk_ids = fused_topk(a, score, topk, False) + + triton_output = fused_moe( + a, + w_ref1.transpose(1, 2).contiguous(), + w_ref2.transpose(1, 2).contiguous(), + score, + topk, + renormalize=False, + ) + marlin_output = fused_marlin_moe( + a, + qweight1, + qweight2, + score, + g_idx1, + g_idx2, + sort_indices1, + sort_indices2, + topk_weights, + topk_ids, + w1_scale=scales1, + w2_scale=scales2, + ) + + assert compute_max_diff(marlin_output, triton_output) < 4e-2 + + +@pytest.mark.skip("This test is here for the sake of debugging, " + "don't run it in automated tests.") +@pytest.mark.parametrize("m", [64, 512, 222, 33, 1]) +@pytest.mark.parametrize("n", [128, 2048, 256, 1024]) +@pytest.mark.parametrize("k", [128, 1024, 512]) +@pytest.mark.parametrize("e", [4, 8, 64]) +@pytest.mark.parametrize("topk", [2, 6]) +@pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) +@pytest.mark.parametrize("act_order", [True, False]) +def test_marlin_moe_mmm( + m: int, + n: int, + k: int, + e: int, + topk: int, + group_size: int, + act_order: bool, +): + if topk > e: + return + + # Filter act_order + if act_order: + if group_size == -1: + return + if group_size == k: + return + + quant_type = scalar_types.uint4b8 + dtype = torch.float16 + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10 + + w_ref_l = [] + qweights_l = [] + scales_l = [] + g_idx_l = [] + sort_indices_l = [] + + for i in range(w.shape[0]): + test_perm = torch.randperm(k) + w_ref, qweight, scales, g_idx, sort_indices, _ = marlin_quantize( + w[i].transpose(1, 0), quant_type, group_size, act_order, test_perm) + w_ref_l.append(w_ref) + qweights_l.append(qweight) + scales_l.append(scales) + g_idx_l.append(g_idx) + sort_indices_l.append(sort_indices) + + w_ref = stack_and_dev(w_ref_l) + qweight = stack_and_dev(qweights_l).contiguous() + scales = stack_and_dev(scales_l) + g_idx = stack_and_dev(g_idx_l) + sort_indices = stack_and_dev(sort_indices_l) + + score = torch.randn((m, e), device="cuda", dtype=dtype) + marlin_output = single_marlin_moe(a, + qweight, + scales, + score, + g_idx, + sort_indices, + topk, + renormalize=False) + torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk) + + assert compute_max_diff(marlin_output, torch_output) < 1e-2 diff --git a/tests/weight_loading/models-large.txt b/tests/weight_loading/models-large.txt new file mode 100644 index 0000000000000..fe76705746766 --- /dev/null +++ b/tests/weight_loading/models-large.txt @@ -0,0 +1,3 @@ +compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main +compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main +gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, main \ No newline at end of file diff --git a/tests/weight_loading/models.txt b/tests/weight_loading/models.txt index c708e6d5eb897..a90b352a39bca 100644 --- a/tests/weight_loading/models.txt +++ b/tests/weight_loading/models.txt @@ -19,8 +19,6 @@ compressed-tensors, nm-testing/tinyllama-oneshot-w8a16-per-channel, main compressed-tensors, nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test, main compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main -compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main -compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main compressed-tensors, nm-testing/TinyLlama-1.1B-Chat-v1.0-actorder-group, main awq, casperhansen/mixtral-instruct-awq, main awq_marlin, casperhansen/mixtral-instruct-awq, main diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index fd6f41b90042e..e9b5703ca28be 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -2,16 +2,22 @@ FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) from vllm.triton_utils import HAS_TRITON -__all__ = ["FusedMoE", "FusedMoEMethodBase", "FusedMoeWeightScaleSupported"] +__all__ = [ + "FusedMoE", + "FusedMoEMethodBase", + "FusedMoeWeightScaleSupported", +] if HAS_TRITON: - + from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( + fused_marlin_moe, single_marlin_moe) from vllm.model_executor.layers.fused_moe.fused_moe import ( - fused_experts, fused_marlin_moe, fused_moe, fused_topk, - get_config_file_name, grouped_topk) + fused_experts, fused_moe, fused_topk, get_config_file_name, + grouped_topk) __all__ += [ "fused_marlin_moe", + "single_marlin_moe", "fused_moe", "fused_topk", "fused_experts", diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py new file mode 100644 index 0000000000000..200a6148978aa --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -0,0 +1,219 @@ +"""Fused MoE utilities for GPTQ.""" +import functools +from typing import Any, Dict, Optional + +import torch + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.fused_moe.fused_moe import ( + fused_topk, moe_align_block_size, try_get_optimal_moe_config) + + +def single_marlin_moe( + hidden_states: torch.Tensor, + w: torch.Tensor, + scales: torch.Tensor, + gating_output: torch.Tensor, + g_idx: torch.Tensor, + perm: torch.Tensor, + topk: int, + renormalize: bool, + override_config: Optional[Dict[str, Any]] = None) -> torch.Tensor: + """ + This function computes the multiplication of hidden_states with expert + weights used in Marlin MoE, using weights w and top-k gating mechanism. + Its purpose is testing and debugging the fused MoE kernel. + + Parameters: + - hidden_states (torch.Tensor): The input tensor to the Marlin Mul. + - w (torch.Tensor): The set of expert weights. + - scales (torch.Tensor): The quantization scales. + - gating_output (torch.Tensor): The output of the gating operation + (before softmax). + - g_idx (torch.Tensor): The act_order indices. + - perm (torch.Tensor): The act_order input permutation. + - topk (int): The number of top-k experts to select. + - renormalize (bool): If True, renormalize the top-k weights to sum to 1. + - override_config (Optional[Dict[str, Any]]): Optional override + for the kernel configuration. + + Returns: + - torch.Tensor: The output tensor after applying the MoE layer. + """ + # Check constraints. + assert hidden_states.shape[0] == gating_output.shape[0], ( + "Number of tokens mismatch") + assert hidden_states.shape[1] == w.shape[1] * 16, "Hidden size mismatch" + assert gating_output.shape[1] == w.shape[0], "Number of experts mismatch" + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" + assert w.is_contiguous(), "Expert weights must be contiguous" + assert hidden_states.dtype == torch.float16 + + M, K = hidden_states.shape + E = w.shape[0] + N = w.shape[2] // 2 + + topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, + renormalize) + + # This might not be an optimal config for a single MMM + get_config_func = functools.partial(try_get_optimal_moe_config, + w.shape, + w.shape, + topk_ids.shape[1], + None, + override_config=override_config, + is_marlin=True) + config = get_config_func(M) + + block_size_m = config['BLOCK_SIZE_M'] + + sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E) + + max_workspace_size = (N // 64) * 16 + workspace = torch.zeros(max_workspace_size, + dtype=torch.int, + device="cuda", + requires_grad=False) + + intermediate_cache = torch.ops._moe_C.marlin_gemm_moe( + hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales, + g_idx, perm, workspace, M, N, K, True, E, topk, block_size_m, True, + False) + + return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1) + + +def fused_marlin_moe( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + gating_output: torch.Tensor, + g_idx1: torch.Tensor, + g_idx2: torch.Tensor, + perm1: torch.Tensor, + perm2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + override_config: Optional[Dict[str, Any]] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + This function computes a Mixture of Experts (MoE) layer using two sets of + weights, w1 and w2, and top-k gating mechanism. + + Parameters: + - hidden_states (torch.Tensor): The input tensor to the MoE layer. + - w1 (torch.Tensor): The first set of expert weights. + - w2 (torch.Tensor): The second set of expert weights. + - gating_output (torch.Tensor): The output of the gating operation + (before softmax). + - g_idx1 (torch.Tensor): The first set of act_order indices. + - g_idx2 (torch.Tensor): The second set of act_order indices. + - perm1 (torch.Tensor): The first act_order input permutation. + - perm2 (torch.Tensor): The second act_order input permutation. + - topk_weights (torch.Tensor): Top-k weights. + - topk_ids (torch.Tensor): Indices of topk-k elements. + - renormalize (bool): If True, renormalize the top-k weights to sum to 1. + - override_config (Optional[Dict[str, Any]]): Optional override + for the kernel configuration. + - w1_scale (Optional[torch.Tensor]): Optional scale to be used for + w1. + - w2_scale (Optional[torch.Tensor]): Optional scale to be used for + w2. + + Returns: + - torch.Tensor: The output tensor after applying the MoE layer. + """ + # Check constraints. + assert hidden_states.shape[0] == gating_output.shape[ + 0], "Number of tokens mismatch" + assert hidden_states.shape[ + 1] == w1.shape[1] * 16, "Hidden size mismatch w1" + assert hidden_states.shape[ + 1] == w2.shape[2] // 2, "Hidden size mismatch w2" + assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" + assert w1.is_contiguous(), "Expert weights1 must be contiguous" + assert w2.is_contiguous(), "Expert weights2 must be contiguous" + assert hidden_states.dtype == torch.float16 + + M, K = hidden_states.shape + E = w1.shape[0] + N = w2.shape[1] * 16 + topk = topk_ids.shape[1] + + get_config_func = functools.partial( + try_get_optimal_moe_config, + w1.shape, + w2.shape, + topk_ids.shape[1], + None, + override_config=override_config, + is_marlin=True, + ) + config = get_config_func(M) + + block_size_m = config["BLOCK_SIZE_M"] + + sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E) + + max_workspace_size = ((M + 255) // 256) * (max(2 * N, K) // 64) * 16 + workspace = torch.zeros(max_workspace_size, + dtype=torch.int, + device="cuda", + requires_grad=False) + + intermediate_cache2 = torch.empty( + (M * topk_ids.shape[1], N), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + + intermediate_cache1 = torch.ops._moe_C.marlin_gemm_moe( + hidden_states, + w1, + sorted_token_ids, + topk_weights, + topk_ids, + w1_scale, + g_idx1, + perm1, + workspace, + M, + 2 * N, + K, + True, + E, + topk, + block_size_m, + True, + False, + ) + + ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, 2 * N)) + + intermediate_cache3 = torch.ops._moe_C.marlin_gemm_moe( + intermediate_cache2, + w2, + sorted_token_ids, + topk_weights, + topk_ids, + w2_scale, + g_idx2, + perm2, + workspace, + M, + K, + N, + True, + E, + topk, + block_size_m, + False, + True, + ) + + return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), + dim=1) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 05169eaddb256..bd13d8fecbb96 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -323,15 +323,22 @@ def get_moe_configs(E: int, N: int, return None -def get_default_config(M: int, E: int, N: int, K: int, topk: int, - dtype: Optional[str], - is_marlin: bool) -> Dict[str, int]: +def get_default_config( + M: int, + E: int, + N: int, + K: int, + topk: int, + dtype: Optional[str], + is_marlin: bool, +) -> Dict[str, int]: config = { 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8 } + # A heuristic: fused marlin works faster with this config for small M if M <= E or (is_marlin and M <= 32): config = { 'BLOCK_SIZE_M': 16, @@ -342,14 +349,15 @@ def get_default_config(M: int, E: int, N: int, K: int, topk: int, return config -def try_get_optimal_moe_config(w1_shape: Tuple[int, ...], - w2_shape: Tuple[int, ...], - top_k: int, - dtype: Optional[str], - M: int, - override_config: Optional[Dict[str, - Any]] = None, - is_marlin: bool = False): +def try_get_optimal_moe_config( + w1_shape: Tuple[int, ...], + w2_shape: Tuple[int, ...], + top_k: int, + dtype: Optional[str], + M: int, + override_config: Optional[Dict[str, Any]] = None, + is_marlin: bool = False, +): if override_config: config = override_config else: @@ -391,6 +399,7 @@ def fused_topk( topk, dtype=torch.int32, device=hidden_states.device) + ops.topk_softmax( topk_weights, topk_ids, @@ -437,113 +446,6 @@ def grouped_topk(hidden_states: torch.Tensor, return topk_weights, topk_ids -def fused_marlin_moe(hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - gating_output: torch.Tensor, - g_idx1: torch.Tensor, - g_idx2: torch.Tensor, - rand_perm1: torch.Tensor, - rand_perm2: torch.Tensor, - topk: int, - custom_routing_function: Optional[Callable] = None, - renormalize: bool = True, - override_config: Optional[Dict[str, Any]] = None, - use_fp8: bool = False, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None) -> torch.Tensor: - """ - This function computes a Mixture of Experts (MoE) layer using two sets of - weights, w1 and w2, and top-k gating mechanism. - Parameters: - - hidden_states (torch.Tensor): The input tensor to the MoE layer. - - w1 (torch.Tensor): The first set of expert weights. - - w2 (torch.Tensor): The second set of expert weights. - - gating_output (torch.Tensor): The output of the gating operation - (before softmax). - - topk (int): The number of top-k experts to select. - - renormalize (bool): If True, renormalize the top-k weights to sum to 1. - - inplace (bool): If True, perform the operation in-place. - Defaults to False. - - override_config (Optional[Dict[str, Any]]): Optional override - for the kernel configuration. - - use_fp8 (bool): If True, use fp8 arithmetic to compute the inner - products for w1 and w2. Defaults to False. - - w1_scale (Optional[torch.Tensor]): Optional scale to be used for - w1. - - w2_scale (Optional[torch.Tensor]): Optional scale to be used for - w2. - Returns: - - torch.Tensor: The output tensor after applying the MoE layer. - """ - # Check constraints. - assert hidden_states.shape[0] == gating_output.shape[0], ( - "Number of tokens mismatch") - assert hidden_states.shape[ - 1] == w1.shape[1] * 16, "Hidden size mismatch w1" - assert hidden_states.shape[ - 1] == w2.shape[2] // 2, "Hidden size mismatch w2" - assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" - assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" - assert w1.is_contiguous(), "Expert weights1 must be contiguous" - assert w2.is_contiguous(), "Expert weights2 must be contiguous" - assert hidden_states.dtype in [ - torch.float32, torch.float16, torch.bfloat16 - ] - - #TODO fp8 is not implemented yet - assert not use_fp8 - - M, K = hidden_states.shape - E = w1.shape[0] - N = w2.shape[1] * 16 - - if custom_routing_function is None: - topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, - renormalize) - else: - topk_weights, topk_ids = custom_routing_function( - hidden_states, gating_output, topk, renormalize) - - get_config_func = functools.partial(try_get_optimal_moe_config, - w1.shape, - w2.shape, - topk_ids.shape[1], - "float8" if use_fp8 else None, - override_config=override_config, - is_marlin=True) - config = get_config_func(M) - - block_size_m = config['BLOCK_SIZE_M'] - - sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E) - - max_workspace_size = ((M + 255) // 256) * (max(2 * N, K) // 64) * 16 - workspace = torch.zeros(max_workspace_size, - dtype=torch.int, - device="cuda", - requires_grad=False) - - intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N), - device=hidden_states.device, - dtype=hidden_states.dtype) - - intermediate_cache1 = torch.ops._moe_C.marlin_gemm_moe( - hidden_states, w1, sorted_token_ids, topk_weights, topk_ids, w1_scale, - g_idx1, rand_perm1, workspace, M, 2 * N, K, True, E, topk, - block_size_m, True, False) - - ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, 2 * N)) - - intermediate_cache3 = torch.ops._moe_C.marlin_gemm_moe( - intermediate_cache2, w2, sorted_token_ids, topk_weights, topk_ids, - w2_scale, g_idx2, rand_perm2, workspace, M, K, N, True, E, topk, - block_size_m, False, True) - - return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), - dim=1) - - def get_config_dtype_str(dtype: torch.dtype, use_int8_w8a16: Optional[bool] = False, use_fp8_w8a8: Optional[bool] = False): diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 3df0b61a9ebe4..f6c6f5f529408 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -306,10 +306,28 @@ def _load_single_value(self, param: torch.nn.Parameter, # Input scales can be loaded directly and should be equal. param_data[expert_id] = loaded_weight + def _load_g_idx(self, shard_id: str, expert_data: torch.Tensor, + shard_dim: int, loaded_weight: torch.tensor, tp_rank: int): + + if shard_id == "w2": + self._load_w2(shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=tp_rank) + else: + assert shard_id in ("w1", "w3") + expert_data.copy_(loaded_weight) + def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, weight_name: str, shard_id: str, expert_id: int) -> None: + # compressed-tensors represents weights on disk which are flipped + loaded_weight = loaded_weight.t().contiguous() if ( + self.quant_method.__class__.__name__ + == "CompressedTensorsMoEMethod") else loaded_weight + if shard_id not in ("w1", "w2", "w3"): raise ValueError(f"shard_id must be ['w1','w2','w3'] but " f"got {shard_id}.") @@ -325,19 +343,41 @@ def weight_loader(self, param: torch.nn.Parameter, expert_data = param.data[expert_id] tp_rank = get_tensor_model_parallel_rank() - # is_transposed: whether or not the parameter is transposed on disk - # If transposed, the loaded weight will be transposed and the dim - # to shard the loaded weight will be flipped. + # is_transposed: if the dim to shard the weight + # should be flipped. Required by GPTQ, compressed-tensors + # should be whatever dimension intermediate_size is is_transposed = getattr(param, "is_transposed", False) shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id] if is_transposed: - loaded_weight = loaded_weight.t().contiguous() shard_dim = ~shard_dim - # Case weight_scales - if "weight_scale" in weight_name: - # load the weight scaling based on the quantization scheme - # supported weight scales can be found in + # Case input scale: input_scale loading is only supported for fp8 + if "input_scale" in weight_name: + if param.data[expert_id] != 1 and (param.data[expert_id] - + loaded_weight).abs() > 1e-5: + raise ValueError( + "input_scales of w1 and w3 of a layer " + f"must be equal. But got {param.data[expert_id]} " + f"vs. {loaded_weight}") + + self._load_single_value(param=param, + loaded_weight=loaded_weight, + expert_id=expert_id) + return + + # Case g_idx + if "g_idx" in weight_name: + self._load_g_idx(shard_dim=0, + shard_id=shard_id, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=tp_rank) + return + + # Case weight scales and zero_points + if ("scale" in weight_name or "zero" in weight_name): + # load the weight scales and zp based on the quantization scheme + # supported weight scales/zp can be found in # FusedMoeWeightScaleSupported # TODO @dsikka: once hardened, refactor to use vLLM Parameters # specific to each case @@ -366,22 +406,9 @@ def weight_loader(self, param: torch.nn.Parameter, f"quant method must be one of {WEIGHT_SCALE_SUPPORTED}") return + # Case weight_shape if "weight_shape" in weight_name: - self._load_single_value(param=param, - loaded_weight=loaded_weight, - expert_id=expert_id) - return - - # Case input scale - if "input_scale" in weight_name: - # Note: input_scale loading is only supported for fp8 - if param.data[expert_id] != 1 and (param.data[expert_id] - - loaded_weight).abs() > 1e-5: - raise ValueError( - "input_scales of w1 and w3 of a layer " - f"must be equal. But got {param.data[expert_id]} " - f"vs. {loaded_weight}") - + # only required by compressed-tensors self._load_single_value(param=param, loaded_weight=loaded_weight, expert_id=expert_id) @@ -498,4 +525,4 @@ def _load_fp8_scale(self, param: torch.nn.Parameter, param_data[expert_id][idx] = loaded_weight # If we are in the row parallel case (down_proj) else: - param_data[expert_id] = loaded_weight \ No newline at end of file + param_data[expert_id] = loaded_weight diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 36323493d601e..49c29c2775cb6 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -5,9 +5,7 @@ import torch from vllm import _custom_ops as ops -from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase -from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( - WNA16_SUPPORTED_BITS) +from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( CompressionFormat) from vllm.model_executor.utils import set_weight_attrs @@ -40,11 +38,10 @@ def __init__( if not (self.quant_config.quant_format == CompressionFormat.pack_quantized.value - and self.num_bits in WNA16_SUPPORTED_BITS): + and self.num_bits == 4): raise ValueError("For Fused MoE layers, only ", f"{CompressionFormat.pack_quantized.value} ", - "is supported for the following bits: ", - f"{WNA16_SUPPORTED_BITS}") + "is supported for 4 bits") def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size: int, @@ -269,19 +266,30 @@ def apply( custom_routing_function: Optional[Callable] = None, ) -> torch.Tensor: - from vllm.model_executor.layers.fused_moe.fused_moe import ( + from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( fused_marlin_moe) - return fused_marlin_moe(x, - layer.w13_weight_packed, - layer.w2_weight_packed, - router_logits, - layer.w13_g_idx, - layer.w2_g_idx, - layer.w13_g_idx_sort_indices, - layer.w2_g_idx_sort_indices, - top_k, - custom_routing_function, - renormalize=renormalize, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale) + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function) + + return fused_marlin_moe( + x, + layer.w13_weight_packed, + layer.w2_weight_packed, + router_logits, + layer.w13_g_idx, + layer.w2_g_idx, + layer.w13_g_idx_sort_indices, + layer.w2_g_idx_sort_indices, + topk_weights, + topk_ids, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + ) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py index 8897737c1c55a..3cade3d3fbcd0 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py @@ -22,7 +22,7 @@ __all__ = ["CompressedTensorsWNA16"] WNA16_SUPPORTED_TYPES_MAP = { 4: scalar_types.uint4b8, - 8: scalar_types.uint8b128, + 8: scalar_types.uint8b128 } WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys()) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index b06ff7bd2bace..3617a32f80fc1 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -1,18 +1,22 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Union import torch from torch.nn import Parameter from vllm import _custom_ops as ops from vllm.logger import init_logger -from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.fused_moe.layer import ( + FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( apply_gptq_marlin_linear, check_marlin_supported, marlin_is_k_full, - marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales, - marlin_repeat_scales_on_all_ranks, marlin_sort_g_idx, replace_tensor, - verify_marlin_supported, verify_marlin_supports_shape) + marlin_make_empty_g_idx, marlin_make_workspace, marlin_moe_permute_scales, + marlin_permute_scales, marlin_repeat_scales_on_all_ranks, + marlin_sort_g_idx, replace_tensor, verify_marlin_supported, + verify_marlin_supports_shape) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.parameter import (ChannelQuantScaleParameter, GroupQuantScaleParameter, @@ -33,8 +37,14 @@ class GPTQMarlinConfig(QuantizationConfig): (8, True): scalar_types.uint8b128, } - def __init__(self, weight_bits: int, group_size: int, desc_act: bool, - is_sym: bool, lm_head_quantized: bool) -> None: + def __init__( + self, + weight_bits: int, + group_size: int, + desc_act: bool, + is_sym: bool, + lm_head_quantized: bool, + ) -> None: if desc_act and group_size == -1: # In this case, act_order == True is the same as act_order == False # (since we have only one group per output channel) @@ -105,11 +115,14 @@ def override_quantization_method(cls, hf_quant_cfg, " faster inference") return None - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["GPTQMarlinLinearMethod"]: - if (isinstance(layer, LinearBase) or - (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)): + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional[Union["GPTQMarlinLinearMethod", "GPTQMarlinMoEMethod"]]: + if isinstance(layer, LinearBase) or (isinstance(layer, ParallelLMHead) + and self.lm_head_quantized): return GPTQMarlinLinearMethod(self) + elif isinstance(layer, FusedMoE): + return GPTQMarlinMoEMethod(self) return None def get_scaled_act_names(self) -> List[str]: @@ -179,7 +192,8 @@ def create_weights( output_size_per_partition=output_size_per_partition, input_size_per_partition=input_size_per_partition, input_size=input_size, - group_size=group_size) + group_size=group_size, + ) # Determine sharding if marlin_repeat_scales_on_all_ranks(self.quant_config.desc_act, @@ -299,7 +313,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: perm=layer.g_idx_sort_indices, size_k=layer.input_size_per_partition, size_n=layer.output_size_per_partition, - num_bits=self.quant_config.quant_type.size_bits) + num_bits=self.quant_config.quant_type.size_bits, + ) replace_tensor(layer, "qweight", marlin_qweight) # Permute scales from autogptq format to marlin format. @@ -308,7 +323,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: size_k=(layer.input_size if self.quant_config.desc_act else layer.input_size_per_partition), size_n=layer.output_size_per_partition, - group_size=self.quant_config.group_size) + group_size=self.quant_config.group_size, + ) replace_tensor(layer, "scales", marlin_scales) def apply( @@ -329,4 +345,270 @@ def apply( output_size_per_partition=layer.output_size_per_partition, input_size_per_partition=layer.input_size_per_partition, is_k_full=layer.is_k_full, - bias=bias) + bias=bias, + ) + + +class GPTQMarlinMoEMethod(FusedMoEMethodBase): + """MoE Marlin method with quantization.""" + + def __init__(self, quant_config: GPTQMarlinConfig) -> None: + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + # Currently assuming is_k_full is always True + # (input size per partition is the same as full input size) + # Supports only sym for now (no zp) + if self.quant_config.group_size != -1: + scales_size13 = hidden_size // self.quant_config.group_size + scales_size2 = intermediate_size // self.quant_config.group_size + strategy = FusedMoeWeightScaleSupported.GROUP.value + else: + scales_size13 = 1 + scales_size2 = 1 + strategy = FusedMoeWeightScaleSupported.CHANNEL.value + + extra_weight_attrs.update({ + "quant_method": strategy, + "is_transposed": True + }) + # Fused gate_up_proj (column parallel) + w13_qweight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size // self.quant_config.pack_factor, + 2 * intermediate_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_qweight", w13_qweight) + set_weight_attrs(w13_qweight, extra_weight_attrs) + # down_proj (row parallel) + w2_qweight = torch.nn.Parameter( + torch.empty( + num_experts, + intermediate_size // self.quant_config.pack_factor, + hidden_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_qweight", w2_qweight) + set_weight_attrs(w2_qweight, extra_weight_attrs) + # up_proj scales + w13_scales = torch.nn.Parameter( + torch.empty(num_experts, + scales_size13, + 2 * intermediate_size, + dtype=torch.half), + requires_grad=False, + ) + layer.register_parameter("w13_scales", w13_scales) + set_weight_attrs(w13_scales, extra_weight_attrs) + # down_proj scales + w2_scales = torch.nn.Parameter( + torch.empty(num_experts, + scales_size2, + hidden_size, + dtype=torch.half), + requires_grad=False, + ) + layer.register_parameter("w2_scales", w2_scales) + set_weight_attrs(w2_scales, extra_weight_attrs) + # up_proj scales + w13_qzeros = torch.nn.Parameter( + torch.empty(num_experts, + scales_size13, + 2 * intermediate_size // self.quant_config.pack_factor, + dtype=params_dtype), + requires_grad=False, + ) + layer.register_parameter("w13_qzeros", w13_qzeros) + set_weight_attrs(w13_qzeros, extra_weight_attrs) + # down_proj scales + w2_qzeros = torch.nn.Parameter( + torch.empty(num_experts, + scales_size2, + hidden_size // self.quant_config.pack_factor, + dtype=params_dtype), + requires_grad=False, + ) + layer.register_parameter("w2_qzeros", w2_qzeros) + set_weight_attrs(w2_qzeros, extra_weight_attrs) + w13_g_idx = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_g_idx", w13_g_idx) + set_weight_attrs(w13_g_idx, extra_weight_attrs) + w2_g_idx = torch.nn.Parameter( + torch.empty( + num_experts, + intermediate_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_g_idx", w2_g_idx) + set_weight_attrs(w2_g_idx, extra_weight_attrs) + w13_g_idx_sort_indices = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_g_idx_sort_indices", + w13_g_idx_sort_indices) + set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs) + w2_g_idx_sort_indices = torch.nn.Parameter( + torch.empty( + num_experts, + intermediate_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_g_idx_sort_indices", + w2_g_idx_sort_indices) + set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + + # Process act_order + if self.quant_config.desc_act: + # Get sorting based on g_idx + num_experts = layer.w13_g_idx.shape[0] + w13_g_idx_sort_indices = torch.empty_like(layer.w13_g_idx) + w2_g_idx_sort_indices = torch.empty_like(layer.w2_g_idx) + w13_sorted_g_idx = torch.empty_like(layer.w13_g_idx) + w2_sorted_g_idx = torch.empty_like(layer.w2_g_idx) + for e in range(num_experts): + w13_g_idx_sort_indices[e] = torch.argsort( + layer.w13_g_idx[e]).to(torch.int32) + w2_g_idx_sort_indices[e] = torch.argsort(layer.w2_g_idx[e]).to( + torch.int32) + w13_sorted_g_idx[e] = layer.w13_g_idx[e][ + w13_g_idx_sort_indices[e]] + w2_sorted_g_idx[e] = layer.w2_g_idx[e][ + w2_g_idx_sort_indices[e]] + replace_tensor(layer, "w13_g_idx", w13_sorted_g_idx) + replace_tensor(layer, "w2_g_idx", w2_sorted_g_idx) + replace_tensor(layer, "w13_g_idx_sort_indices", + w13_g_idx_sort_indices) + replace_tensor(layer, "w2_g_idx_sort_indices", + w2_g_idx_sort_indices) + else: + # Reset g_idx related tensors + num_experts = layer.w13_g_idx.shape[0] + device = layer.w13_g_idx.device + layer.w13_g_idx = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), + requires_grad=False, + ) + layer.w2_g_idx = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), + requires_grad=False, + ) + layer.w13_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), + requires_grad=False, + ) + layer.w2_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), + requires_grad=False, + ) + # Repack weights + marlin_w13_qweight = ops.gptq_marlin_moe_repack( + layer.w13_qweight, + layer.w13_g_idx_sort_indices, + layer.w13_qweight.shape[1] * self.quant_config.pack_factor, + layer.w13_qweight.shape[2], + self.quant_config.quant_type.size_bits, + ) + replace_tensor(layer, "w13_qweight", marlin_w13_qweight) + marlin_w2_qweight = ops.gptq_marlin_moe_repack( + layer.w2_qweight, + layer.w2_g_idx_sort_indices, + layer.w2_qweight.shape[1] * self.quant_config.pack_factor, + layer.w2_qweight.shape[2], + self.quant_config.quant_type.size_bits, + ) + replace_tensor(layer, "w2_qweight", marlin_w2_qweight) + # Repack scales + marlin_w13_scales = marlin_moe_permute_scales( + s=layer.w13_scales, + size_k=layer.intermediate_size_per_partition, + size_n=layer.w13_scales.shape[2], + group_size=self.quant_config.group_size, + ) + replace_tensor(layer, "w13_scales", marlin_w13_scales) + marlin_w2_scales = marlin_moe_permute_scales( + s=layer.w2_scales, + size_k=layer.w2_scales.shape[1] * self.quant_config.pack_factor, + size_n=layer.w2_scales.shape[2], + group_size=self.quant_config.group_size, + ) + replace_tensor(layer, "w2_scales", marlin_w2_scales) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + ) -> torch.Tensor: + from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( + fused_marlin_moe) + + # The input must currently be float16 + orig_dtype = x.dtype + x = x.half() + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=None) + + return fused_marlin_moe( + x, + layer.w13_qweight, + layer.w2_qweight, + router_logits, + layer.w13_g_idx, + layer.w2_g_idx, + layer.w13_g_idx_sort_indices, + layer.w2_g_idx_sort_indices, + topk_weights, + topk_ids, + w1_scale=layer.w13_scales, + w2_scale=layer.w2_scales, + ).to(orig_dtype) diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index 0ec68ac5b0f21..699d5f1844146 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -176,6 +176,23 @@ def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int, return s +def marlin_moe_permute_scales( + s: torch.Tensor, + size_k: int, + size_n: int, + group_size: int, +): + num_experts = s.shape[0] + output = torch.empty( + (num_experts, s.shape[1], s.shape[2]), + device=s.device, + dtype=s.dtype, + ) + for e in range(num_experts): + output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size) + return output + + def marlin_zero_points(zp: torch.Tensor, size_k: int, size_n: int, num_bits: int) -> torch.Tensor: # Permute zero-points in a similar way to scales, but do not use the diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py index 7d08ac6f87469..4a06c5d63d52d 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py @@ -1,6 +1,6 @@ """Utility functions used for tests and benchmarks""" -from typing import List +from typing import List, Optional import numpy as np import torch @@ -92,8 +92,11 @@ def get_weight_perm(num_bits: int): return perm -def marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int, - act_order: bool): +def marlin_quantize(w: torch.Tensor, + quant_type: ScalarType, + group_size: int, + act_order: bool, + test_perm: Optional[torch.Tensor] = None): size_k, size_n = w.shape num_bits = quant_type.size_bits @@ -104,7 +107,7 @@ def marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int, # Quantize (and apply act_order if provided) w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights( - w, quant_type, group_size, act_order) + w, quant_type, group_size, act_order, test_perm) # For act_order, sort the "weights" and "g_idx" so that group ids are # increasing diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index 33f24ff5d54d3..bdfda31de852b 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -1,5 +1,5 @@ """This file is used for /tests and /benchmarks""" -from typing import List +from typing import List, Optional import numpy import torch @@ -53,7 +53,10 @@ def get_pack_factor(num_bits): return 32 // num_bits -def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int): +def permute_rows(q_w: torch.Tensor, + w_ref: torch.Tensor, + group_size: int, + test_perm: Optional[torch.Tensor] = None): assert q_w.shape == w_ref.shape orig_device = q_w.device @@ -64,7 +67,7 @@ def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int): g_idx[i] = i // group_size # Simulate act_order by doing a random permutation on K - rand_perm = torch.randperm(k_size) + rand_perm = test_perm if test_perm is not None else torch.randperm(k_size) g_idx = g_idx[rand_perm].contiguous() q_w = q_w[rand_perm, :].contiguous() @@ -164,8 +167,11 @@ def reshape_w(w): ) -def gptq_quantize_weights(w: torch.Tensor, quant_type: ScalarType, - group_size: int, act_order: bool): +def gptq_quantize_weights(w: torch.Tensor, + quant_type: ScalarType, + group_size: int, + act_order: bool, + test_perm: Optional[torch.Tensor] = None): size_k, _ = w.shape assert w.is_floating_point(), "w must be float" @@ -186,7 +192,8 @@ def gptq_quantize_weights(w: torch.Tensor, quant_type: ScalarType, ), "For act_order, groupsize = {} must be less than size_k = {}".format( group_size, size_k) - w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size) + w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size, + test_perm) return w_ref, w_q, w_s, g_idx, rand_perm diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index 4bb943ab3afe4..0052489d99dc4 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -24,10 +24,18 @@ def get_model_architecture( # Special handling for quantized Mixtral. # FIXME(woosuk): This is a temporary hack. mixtral_supported = ["fp8", "compressed-tensors"] + # for gptq_marlin, only run fused MoE for int4 + if model_config.quantization == "gptq_marlin": + hf_quant_config = getattr(model_config.hf_config, + "quantization_config", None) + if hf_quant_config and hf_quant_config.get("bits") == 4: + mixtral_supported.append("gptq_marlin") + if (model_config.quantization is not None and model_config.quantization not in mixtral_supported and "MixtralForCausalLM" in architectures): architectures = ["QuantMixtralForCausalLM"] + return ModelRegistry.resolve_model_cls(architectures) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index e744e36ac08bf..10cbfcf6432b3 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -435,7 +435,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): continue # Skip layers on other devices. if is_pp_missing_parameter(name, self): @@ -454,6 +455,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, @@ -464,7 +468,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): break else: # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): continue # Skip layers on other devices. if is_pp_missing_parameter(name, self): From a1d874224d9c29ae84f3850474b4816f0ed9574b Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Mon, 9 Sep 2024 23:21:00 -0700 Subject: [PATCH 03/81] Add NVIDIA Meetup slides, announce AMD meetup, and add contact info (#8319) --- README.md | 16 ++++++++++++---- docs/source/community/meetups.rst | 1 + 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 9ae30f8d2de55..53749cb36b972 100644 --- a/README.md +++ b/README.md @@ -17,15 +17,16 @@ Easy, fast, and cheap LLM serving for everyone --- -**vLLM & NVIDIA Triton User Meetup (Monday, September 9, 5pm-9pm PT) at Fort Mason, San Francisco** +**vLLM, AMD, Anyscale Meet & Greet at [Ray Summit 2024](http://raysummit.anyscale.com) (Monday, Sept 30th, 5-7pm PT) at Marriott Marquis San Francisco** -We are excited to announce our sixth vLLM Meetup, in collaboration with NVIDIA Triton Team. -Join us to hear the vLLM's recent update about performance. -Register now [here](https://lu.ma/87q3nvnh) and be part of the event! +We are excited to announce our special vLLM event in collaboration with AMD and Anyscale. +Join us to learn more about recent advancements of vLLM on MI300X. +Register [here](https://lu.ma/db5ld9n5) and be a part of the event! --- *Latest News* 🔥 +- [2024/09] We hosted [the sixth vLLM meetup](https://lu.ma/87q3nvnh) with NVIDIA! Please find the meetup slides [here](https://docs.google.com/presentation/d/1wrLGwytQfaOTd5wCGSPNhoaW3nq0E-9wqyP7ny93xRs/edit?usp=sharing). - [2024/07] We hosted [the fifth vLLM meetup](https://lu.ma/lp0gyjqr) with AWS! Please find the meetup slides [here](https://docs.google.com/presentation/d/1RgUD8aCfcHocghoP3zmXzck9vX3RCI9yfUAB2Bbcl4Y/edit?usp=sharing). - [2024/07] In partnership with Meta, vLLM officially supports Llama 3.1 with FP8 quantization and pipeline parallelism! Please check out our blog post [here](https://blog.vllm.ai/2024/07/23/llama31.html). - [2024/06] We hosted [the fourth vLLM meetup](https://lu.ma/agivllm) with Cloudflare and BentoML! Please find the meetup slides [here](https://docs.google.com/presentation/d/1iJ8o7V2bQEi0BFEljLTwc5G1S10_Rhv3beed5oB0NJ4/edit?usp=sharing). @@ -130,3 +131,10 @@ If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs year={2023} } ``` + +## Contact Us + +* For technical questions and feature requests, please use Github issues or discussions. +* For discussing with fellow users, please use Discord. +* For security disclosures, please use Github's security advisory feature. +* For collaborations and partnerships, please contact us at vllm-questions AT lists.berkeley.edu. \ No newline at end of file diff --git a/docs/source/community/meetups.rst b/docs/source/community/meetups.rst index 3b01b109ebf2c..a3962e96e7913 100644 --- a/docs/source/community/meetups.rst +++ b/docs/source/community/meetups.rst @@ -5,6 +5,7 @@ vLLM Meetups We host regular meetups in San Francisco Bay Area every 2 months. We will share the project updates from the vLLM team and have guest speakers from the industry to share their experience and insights. Please find the materials of our previous meetups below: +- `The sixth vLLM meetup `__, with NVIDIA, September 9th 2024. `[Slides] `__ - `The fifth vLLM meetup `__, with AWS, July 24th 2024. `[Slides] `__ - `The fourth vLLM meetup `__, with Cloudflare and BentoML, June 11th 2024. `[Slides] `__ - `The third vLLM meetup `__, with Roblox, April 2nd 2024. `[Slides] `__ From da1a844e61366b473cef6b3f7437ea5dc41876a1 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Tue, 10 Sep 2024 16:22:50 +0800 Subject: [PATCH 04/81] [Bugfix] Fix missing `post_layernorm` in CLIP (#8155) --- vllm/model_executor/models/clip.py | 29 +++++++++++++++++++++---- vllm/model_executor/models/siglip.py | 32 +++++++++++++++------------- 2 files changed, 42 insertions(+), 19 deletions(-) diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index 70f1522ae2524..078928f281c26 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -355,6 +355,19 @@ def __init__(self, quant_config=quant_config, num_hidden_layers_override=num_hidden_layers_override) + if len(self.encoder.layers) > config.num_hidden_layers: + raise ValueError( + f"The original encoder only has {config.num_hidden_layers} " + f"layers, but you requested {len(self.encoder.layers)} layers." + ) + elif len(self.encoder.layers) == config.num_hidden_layers: + self.post_layernorm = nn.LayerNorm(embed_dim, + eps=config.layer_norm_eps) + else: + # post_layernorm is unused when we extract intermediate features + # In this case, we can skip it to conserve memory + self.post_layernorm = None + def forward( self, pixel_values: torch.Tensor, @@ -364,7 +377,10 @@ def forward( hidden_states = self.pre_layrnorm(hidden_states) hidden_states = self.encoder(inputs_embeds=hidden_states) - return hidden_states + if self.post_layernorm is None: + return hidden_states + + return self.post_layernorm(hidden_states) class CLIPVisionModel(nn.Module): @@ -386,9 +402,12 @@ def __init__(self, quant_config=quant_config, num_hidden_layers_override=num_hidden_layers_override) - def forward(self, pixel_values: Optional[torch.Tensor] = None): + @property + def _require_post_layernorm(self) -> bool: + return self.vision_model.post_layernorm is not None - return self.vision_model(pixel_values=pixel_values) + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + return self.vision_model(pixel_values) @property def device(self): @@ -408,8 +427,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): for name, loaded_weight in weights: # post_layernorm is not needed in CLIPVisionModel - if "vision_model.post_layernorm" in name: + if ("vision_model.post_layernorm" in name + and not self._require_post_layernorm): continue + # omit layers when num_hidden_layers_override is set if "vision_model.encoder.layers." in name: layer_idx = int(name.split(".")[3]) diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index 13d09e4cd4c23..f7976eba7420b 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -443,27 +443,26 @@ def __init__( self.config = config embed_dim = config.hidden_size - if (num_hidden_layers_override is None - or num_hidden_layers_override == config.num_hidden_layers): - self.need_post_layernorm = True - elif num_hidden_layers_override > config.num_hidden_layers: - raise ValueError( - "num_hidden_layers_override cannot be greater than " - "num_hidden_layers") - else: - self.need_post_layernorm = False - self.embeddings = SiglipVisionEmbeddings(config) self.encoder = SiglipEncoder( config, quant_config=quant_config, num_hidden_layers_override=num_hidden_layers_override, ) - if self.need_post_layernorm: + + if len(self.encoder.layers) > config.num_hidden_layers: + raise ValueError( + f"The original encoder only has {config.num_hidden_layers} " + f"layers, but you requested {len(self.encoder.layers)} layers." + ) + elif len(self.encoder.layers) == config.num_hidden_layers: self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) else: - self.post_layernorm = nn.Identity() + # post_layernorm is unused when we extract intermediate features + # In this case, we can skip it to conserve memory + self.post_layernorm = None + self.use_head = (True if not hasattr(config, "vision_use_head") else config.vision_use_head) if self.use_head: @@ -482,6 +481,9 @@ def forward( encoder_outputs = self.encoder(inputs_embeds=hidden_states) + if self.post_layernorm is None: + return encoder_outputs + last_hidden_state = self.post_layernorm(encoder_outputs) # TODO: add this back when pooled_output is used in inference # if self.use_head: @@ -512,8 +514,8 @@ def __init__( ) @property - def need_post_layernorm(self): - return self.vision_model.need_post_layernorm + def _require_post_layernorm(self) -> bool: + return self.vision_model.post_layernorm is not None def get_input_embeddings(self) -> nn.Module: return self.vision_model.embeddings.patch_embedding @@ -541,7 +543,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): for name, loaded_weight in weights: # post_layernorm is optional in SiglipVisionModel if ("vision_model.post_layernorm" in name - and not self.need_post_layernorm): + and not self._require_post_layernorm): continue # omit layers when num_hidden_layers_override is set From 6234385f4a826edd5c4e0ca7dbdea480be215c5e Mon Sep 17 00:00:00 2001 From: Daniele <36171005+dtrifiro@users.noreply.github.com> Date: Tue, 10 Sep 2024 17:55:08 +0200 Subject: [PATCH 05/81] [CI/Build] enable ccache/scccache for HIP builds (#8327) --- setup.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 1e08a5bd70cd3..994920ede349d 100644 --- a/setup.py +++ b/setup.py @@ -170,14 +170,17 @@ def configure(self, ext: CMakeExtension) -> None: if is_sccache_available(): cmake_args += [ + '-DCMAKE_C_COMPILER_LAUNCHER=sccache', '-DCMAKE_CXX_COMPILER_LAUNCHER=sccache', '-DCMAKE_CUDA_COMPILER_LAUNCHER=sccache', - '-DCMAKE_C_COMPILER_LAUNCHER=sccache', + '-DCMAKE_HIP_COMPILER_LAUNCHER=sccache', ] elif is_ccache_available(): cmake_args += [ + '-DCMAKE_C_COMPILER_LAUNCHER=ccache', '-DCMAKE_CXX_COMPILER_LAUNCHER=ccache', '-DCMAKE_CUDA_COMPILER_LAUNCHER=ccache', + '-DCMAKE_HIP_COMPILER_LAUNCHER=ccache', ] # Pass the python executable to cmake so it can find an exact From 8c054b7a6290551c868451dfd449d40cf37d8b62 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Wed, 11 Sep 2024 00:49:11 +0800 Subject: [PATCH 06/81] [Frontend] Clean up type annotations for mistral tokenizer (#8314) --- tests/async_engine/test_chat_template.py | 5 +- vllm/entrypoints/chat_utils.py | 61 +++++++++++++------ vllm/entrypoints/llm.py | 26 +++++--- vllm/entrypoints/openai/serving_chat.py | 48 +++++++++------ .../openai/serving_tokenization.py | 25 +++++--- vllm/transformers_utils/tokenizers/mistral.py | 8 +-- 6 files changed, 114 insertions(+), 59 deletions(-) diff --git a/tests/async_engine/test_chat_template.py b/tests/async_engine/test_chat_template.py index 4df6c02973284..61a6d77cd8756 100644 --- a/tests/async_engine/test_chat_template.py +++ b/tests/async_engine/test_chat_template.py @@ -1,6 +1,7 @@ import pytest -from vllm.entrypoints.chat_utils import apply_chat_template, load_chat_template +from vllm.entrypoints.chat_utils import (apply_hf_chat_template, + load_chat_template) from vllm.entrypoints.openai.protocol import ChatCompletionRequest from vllm.transformers_utils.tokenizer import get_tokenizer @@ -87,7 +88,7 @@ def test_get_gen_prompt(model, template, add_generation_prompt, add_generation_prompt=add_generation_prompt) # Call the function and get the result - result = apply_chat_template( + result = apply_hf_chat_template( tokenizer, conversation=mock_request.messages, chat_template=mock_request.chat_template or template_content, diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index f9f9536a7c160..a42ad81b3eef4 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -23,6 +23,7 @@ # yapf: enable # pydantic needs the TypedDict from typing_extensions from pydantic import ConfigDict +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from typing_extensions import Required, TypeAlias, TypedDict from vllm.config import ModelConfig @@ -31,7 +32,7 @@ from vllm.multimodal.utils import (async_get_and_parse_audio, async_get_and_parse_image, get_and_parse_audio, get_and_parse_image) -from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer logger = init_logger(__name__) @@ -379,6 +380,9 @@ def _parse_chat_message_content_parts( audio_url = _AudioParser(part)["audio_url"] mm_parser.parse_audio(audio_url["url"]) + elif part_type == "refusal": + text = _RefusalParser(part)["refusal"] + texts.append(text) else: raise NotImplementedError(f"Unknown part type: {part_type}") @@ -433,6 +437,21 @@ def _parse_chat_message_content( return result +def _postprocess_messages(messages: List[ConversationMessage]) -> None: + # per the Transformers docs & maintainers, tool call arguments in + # assistant-role messages with tool_calls need to be dicts not JSON str - + # this is how tool-use chat templates will expect them moving forwards + # so, for messages that have tool_calls, parse the string (which we get + # from openAI format) to dict + for message in messages: + if (message["role"] == "assistant" and "tool_calls" in message + and isinstance(message["tool_calls"], list)): + + for item in message["tool_calls"]: + item["function"]["arguments"] = json.loads( + item["function"]["arguments"]) + + def parse_chat_messages( messages: List[ChatCompletionMessageParam], model_config: ModelConfig, @@ -446,6 +465,8 @@ def parse_chat_messages( conversation.extend(sub_messages) + _postprocess_messages(conversation) + return conversation, mm_tracker.all_mm_data() @@ -462,41 +483,41 @@ def parse_chat_messages_futures( conversation.extend(sub_messages) + _postprocess_messages(conversation) + return conversation, mm_tracker.all_mm_data() -def apply_chat_template( - tokenizer: AnyTokenizer, +def apply_hf_chat_template( + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], conversation: List[ConversationMessage], chat_template: Optional[str], *, tokenize: bool = False, # Different from HF's default **kwargs: Any, -) -> Union[str, List[int]]: +) -> str: if chat_template is None and tokenizer.chat_template is None: raise ValueError( "As of transformers v4.44, default chat template is no longer " "allowed, so you must provide a chat template if the tokenizer " "does not define one.") - # per the Transformers docs & maintainers, tool call arguments in - # assistant-role messages with tool_calls need to be dicts not JSON str - - # this is how tool-use chat templates will expect them moving forwards - # so, for messages that have tool_calls, parse the string (which we get - # from openAI format) to dict - for message in conversation: - if (message["role"] == "assistant" and "tool_calls" in message - and isinstance(message["tool_calls"], list)): + return tokenizer.apply_chat_template( + conversation=conversation, # type: ignore[arg-type] + chat_template=chat_template, + tokenize=tokenize, + **kwargs, + ) - for i in range(len(message["tool_calls"])): - args: str = message["tool_calls"][i]["function"]["arguments"] - parsed_args: Dict = json.loads(args) - message["tool_calls"][i]["function"]["arguments"] = parsed_args - prompt = tokenizer.apply_chat_template( - conversation=conversation, +def apply_mistral_chat_template( + tokenizer: MistralTokenizer, + messages: List[ChatCompletionMessageParam], + chat_template: Optional[str], + **kwargs: Any, +) -> List[int]: + return tokenizer.apply_chat_template( + messages=messages, chat_template=chat_template, - tokenize=tokenize, **kwargs, ) - return prompt diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 1e4432eaaa665..b1d9f386b6c3e 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -6,7 +6,8 @@ from vllm.engine.arg_utils import EngineArgs from vllm.engine.llm_engine import LLMEngine from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, - apply_chat_template, + apply_hf_chat_template, + apply_mistral_chat_template, parse_chat_messages) from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt from vllm.inputs.parse import parse_and_batch_prompt @@ -19,7 +20,7 @@ from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams -from vllm.transformers_utils.tokenizer import (AnyTokenizer, +from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer, get_cached_tokenizer) from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.usage.usage_lib import UsageContext @@ -393,12 +394,21 @@ def chat( conversation, mm_data = parse_chat_messages(messages, model_config, tokenizer) - prompt = apply_chat_template( - tokenizer, - conversation, - chat_template=chat_template, - add_generation_prompt=add_generation_prompt, - ) + prompt: Union[str, List[int]] + if isinstance(tokenizer, MistralTokenizer): + prompt = apply_mistral_chat_template( + tokenizer, + messages=messages, + chat_template=chat_template, + add_generation_prompt=add_generation_prompt, + ) + else: + prompt = apply_hf_chat_template( + tokenizer, + conversation=conversation, + chat_template=chat_template, + add_generation_prompt=add_generation_prompt, + ) inputs: PromptInputs if is_list_of(prompt, int): diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 8ed81e9c88cb2..a81d2aa989aaf 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -11,7 +11,8 @@ from vllm.config import ModelConfig from vllm.engine.protocol import AsyncEngineClient from vllm.entrypoints.chat_utils import (ConversationMessage, - apply_chat_template, + apply_hf_chat_template, + apply_mistral_chat_template, load_chat_template, parse_chat_messages_futures) from vllm.entrypoints.logger import RequestLogger @@ -35,7 +36,7 @@ from vllm.sequence import Logprob from vllm.tracing import (contains_trace_headers, extract_trace_headers, log_tracing_disabled_warning) -from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.utils import iterate_with_cancellation, random_uuid logger = init_logger(__name__) @@ -121,15 +122,27 @@ async def create_chat_completion( tool.model_dump() for tool in request.tools ] - prompt = apply_chat_template( - tokenizer, - conversation=conversation, - chat_template=request.chat_template or self.chat_template, - add_generation_prompt=request.add_generation_prompt, - tools=tool_dicts, - documents=request.documents, - **(request.chat_template_kwargs or {}), - ) + prompt: Union[str, List[int]] + if isinstance(tokenizer, MistralTokenizer): + prompt = apply_mistral_chat_template( + tokenizer, + messages=request.messages, + chat_template=request.chat_template or self.chat_template, + add_generation_prompt=request.add_generation_prompt, + tools=tool_dicts, + documents=request.documents, + **(request.chat_template_kwargs or {}), + ) + else: + prompt = apply_hf_chat_template( + tokenizer, + conversation=conversation, + chat_template=request.chat_template or self.chat_template, + add_generation_prompt=request.add_generation_prompt, + tools=tool_dicts, + documents=request.documents, + **(request.chat_template_kwargs or {}), + ) except Exception as e: logger.error("Error in applying chat template from request: %s", e) return self.create_error_response(str(e)) @@ -307,11 +320,10 @@ async def chat_completion_stream_generator( # Send response to echo the input portion of the # last message if request.echo: - last_msg_content: Optional[str] = "" - if conversation and conversation[-1].get( - "content") and conversation[-1].get( - "role") == role: - last_msg_content = conversation[-1]["content"] + last_msg_content: str = "" + if conversation and "content" in conversation[ + -1] and conversation[-1].get("role") == role: + last_msg_content = conversation[-1]["content"] or "" if last_msg_content: for i in range(num_choices): @@ -659,8 +671,8 @@ async def chat_completion_full_generator( if request.echo: last_msg_content = "" - if conversation and conversation[-1].get( - "content") and conversation[-1].get("role") == role: + if conversation and "content" in conversation[-1] and conversation[ + -1].get("role") == role: last_msg_content = conversation[-1]["content"] or "" for choice in choices: diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index 69a5ad5b62cfa..6e802b71ae2b4 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -2,7 +2,8 @@ from vllm.config import ModelConfig from vllm.engine.protocol import AsyncEngineClient -from vllm.entrypoints.chat_utils import (apply_chat_template, +from vllm.entrypoints.chat_utils import (apply_hf_chat_template, + apply_mistral_chat_template, load_chat_template, parse_chat_messages_futures) from vllm.entrypoints.logger import RequestLogger @@ -18,6 +19,7 @@ from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, OpenAIServing) from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import MistralTokenizer from vllm.utils import random_uuid logger = init_logger(__name__) @@ -66,6 +68,7 @@ async def create_tokenize( tokenizer = await self.async_engine_client.get_tokenizer(lora_request) + prompt: Union[str, List[int]] if isinstance(request, TokenizeChatRequest): model_config = self.model_config @@ -77,12 +80,20 @@ async def create_tokenize( logger.warning( "Multi-modal inputs are ignored during tokenization") - prompt = apply_chat_template( - tokenizer, - conversation=conversation, - chat_template=self.chat_template, - add_generation_prompt=request.add_generation_prompt, - ) + if isinstance(tokenizer, MistralTokenizer): + prompt = apply_mistral_chat_template( + tokenizer, + messages=request.messages, + chat_template=self.chat_template, + add_generation_prompt=request.add_generation_prompt, + ) + else: + prompt = apply_hf_chat_template( + tokenizer, + conversation=conversation, + chat_template=self.chat_template, + add_generation_prompt=request.add_generation_prompt, + ) else: prompt = request.prompt diff --git a/vllm/transformers_utils/tokenizers/mistral.py b/vllm/transformers_utils/tokenizers/mistral.py index 533a86b787325..17e318cb5e047 100644 --- a/vllm/transformers_utils/tokenizers/mistral.py +++ b/vllm/transformers_utils/tokenizers/mistral.py @@ -16,7 +16,7 @@ Tekkenizer) if TYPE_CHECKING: - from vllm.entrypoints.chat_utils import ConversationMessage + from vllm.entrypoints.chat_utils import ChatCompletionMessageParam @dataclass @@ -122,19 +122,19 @@ def get_added_vocab(self) -> List[str]: return [] def encode(self, prompt: str) -> List[int]: - # `encode ` should only be used for prompt completion + # `encode` should only be used for prompt completion # it should never be used for chat_completion. # For chat completion use `apply_chat_template` return self.tokenizer.encode(prompt, bos=True, eos=False) def apply_chat_template(self, - conversation: List["ConversationMessage"], + messages: List["ChatCompletionMessageParam"], tools: Optional[Dict[str, Any]] = None, **kwargs) -> List[int]: assert tools is None, "`tools` are not yet supported." request = ChatCompletionRequest( - messages=conversation) # type: ignore[type-var] + messages=messages) # type: ignore[type-var] encoded = self.mistral.encode_chat_completion(request) # encode-decode to get clean prompt From f421f3cefb58d968767536d745fcc6e9ac342df5 Mon Sep 17 00:00:00 2001 From: "Alexey Kondratiev(AMD)" <143633163+alexeykondrat@users.noreply.github.com> Date: Tue, 10 Sep 2024 14:51:15 -0400 Subject: [PATCH 07/81] [CI/Build] Enabling kernels tests for AMD, ignoring some of then that fail (#8130) --- .buildkite/run-amd-test.sh | 24 +++++++++++++++++++++++- .buildkite/test-pipeline.yaml | 1 + 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/.buildkite/run-amd-test.sh b/.buildkite/run-amd-test.sh index 972c62a091aea..c9b72a3264e82 100755 --- a/.buildkite/run-amd-test.sh +++ b/.buildkite/run-amd-test.sh @@ -71,13 +71,35 @@ mkdir -p ${HF_CACHE} HF_MOUNT="/root/.cache/huggingface" commands=$@ +echo "Commands:$commands" +#ignore certain kernels tests +if [[ $commands == *" kernels "* ]]; then + commands="${commands} \ + --ignore=kernels/test_attention.py \ + --ignore=kernels/test_attention_selector.py \ + --ignore=kernels/test_blocksparse_attention.py \ + --ignore=kernels/test_causal_conv1d.py \ + --ignore=kernels/test_cutlass.py \ + --ignore=kernels/test_encoder_decoder_attn.py \ + --ignore=kernels/test_flash_attn.py \ + --ignore=kernels/test_flashinfer.py \ + --ignore=kernels/test_int8_quant.py \ + --ignore=kernels/test_machete_gemm.py \ + --ignore=kernels/test_mamba_ssm.py \ + --ignore=kernels/test_marlin_gemm.py \ + --ignore=kernels/test_prefix_prefill.py \ + --ignore=kernels/test_rand.py \ + --ignore=kernels/test_sampler.py" +fi + PARALLEL_JOB_COUNT=8 # check if the command contains shard flag, we will run all shards in parallel because the host have 8 GPUs. if [[ $commands == *"--shard-id="* ]]; then for GPU in $(seq 0 $(($PARALLEL_JOB_COUNT-1))); do #replace shard arguments - commands=${@//"--shard-id= "/"--shard-id=${GPU} "} + commands=${commands//"--shard-id= "/"--shard-id=${GPU} "} commands=${commands//"--num-shards= "/"--num-shards=${PARALLEL_JOB_COUNT} "} + echo "Shard ${GPU} commands:$commands" docker run \ --device /dev/kfd --device /dev/dri \ --network host \ diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index a0c7b7442b3b3..e4f70c5d4920a 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -228,6 +228,7 @@ steps: parallelism: 4 - label: Kernels Test %N # 30min each + mirror_hardwares: [amd] source_file_dependencies: - csrc/ - vllm/attention From 02751a7a42c18454030ff35e350afab31e26f51d Mon Sep 17 00:00:00 2001 From: sumitd2 <91451282+sumitd2@users.noreply.github.com> Date: Wed, 11 Sep 2024 01:28:34 +0530 Subject: [PATCH 08/81] Fix ppc64le buildkite job (#8309) --- .buildkite/run-cpu-test-ppc64le.sh | 3 ++- Dockerfile.ppc64le | 5 ++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.buildkite/run-cpu-test-ppc64le.sh b/.buildkite/run-cpu-test-ppc64le.sh index a01cf3fe67489..49ae838cf0690 100755 --- a/.buildkite/run-cpu-test-ppc64le.sh +++ b/.buildkite/run-cpu-test-ppc64le.sh @@ -11,8 +11,9 @@ trap remove_docker_container EXIT remove_docker_container # Run the image, setting --shm-size=4g for tensor parallel. +source /etc/environment #docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test cpu-test -docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true --network host -e HF_TOKEN --name cpu-test cpu-test +docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true --network host -e HF_TOKEN=$HF_TOKEN --name cpu-test cpu-test # Run basic model test docker exec cpu-test bash -c " diff --git a/Dockerfile.ppc64le b/Dockerfile.ppc64le index 16780f8ab950c..27d10e91342e4 100644 --- a/Dockerfile.ppc64le +++ b/Dockerfile.ppc64le @@ -4,7 +4,7 @@ USER root ENV PATH="/usr/local/cargo/bin:$PATH:/opt/conda/bin/" -RUN apt-get update -y && apt-get install -y git wget vim libnuma-dev libsndfile-dev libprotobuf-dev build-essential +RUN apt-get update -y && apt-get install -y git wget curl vim libnuma-dev libsndfile-dev libprotobuf-dev build-essential # Some packages in requirements-cpu are installed here # IBM provides optimized packages for ppc64le processors in the open-ce project for mamba @@ -16,7 +16,7 @@ COPY ./ /workspace/vllm WORKDIR /workspace/vllm # These packages will be in rocketce eventually -RUN pip install -v cmake torch==2.3.1 uvloop==0.20.0 -r requirements-cpu.txt --prefer-binary --extra-index-url https://repo.fury.io/mgiessing +RUN pip install -v cmake xformers torch==2.3.1 uvloop==0.20.0 -r requirements-cpu.txt --prefer-binary --extra-index-url https://repo.fury.io/mgiessing RUN VLLM_TARGET_DEVICE=cpu python3 setup.py install @@ -25,4 +25,3 @@ WORKDIR /workspace/ RUN ln -s /workspace/vllm/tests && ln -s /workspace/vllm/examples && ln -s /workspace/vllm/benchmarks ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"] - From 5faedf1b6224f6e7348e9223f3e3107ec03954d3 Mon Sep 17 00:00:00 2001 From: Kevin Lin <42618777+kevin314@users.noreply.github.com> Date: Tue, 10 Sep 2024 15:18:14 -0500 Subject: [PATCH 09/81] [Spec Decode] Move ops.advance_step to flash attn advance_step (#8224) --- vllm/attention/backends/flash_attn.py | 21 +++++++++++++++------ vllm/spec_decode/draft_model_runner.py | 16 +++------------- vllm/worker/multi_step_model_runner.py | 19 +++++-------------- 3 files changed, 23 insertions(+), 33 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 30ce715d5d05a..06b178798dcd9 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -16,7 +16,8 @@ from vllm.utils import async_tensor_h2d, make_tensor_with_pad if TYPE_CHECKING: - from vllm.worker.model_runner import ModelInputForGPUBuilder + from vllm.worker.model_runner import (ModelInputForGPUBuilder, + ModelInputForGPUWithSamplingMetadata) from vllm_flash_attn import flash_attn_varlen_func as _flash_attn_varlen_func from vllm_flash_attn import flash_attn_with_kvcache as _flash_attn_with_kvcache @@ -302,14 +303,12 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]: ) return self._cached_decode_metadata - def advance_step(self, num_seqs: int, num_queries: int): + def advance_step(self, model_input: "ModelInputForGPUWithSamplingMetadata", + sampled_token_ids: Optional[torch.Tensor], + block_size: int, num_seqs: int, num_queries: int): """ Update metadata in-place to advance one decode step. """ - # GPU in-place update is currently called separately through - # custom_ops.advance_step(). See draft_model_runner. TODO(will): Move - # this logic to the backend. - # When using cudagraph, the num_seqs is padded to the next captured # batch sized, but num_queries tracks the actual number of requests in # the batch. For --enforce-eager mode, num_seqs == num_queries @@ -347,6 +346,16 @@ def advance_step(self, num_seqs: int, num_queries: int): self.seq_lens[i] += 1 self.max_decode_seq_len = max(self.seq_lens) + ops.advance_step(num_seqs=num_seqs, + num_queries=num_queries, + block_size=block_size, + input_tokens=model_input.input_tokens, + sampled_token_ids=sampled_token_ids, + input_positions=model_input.input_positions, + seq_lens=self.seq_lens_tensor, + slot_mapping=self.slot_mapping, + block_tables=self.block_tables) + class FlashAttentionMetadataBuilder( AttentionMetadataBuilder[FlashAttentionMetadata]): diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index 6e35e40294381..1e403637d2388 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -2,7 +2,6 @@ import torch -from vllm import _custom_ops as ops from vllm.model_executor.layers.sampler import SamplerOutput try: @@ -116,18 +115,9 @@ def _gpu_advance_step( # Update attn_metadata attn_metadata = model_input.attn_metadata assert isinstance(attn_metadata, FlashAttentionMetadata) - attn_metadata.advance_step(num_seqs, num_queries) - - # Update GPU tensors - ops.advance_step(num_seqs=num_seqs, - num_queries=num_queries, - block_size=self.block_size, - input_tokens=model_input.input_tokens, - sampled_token_ids=sampled_token_ids, - input_positions=model_input.input_positions, - seq_lens=attn_metadata.seq_lens_tensor, - slot_mapping=attn_metadata.slot_mapping, - block_tables=attn_metadata.block_tables) + + attn_metadata.advance_step(model_input, sampled_token_ids, + self.block_size, num_seqs, num_queries) # Update sampling_metadata sampling_metadata = model_input.sampling_metadata diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py index b13cf39bd846e..9a196c3dfcd1f 100644 --- a/vllm/worker/multi_step_model_runner.py +++ b/vllm/worker/multi_step_model_runner.py @@ -13,7 +13,6 @@ import torch -from vllm import _custom_ops as ops from vllm.distributed import get_pp_group from vllm.logger import init_logger from vllm.model_executor.layers.sampler import (PromptLogprobs, SampleLogprobs, @@ -499,19 +498,11 @@ def _advance_step(self, model_input: StatefulModelInput, attn_metadata = frozen_model_input.attn_metadata assert isinstance(attn_metadata, FlashAttentionMetadata) - attn_metadata.advance_step(num_seqs, num_queries) - - # Update GPU tensors - ops.advance_step( - num_seqs=num_seqs, - num_queries=num_queries, - block_size=self.block_size, - input_tokens=frozen_model_input.input_tokens, - sampled_token_ids=model_input.cached_outputs[-1].sampled_token_ids, - input_positions=frozen_model_input.input_positions, - seq_lens=attn_metadata.seq_lens_tensor, - slot_mapping=attn_metadata.slot_mapping, - block_tables=attn_metadata.block_tables) + + attn_metadata.advance_step( + frozen_model_input, + model_input.cached_outputs[-1].sampled_token_ids, self.block_size, + num_seqs, num_queries) if frozen_model_input.seq_lens is not None: for i in range(num_queries): From 04e7c4e77118159e0b892681acd04a1b50a7ea6e Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Tue, 10 Sep 2024 14:21:56 -0700 Subject: [PATCH 10/81] [Misc] remove peft as dependency for prompt models (#8162) --- vllm/config.py | 8 --- vllm/prompt_adapter/models.py | 2 +- vllm/prompt_adapter/utils.py | 93 +++++++++++++++++++++++++++++++++++ 3 files changed, 94 insertions(+), 9 deletions(-) create mode 100644 vllm/prompt_adapter/utils.py diff --git a/vllm/config.py b/vllm/config.py index 8f5e02e35f28d..9e7c107900aaf 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1558,14 +1558,6 @@ class PromptAdapterConfig: prompt_adapter_dtype: Optional[torch.dtype] = None def __post_init__(self): - library_name = 'peft' - try: - __import__(library_name) - except ImportError as e: - raise ImportError( - f"'{library_name}' is not installed for prompt adapter support." - f"Please install it using 'pip install {library_name}'." - ) from e if self.max_prompt_adapters < 1: raise ValueError(f"max_prompt_adapters " diff --git a/vllm/prompt_adapter/models.py b/vllm/prompt_adapter/models.py index 93eb3bde646ac..18a5f86c341a9 100644 --- a/vllm/prompt_adapter/models.py +++ b/vllm/prompt_adapter/models.py @@ -14,6 +14,7 @@ from vllm.prompt_adapter.layers import ( VocabParallelEmbeddingWithPromptAdapter) # yapf: disable from vllm.prompt_adapter.layers import PromptAdapterMapping +from vllm.prompt_adapter.utils import load_peft_weights logger = logging.getLogger(__name__) @@ -90,7 +91,6 @@ def from_local_checkpoint( config: PromptAdapterConfig, device: str = "cuda", ) -> "PromptAdapterModel": - from peft.utils import load_peft_weights if num_virtual_tokens > config.max_prompt_adapter_token: raise ValueError( diff --git a/vllm/prompt_adapter/utils.py b/vllm/prompt_adapter/utils.py new file mode 100644 index 0000000000000..989cc5a0f87c8 --- /dev/null +++ b/vllm/prompt_adapter/utils.py @@ -0,0 +1,93 @@ +# code borrowed from: https://github.com/huggingface/peft/blob/v0.12.0/src/peft/utils/save_and_load.py#L420 + +import os +from typing import Optional + +import torch +from huggingface_hub import file_exists, hf_hub_download +from huggingface_hub.utils import EntryNotFoundError +from safetensors.torch import load_file as safe_load_file + +WEIGHTS_NAME = "adapter_model.bin" +SAFETENSORS_WEIGHTS_NAME = "adapter_model.safetensors" + + +# Get current device name based on available devices +def infer_device() -> str: + if torch.cuda.is_available(): + return "cuda" + return "cpu" + + +def load_peft_weights(model_id: str, + device: Optional[str] = None, + **hf_hub_download_kwargs) -> dict: + r""" + A helper method to load the PEFT weights from the HuggingFace Hub or locally + + Args: + model_id (`str`): + The local path to the adapter weights or the name of the adapter to + load from the HuggingFace Hub. + device (`str`): + The device to load the weights onto. + hf_hub_download_kwargs (`dict`): + Additional arguments to pass to the `hf_hub_download` method when + loading from the HuggingFace Hub. + """ + path = (os.path.join(model_id, hf_hub_download_kwargs["subfolder"]) + if hf_hub_download_kwargs.get("subfolder", None) is not None else + model_id) + + if device is None: + device = infer_device() + + if os.path.exists(os.path.join(path, SAFETENSORS_WEIGHTS_NAME)): + filename = os.path.join(path, SAFETENSORS_WEIGHTS_NAME) + use_safetensors = True + elif os.path.exists(os.path.join(path, WEIGHTS_NAME)): + filename = os.path.join(path, WEIGHTS_NAME) + use_safetensors = False + else: + token = hf_hub_download_kwargs.get("token", None) + if token is None: + token = hf_hub_download_kwargs.get("use_auth_token", None) + + hub_filename = (os.path.join(hf_hub_download_kwargs["subfolder"], + SAFETENSORS_WEIGHTS_NAME) + if hf_hub_download_kwargs.get("subfolder", None) + is not None else SAFETENSORS_WEIGHTS_NAME) + has_remote_safetensors_file = file_exists( + repo_id=model_id, + filename=hub_filename, + revision=hf_hub_download_kwargs.get("revision", None), + repo_type=hf_hub_download_kwargs.get("repo_type", None), + token=token, + ) + use_safetensors = has_remote_safetensors_file + + if has_remote_safetensors_file: + # Priority 1: load safetensors weights + filename = hf_hub_download( + model_id, + SAFETENSORS_WEIGHTS_NAME, + **hf_hub_download_kwargs, + ) + else: + try: + filename = hf_hub_download(model_id, WEIGHTS_NAME, + **hf_hub_download_kwargs) + except EntryNotFoundError: + raise ValueError( # noqa: B904 + f"Can't find weights for {model_id} in {model_id} or \ + in the Hugging Face Hub. " + f"Please check that the file {WEIGHTS_NAME} or \ + {SAFETENSORS_WEIGHTS_NAME} is present at {model_id}.") + + if use_safetensors: + adapters_weights = safe_load_file(filename, device=device) + else: + adapters_weights = torch.load(filename, + map_location=torch.device(device)) + + return adapters_weights From b1f3e189586dce42bb3dcda20169a9308c9a25fa Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Tue, 10 Sep 2024 15:28:28 -0700 Subject: [PATCH 11/81] [MISC] Keep chunked prefill enabled by default with long context when prefix caching is enabled (#8342) --- vllm/engine/arg_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 9bc03948d3845..7748e11092040 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -878,7 +878,6 @@ def create_engine_config(self) -> EngineConfig: if (is_gpu and not use_sliding_window and not use_spec_decode and not self.enable_lora and not self.enable_prompt_adapter - and not self.enable_prefix_caching and not has_seqlen_agnostic_layers): self.enable_chunked_prefill = True logger.warning( From 22f3a4bc6c6801101728d97edd25ffcdd5a7fd8c Mon Sep 17 00:00:00 2001 From: Alexander Matveev <59768536+alexm-neuralmagic@users.noreply.github.com> Date: Tue, 10 Sep 2024 19:00:35 -0400 Subject: [PATCH 12/81] [Bugfix] lookahead block table with cuda graph max capture (#8340) [Bugfix] Ensure multistep lookahead allocation is compatible with cuda graph max capture (#8340) --- vllm/attention/backends/flash_attn.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 06b178798dcd9..69faa6d343eda 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -471,9 +471,19 @@ def build(self, seq_lens: List[int], query_lens: List[int], # The shape of graph_block_tables is # [max batch size, max context len // block size]. input_block_tables = self.runner.graph_block_tables[:batch_size] + max_blocks = input_block_tables.shape[1] for i, block_table in enumerate(self.block_tables): if block_table: - input_block_tables[i, :len(block_table)] = block_table + num_blocks = len(block_table) + if num_blocks <= max_blocks: + input_block_tables[i, :num_blocks] = block_table + else: + # It may be possible to have more blocks allocated due + # to lookahead slots of multi-step, however, they are + # not used anyway, so can be safely ignored. + input_block_tables[ + i, :max_blocks] = block_table[:max_blocks] + block_tables = torch.from_numpy(input_block_tables).to( device=device, non_blocking=True) else: From 1d5e397aa4d94d0ccc1c9dbad533afa5cb60bb69 Mon Sep 17 00:00:00 2001 From: William Lin Date: Tue, 10 Sep 2024 16:46:08 -0700 Subject: [PATCH 13/81] [Core/Bugfix] pass VLLM_ATTENTION_BACKEND to ray workers (#8172) --- vllm/executor/ray_gpu_executor.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 1359a0d310a70..b124fe2e08ea6 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -242,6 +242,9 @@ def sort_by_driver_then_worker_ip(worker): VLLM_INSTANCE_ID, "VLLM_TRACE_FUNCTION": str(envs.VLLM_TRACE_FUNCTION), + **({ + "VLLM_ATTENTION_BACKEND": envs.VLLM_ATTENTION_BACKEND + } if envs.VLLM_ATTENTION_BACKEND is not None else {}) }, ) for (node_id, _) in worker_node_and_gpu_ids] self._env_vars_for_all_workers = ( From 94144e726cfeeba0c1758751b7fd46a20b6bd3b4 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Tue, 10 Sep 2024 19:51:58 -0400 Subject: [PATCH 14/81] [CI/Build][Kernel] Update CUTLASS to 3.5.1 tag (#8043) --- CMakeLists.txt | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 9c88c31c83da1..f8d6a2be9feae 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -195,9 +195,13 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") FetchContent_Declare( cutlass GIT_REPOSITORY https://github.com/nvidia/cutlass.git - # CUTLASS 3.5.1 - GIT_TAG 06b21349bcf6ddf6a1686a47a137ad1446579db9 + GIT_TAG v3.5.1 GIT_PROGRESS TRUE + + # Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history. + # Important: If GIT_SHALLOW is enabled then GIT_TAG works only with branch names and tags. + # So if the GIT_TAG above is updated to a commit hash, GIT_SHALLOW must be set to FALSE + GIT_SHALLOW TRUE ) FetchContent_MakeAvailable(cutlass) @@ -231,6 +235,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "-gencode arch=compute_90a,code=sm_90a") endif() + # # Machete kernels @@ -289,6 +294,12 @@ define_gpu_extension_target( USE_SABI 3 WITH_SOABI) +# If CUTLASS is compiled on NVCC >= 12.5, it by default uses +# cudaGetDriverEntryPointByVersion as a wrapper to avoid directly calling the +# driver API. This causes problems when linking with earlier versions of CUDA. +# Setting this variable sidesteps the issue by calling the driver directly. +target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1) + # # _moe_C extension # From e497b8aeff5799d4ca2cfd6e01105194ebd39eac Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Wed, 11 Sep 2024 08:59:19 +0800 Subject: [PATCH 15/81] [Misc] Skip loading extra bias for Qwen2-MOE GPTQ models (#8329) --- vllm/model_executor/models/qwen2_moe.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index 56129515ca8d1..d80064601d993 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -469,7 +469,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): continue # Skip layers on other devices. if is_pp_missing_parameter(name, self): @@ -490,6 +491,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, @@ -500,7 +505,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): break else: # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): continue # Skip layers on other devices. if is_pp_missing_parameter(name, self): From 1230263e161caa9fd698e109d33437950769ec09 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Wed, 11 Sep 2024 10:11:01 +0800 Subject: [PATCH 16/81] [Bugfix] Fix InternVL2 vision embeddings process with pipeline parallel (#8299) --- tests/distributed/test_pipeline_parallel.py | 10 ++++++++-- vllm/model_executor/models/internvl.py | 3 ++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 637d2b30f6b1f..d2219eed988e1 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -32,7 +32,9 @@ (1, 4, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"), (2, 2, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"), (2, 2, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"), - (2, 2, 1, 1, 1, "internlm/internlm2_5-7b-chat", "ray"), + (1, 2, 1, 1, 1, "OpenGVLab/InternVL2-1B", "ray"), + (1, 2, 1, 1, 1, "OpenGVLab/InternVL2-2B", "ray"), + (1, 2, 1, 0, 1, "OpenGVLab/InternVL2-4B", "ray"), ], ) @fork_new_process_for_each_test @@ -46,6 +48,8 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, # use half precision for speed and memory savings in CI environment "--dtype", "float16", + "--max-model-len", + "8192", "--pipeline-parallel-size", str(PP_SIZE), "--tensor-parallel-size", @@ -62,7 +66,9 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, tp_args = [ # use half precision for speed and memory savings in CI environment "--dtype", - "bfloat16", + "float16", + "--max-model-len", + "8192", "--tensor-parallel-size", str(max(TP_SIZE, 2)), # We only use 2 GPUs in the CI. "--distributed-executor-backend", diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 0cf63d9e1fb22..81819578a4d8c 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -17,6 +17,7 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig +from vllm.distributed import get_pp_group from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput @@ -480,7 +481,7 @@ def forward( **kwargs: object, ) -> SamplerOutput: image_input = self._parse_and_validate_image_input(**kwargs) - if image_input is not None: + if image_input is not None and get_pp_group().is_first_rank: inputs_embeds = self.language_model.model.get_input_embeddings( input_ids) vision_embeddings = self._process_image_input(image_input) From efcf946a158f02a597086199890b5c7673ffe467 Mon Sep 17 00:00:00 2001 From: Pavani Majety Date: Tue, 10 Sep 2024 21:38:40 -0700 Subject: [PATCH 17/81] [Hardware][NV] Add support for ModelOpt static scaling checkpoints. (#6112) --- examples/fp8/quantizer/README.md | 4 +- tests/models/test_modelopt.py | 79 +++++++++ vllm/config.py | 6 +- vllm/model_executor/layers/linear.py | 3 +- .../layers/quantization/__init__.py | 2 + .../layers/quantization/modelopt.py | 163 ++++++++++++++++++ .../model_loader/weight_utils.py | 7 + 7 files changed, 258 insertions(+), 6 deletions(-) create mode 100644 tests/models/test_modelopt.py create mode 100644 vllm/model_executor/layers/quantization/modelopt.py diff --git a/examples/fp8/quantizer/README.md b/examples/fp8/quantizer/README.md index 0b6944f688b49..d0895e97dc341 100644 --- a/examples/fp8/quantizer/README.md +++ b/examples/fp8/quantizer/README.md @@ -1,6 +1,6 @@ ### Quantizer Utilities -`quantize.py`: NVIDIA Quantization utilities using AMMO, ported from TensorRT-LLM: -`https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/quantization/quantize.py` +`quantize.py`: NVIDIA Quantization utilities using TensorRT-Model-Optimizer, ported +from TensorRT-LLM: [`examples/quantization/quantize.py`](https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/quantization/quantize.py) ### Prerequisite diff --git a/tests/models/test_modelopt.py b/tests/models/test_modelopt.py new file mode 100644 index 0000000000000..e643b115d0ea8 --- /dev/null +++ b/tests/models/test_modelopt.py @@ -0,0 +1,79 @@ +# flake8: noqa +"""Tests Model Optimizer fp8 models against ground truth generation +Note: these tests will only pass on H100 +""" +import os +from typing import List + +import pytest +from transformers import AutoTokenizer + +from tests.quantization.utils import is_quant_method_supported +from vllm import LLM, SamplingParams + +os.environ["TOKENIZERS_PARALLELISM"] = "true" + +MAX_MODEL_LEN = 1024 + +MODELS = ["nvidia/Llama-3.1-8B-Instruct-FP8"] + +EXPECTED_STRS_MAP = { + "nvidia/Llama-3.1-8B-Instruct-FP8": [ + "You're referring to VLLM, a high-performance Large Language Model (LLM) inference and", + 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', + 'The comparison between artificial intelligence (AI) and human intelligence in terms of processing information is a complex and', + 'A neural network is a complex system modeled after the human brain, consisting of interconnected nodes or "ne', + '**The Spark of Imagination**\n\nZeta-5, a sleek and efficient robot, whir', + 'The COVID-19 pandemic has had a profound impact on global economic structures and business models, leading to', + 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', + 'Here are the translations:\n\n**Japanese:** 「早起きは早く獲物をとる' + ] +} + + +# This test compares against golden strings for exact match since +# there is no baseline implementation to compare against +# and is unstable w.r.t specifics of the fp8 implementation or +# the hardware being run on. +# Disabled to prevent it from breaking the build +@pytest.mark.skip( + reason= + "Prevent unstable test based on golden strings from breaking the build.") +@pytest.mark.skipif(not is_quant_method_supported("fp8"), + reason="fp8 is not supported on this GPU type.") +@pytest.mark.parametrize("model_name", MODELS) +def test_models(example_prompts, model_name) -> None: + model = LLM( + model=model_name, + max_model_len=MAX_MODEL_LEN, + trust_remote_code=True, + enforce_eager=True, + quantization="modelopt", + ) + + tokenizer = AutoTokenizer.from_pretrained(model_name) + formatted_prompts = [ + tokenizer.apply_chat_template([{ + "role": "user", + "content": prompt + }], + tokenize=False, + add_generation_prompt=True) + for prompt in example_prompts + ] + params = SamplingParams(max_tokens=20, temperature=0) + generations: List[str] = [] + # Note: these need to be run 1 at a time due to numerical precision, + # since the expected strs were generated this way. + for prompt in formatted_prompts: + outputs = model.generate(prompt, params) + generations.append(outputs[0].outputs[0].text) + del model + + print(model_name, generations) + expected_strs = EXPECTED_STRS_MAP[model_name] + for i in range(len(example_prompts)): + generated_str = generations[i] + expected_str = expected_strs[i] + assert expected_str == generated_str, ( + f"Test{i}:\nExpected: {expected_str!r}\nvLLM: {generated_str!r}") diff --git a/vllm/config.py b/vllm/config.py index 9e7c107900aaf..4d9310af79ed1 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -282,9 +282,9 @@ def _verify_quantization(self) -> None: supported_quantization = [*QUANTIZATION_METHODS] rocm_supported_quantization = ["awq", "gptq", "fp8"] optimized_quantization_methods = [ - "fp8", "marlin", "gptq_marlin_24", "gptq_marlin", "awq_marlin", - "fbgemm_fp8", "compressed_tensors", "compressed-tensors", - "experts_int8" + "fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin", + "awq_marlin", "fbgemm_fp8", "compressed_tensors", + "compressed-tensors", "experts_int8" ] tpu_supported_quantization = ["tpu_int8"] neuron_supported_quantization = ["neuron_quant"] diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index b997507ea738d..cea768469aeb8 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -26,7 +26,8 @@ "CompressedTensorsLinearMethod", "AWQMarlinLinearMethod", "AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod", "MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod", - "TPUInt8LinearMethod", "GPTQLinearMethod", "FBGEMMFp8LinearMethod" + "TPUInt8LinearMethod", "GPTQLinearMethod", "FBGEMMFp8LinearMethod", + "ModelOptFp8LinearMethod" ] diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index aa5c288962d91..3c38f0a006070 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -22,6 +22,7 @@ from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( GPTQMarlin24Config) from vllm.model_executor.layers.quantization.marlin import MarlinConfig +from vllm.model_executor.layers.quantization.modelopt import ModelOptFp8Config from vllm.model_executor.layers.quantization.neuron_quant import ( NeuronQuantConfig) from vllm.model_executor.layers.quantization.qqq import QQQConfig @@ -34,6 +35,7 @@ "tpu_int8": Int8TpuConfig, "fp8": Fp8Config, "fbgemm_fp8": FBGEMMFp8Config, + "modelopt": ModelOptFp8Config, # The order of gptq methods is important for config.py iteration over # override_quantization_method(..) "marlin": MarlinConfig, diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py new file mode 100644 index 0000000000000..dc5f47eb9b0fb --- /dev/null +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -0,0 +1,163 @@ +from typing import Any, Dict, List, Optional + +import torch +from torch.nn import Module +from torch.nn.parameter import Parameter + +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + apply_fp8_linear, cutlass_fp8_supported, requantize_with_max_scale) +from vllm.model_executor.parameter import (ModelWeightParameter, + PerTensorScaleParameter) + +logger = init_logger(__name__) + +ACTIVATION_SCHEMES = ["static"] + + +class ModelOptFp8Config(QuantizationConfig): + """Config class for ModelOpt FP8.""" + + def __init__( + self, + is_checkpoint_fp8_serialized: bool = False, + ) -> None: + self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized + if is_checkpoint_fp8_serialized: + logger.warning("Detected ModelOpt fp8 checkpoint. Please note that" + " the format is experimental and could change.") + + @classmethod + def get_name(cls) -> str: + return "modelopt" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.bfloat16, torch.half] + + @classmethod + def get_min_capability(cls) -> int: + return 89 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return ["hf_quant_config.json"] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "ModelOptFp8Config": + quant_config = cls.get_from_keys(config, ["quantization"]) + quant_method = quant_config["quant_algo"] + is_checkpoint_fp8_serialized = ("FP8" in quant_method) + if not is_checkpoint_fp8_serialized: + raise ValueError("ModelOpt currently only supports static FP8" + "quantization in vLLM. Please check the " + "`hf_quant_config.json` file for your model's " + "quant configuration.") + return cls(is_checkpoint_fp8_serialized) + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["QuantizeMethodBase"]: + from vllm.attention.layer import Attention # Avoid circular import + if isinstance(layer, LinearBase): + return ModelOptFp8LinearMethod(self) + elif isinstance(layer, Attention): + return ModelOptFp8KVCacheMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class ModelOptFp8KVCacheMethod(BaseKVCacheMethod): + """ + Supports loading kv-cache scaling factors from FP8 checkpoints. + """ + + def __init__(self, quant_config: ModelOptFp8Config): + super().__init__(quant_config) + + +class ModelOptFp8LinearMethod(LinearMethodBase): + """Linear method for Model Optimizer static quantization. + Supports loading FP8 checkpoints with static weight scale and + activation scale. Future support might be added for dynamic + scales. + + Limitations: + 1. Only support per-tensor quantization due to torch._scaled_mm support. + 2. Only support float8_e4m3fn datatype + Args: quant_config: The ModelOpt quantization config. + """ + + def __init__(self, quant_config: ModelOptFp8Config): + self.quant_config = quant_config + self.cutlass_fp8_supported = cutlass_fp8_supported() + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + del input_size, output_size + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + layer.logical_widths = output_partition_sizes + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + weight_dtype = (torch.float8_e4m3fn + if self.quant_config.is_checkpoint_fp8_serialized else + params_dtype) + weight = ModelWeightParameter(data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=weight_dtype), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + layer.register_parameter("weight", weight) + + if self.quant_config.is_checkpoint_fp8_serialized: + # WEIGHT SCALE + weight_scale = PerTensorScaleParameter(data=torch.empty( + len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader) + weight_scale[:] = torch.finfo(torch.float32).min + layer.register_parameter("weight_scale", weight_scale) + # INPUT SCALE + scale = PerTensorScaleParameter(data=torch.empty( + len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader) + + scale[:] = torch.finfo(torch.float32).min + layer.register_parameter("input_scale", scale) + + def process_weights_after_loading(self, layer: Module) -> None: + max_w_scale, weight = requantize_with_max_scale( + layer.weight, layer.weight_scale, layer.logical_widths) + layer.weight = Parameter(weight.t(), requires_grad=False) + layer.weight_scale = Parameter(max_w_scale, requires_grad=False) + layer.input_scale = Parameter(layer.input_scale.max(), + requires_grad=False) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return apply_fp8_linear( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + input_scale=layer.input_scale, + bias=bias, + cutlass_fp8_supported=self.cutlass_fp8_supported) diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 075451292a8e4..5051d45dd1154 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -192,6 +192,13 @@ def get_quant_config(model_config: ModelConfig, if model_config.quantization == "bitsandbytes": config["adapter_name_or_path"] = model_name_or_path + elif model_config.quantization == "modelopt": + if config["producer"]["name"] == "modelopt": + return quant_cls.from_config(config) + else: + raise ValueError( + f"Unsupported quantization config" + f" found for {model_config.quantization} in {f}.") return quant_cls.from_config(config) From 6a512a00dfa306762c2878bffc3a5664a758d105 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yangshen=E2=9A=A1Deng?= Date: Wed, 11 Sep 2024 13:21:36 +0800 Subject: [PATCH 18/81] [model] Support for Llava-Next-Video model (#7559) Co-authored-by: Roger Wang Co-authored-by: Cyrus Leung Co-authored-by: Cyrus Leung --- Dockerfile | 1 + Dockerfile.cpu | 1 + Dockerfile.neuron | 4 +- Dockerfile.openvino | 3 +- Dockerfile.ppc64le | 2 +- Dockerfile.tpu | 3 + Dockerfile.xpu | 3 +- docs/source/conf.py | 1 + docs/source/models/supported_models.rst | 14 + examples/offline_inference_vision_language.py | 70 ++- requirements-test.txt | 1 + setup.py | 1 + tests/conftest.py | 56 ++- tests/models/test_llava_next_video.py | 236 +++++++++ vllm/assets/video.py | 85 ++++ vllm/model_executor/models/__init__.py | 6 +- .../model_executor/models/llava_next_video.py | 471 ++++++++++++++++++ vllm/multimodal/registry.py | 3 +- vllm/multimodal/utils.py | 42 ++ vllm/multimodal/video.py | 71 +++ vllm/transformers_utils/image_processor.py | 27 + 21 files changed, 1083 insertions(+), 18 deletions(-) create mode 100644 tests/models/test_llava_next_video.py create mode 100644 vllm/assets/video.py create mode 100644 vllm/model_executor/models/llava_next_video.py create mode 100644 vllm/multimodal/video.py diff --git a/Dockerfile b/Dockerfile index 0ec6655ed449e..5484be5bc5785 100644 --- a/Dockerfile +++ b/Dockerfile @@ -145,6 +145,7 @@ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \ && echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \ && apt-get update -y \ && apt-get install -y ccache software-properties-common git curl sudo vim python3-pip \ + && apt-get install -y ffmpeg libsm6 libxext6 libgl1 \ && add-apt-repository ppa:deadsnakes/ppa \ && apt-get update -y \ && apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv libibverbs-dev \ diff --git a/Dockerfile.cpu b/Dockerfile.cpu index 9a570f988f3db..2b60835255cb4 100644 --- a/Dockerfile.cpu +++ b/Dockerfile.cpu @@ -5,6 +5,7 @@ FROM ubuntu:22.04 AS cpu-test-1 RUN --mount=type=cache,target=/var/cache/apt \ apt-get update -y \ && apt-get install -y curl ccache git wget vim numactl gcc-12 g++-12 python3 python3-pip libtcmalloc-minimal4 libnuma-dev \ + && apt-get install -y ffmpeg libsm6 libxext6 libgl1 \ && update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12 # https://intel.github.io/intel-extension-for-pytorch/cpu/latest/tutorials/performance_tuning/tuning_guide.html diff --git a/Dockerfile.neuron b/Dockerfile.neuron index caa1b1d6c4424..f0c3479625a70 100644 --- a/Dockerfile.neuron +++ b/Dockerfile.neuron @@ -6,7 +6,9 @@ FROM $BASE_IMAGE RUN echo "Base image is $BASE_IMAGE" # Install some basic utilities -RUN apt-get update && apt-get install python3 python3-pip -y +RUN apt-get update \ + && apt-get install python3 python3-pip -y \ + && apt-get install -y ffmpeg libsm6 libxext6 libgl1 ### Mount Point ### # When launching the container, mount the code directory to /app diff --git a/Dockerfile.openvino b/Dockerfile.openvino index 06ca4638dfeb9..96b9593a2bfa8 100644 --- a/Dockerfile.openvino +++ b/Dockerfile.openvino @@ -4,7 +4,8 @@ FROM ubuntu:22.04 AS dev RUN apt-get update -y && \ - apt-get install -y python3-pip git + apt-get install -y python3-pip git && \ + apt-get install -y ffmpeg libsm6 libxext6 libgl1 WORKDIR /workspace # copy requirements diff --git a/Dockerfile.ppc64le b/Dockerfile.ppc64le index 27d10e91342e4..3313162bf28e1 100644 --- a/Dockerfile.ppc64le +++ b/Dockerfile.ppc64le @@ -4,7 +4,7 @@ USER root ENV PATH="/usr/local/cargo/bin:$PATH:/opt/conda/bin/" -RUN apt-get update -y && apt-get install -y git wget curl vim libnuma-dev libsndfile-dev libprotobuf-dev build-essential +RUN apt-get update -y && apt-get install -y git wget curl vim libnuma-dev libsndfile-dev libprotobuf-dev build-essential ffmpeg libsm6 libxext6 libgl1 # Some packages in requirements-cpu are installed here # IBM provides optimized packages for ppc64le processors in the open-ce project for mamba diff --git a/Dockerfile.tpu b/Dockerfile.tpu index 3a11c6721ead9..04cd4d79f4045 100644 --- a/Dockerfile.tpu +++ b/Dockerfile.tpu @@ -4,6 +4,9 @@ ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:night FROM $BASE_IMAGE WORKDIR /workspace +# Install some basic utilities +RUN apt-get update && apt-get install -y ffmpeg libsm6 libxext6 libgl1 + # Install the TPU and Pallas dependencies. RUN python3 -m pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html RUN python3 -m pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html diff --git a/Dockerfile.xpu b/Dockerfile.xpu index f91baa11a3753..321da98cf6c89 100644 --- a/Dockerfile.xpu +++ b/Dockerfile.xpu @@ -9,8 +9,7 @@ RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRO chmod 644 /usr/share/keyrings/intel-graphics.gpg RUN apt-get update -y \ -&& apt-get install -y curl libicu70 lsb-release git wget vim numactl python3 python3-pip - +&& apt-get install -y curl libicu70 lsb-release git wget vim numactl python3 python3-pip ffmpeg libsm6 libxext6 libgl1 COPY ./ /workspace/vllm WORKDIR /workspace/vllm diff --git a/docs/source/conf.py b/docs/source/conf.py index b4f5b4ab9d569..8435129e752e1 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -99,6 +99,7 @@ def setup(app): "aiohttp", "compressed_tensors", "cpuinfo", + "cv2", "torch", "transformers", "psutil", diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 1bb3a448f2c92..29fa5d812deb2 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -227,6 +227,11 @@ Multimodal Language Models - Image\ :sup:`E+` - :code:`llava-hf/llava-v1.6-mistral-7b-hf`, :code:`llava-hf/llava-v1.6-vicuna-7b-hf`, etc. - + * - :code:`LlavaNextVideoForConditionalGeneration` + - LLaVA-NeXT-Video + - Video + - :code:`llava-hf/LLaVA-NeXT-Video-7B-hf`, etc. (see note) + - * - :code:`MiniCPMV` - MiniCPM-V - Image\ :sup:`+` @@ -260,6 +265,15 @@ Multimodal Language Models For :code:`openbmb/MiniCPM-V-2`, the official repo doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now. For more details, please see: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630 + For :code:`LLaVA-NeXT-Video`, the latest release of :code:`huggingface/transformers` doesn't work yet, so we need to use a developer version (:code:`21fac7abba2a37fae86106f87fcf9974fd1e3830`) for now. + This can be installed by running the following command: + + + .. code-block:: bash + + pip install git+https://github.com/huggingface/transformers.git@21fac7abba2a37fae86106f87fcf9974fd1e3830 + + ---- If your model uses one of the above model architectures, you can seamlessly run your model with vLLM. diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index aa1580343aee7..2ec691608df6d 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -9,12 +9,9 @@ from vllm import LLM, SamplingParams from vllm.assets.image import ImageAsset +from vllm.assets.video import VideoAsset from vllm.utils import FlexibleArgumentParser -# Input image and question -image = ImageAsset("cherry_blossom").pil_image.convert("RGB") -question = "What is the content of this image?" - # LLaVA-1.5 def run_llava(question): @@ -30,7 +27,16 @@ def run_llava(question): def run_llava_next(question): prompt = f"[INST] \n{question} [/INST]" - llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf") + llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf", max_model_len=8192) + stop_token_ids = None + return llm, prompt, stop_token_ids + + +# LlaVA-NeXT-Video +# Currently only support for video input +def run_llava_next_video(question): + prompt = f"USER: