diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index c2ad98b7e2656..4f88e8e6eb1a6 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -88,22 +88,23 @@ def prepare(i: int): input_gating.copy_(gating_output[i]) def run(): - fused_moe( - x, - w1, - w2, - input_gating, - topk, - renormalize=True, - inplace=True, - override_config=config, - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a16=use_int8_w8a16, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a1_scale, - a2_scale=a2_scale, - ) + from vllm.model_executor.layers.fused_moe import override_config + with override_config(config): + fused_moe( + x, + w1, + w2, + input_gating, + topk, + renormalize=True, + inplace=True, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a16=use_int8_w8a16, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + ) # JIT compilation & warmup run() diff --git a/tests/compile/test_basic_correctness.py b/tests/compile/test_basic_correctness.py index 77c56d91d0a8b..6aa27b24b4a6e 100644 --- a/tests/compile/test_basic_correctness.py +++ b/tests/compile/test_basic_correctness.py @@ -13,11 +13,11 @@ @pytest.mark.parametrize( "model, model_args, pp_size, tp_size, attn_backend, method, fullgraph", [ - ("meta-llama/Llama-3.2-1B", [], 2, 2, "FLASH_ATTN", "generate", True), + ("meta-llama/Llama-3.2-1B", [], 2, 2, "FLASHINFER", "generate", True), ("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples", ["--quantization", "compressed-tensors" ], 1, 1, "FLASH_ATTN", "generate", True), - ("google/gemma-2-2b-it", [], 1, 2, "FLASHINFER", "generate", True), + ("ibm/PowerMoE-3b", [], 1, 2, "FLASH_ATTN", "generate", True), # TODO: add multi-modality test for llava ("llava-hf/llava-1.5-7b-hf", [], 2, 1, "FLASHINFER", "generate", False) ]) diff --git a/tests/kernels/test_awq_marlin.py b/tests/kernels/test_awq_marlin.py index 0f0a2b24563fd..59917dd2c58ad 100644 --- a/tests/kernels/test_awq_marlin.py +++ b/tests/kernels/test_awq_marlin.py @@ -5,11 +5,10 @@ import pytest import torch +import vllm.model_executor.layers.fused_moe # noqa from tests.kernels.utils import (compute_max_diff, stack_and_dev, torch_moe, torch_moe_single) from vllm import _custom_ops as ops -from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( - fused_marlin_moe, single_marlin_moe) from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( awq_marlin_quantize) @@ -81,7 +80,7 @@ def test_fused_marlin_moe_awq( score = torch.randn((m, e), device="cuda", dtype=dtype) topk_weights, topk_ids = fused_topk(a, score, topk, False) - marlin_output = fused_marlin_moe( + marlin_output = torch.ops.vllm.fused_marlin_moe( a, qweight1, qweight2, @@ -150,14 +149,14 @@ def test_single_marlin_moe_multiply_awq( score = torch.randn((m, e), device="cuda", dtype=dtype) - marlin_output = single_marlin_moe(a, - qweight, - scales, - score, - topk, - renormalize=False, - w_zeros=zp, - num_bits=num_bits) + marlin_output = torch.ops.vllm.single_marlin_moe(a, + qweight, + scales, + score, + topk, + renormalize=False, + w_zeros=zp, + num_bits=num_bits) torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk) diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index 4bfc089c82179..70906ab2187bc 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -7,12 +7,11 @@ from transformers import MixtralConfig from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock +import vllm.model_executor.layers.fused_moe # noqa from tests.kernels.utils import (compute_max_diff, opcheck, stack_and_dev, torch_moe, torch_moe_single) from vllm import _custom_ops as ops from vllm.model_executor.layers.fused_moe import fused_moe -from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( - fused_marlin_moe, single_marlin_moe) from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_topk, moe_align_block_size) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( @@ -193,7 +192,7 @@ def test_fused_marlin_moe( topk, renormalize=False, ) - marlin_output = fused_marlin_moe( + marlin_output = torch.ops.vllm.fused_marlin_moe( a, qweight1, qweight2, @@ -309,7 +308,7 @@ def test_single_marlin_moe_multiply( sort_indices = stack_and_dev(sort_indices_l) score = torch.randn((m, e), device="cuda", dtype=dtype) - marlin_output = single_marlin_moe( + marlin_output = torch.ops.vllm.single_marlin_moe( a, qweight, scales, diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index e9b5703ca28be..c4223d12600ac 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -1,23 +1,43 @@ +from contextlib import contextmanager +from typing import Any, Dict, Optional + from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) from vllm.triton_utils import HAS_TRITON +_config: Optional[Dict[str, Any]] = None + + +@contextmanager +def override_config(config): + global _config + old_config = _config + _config = config + yield + _config = old_config + + +def get_config() -> Optional[Dict[str, Any]]: + return _config + + __all__ = [ "FusedMoE", "FusedMoEMethodBase", "FusedMoeWeightScaleSupported", + "override_config", + "get_config", ] if HAS_TRITON: - from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( - fused_marlin_moe, single_marlin_moe) + # import to register the custom ops + import vllm.model_executor.layers.fused_moe.fused_marlin_moe # noqa + import vllm.model_executor.layers.fused_moe.fused_moe # noqa from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_experts, fused_moe, fused_topk, get_config_file_name, grouped_topk) __all__ += [ - "fused_marlin_moe", - "single_marlin_moe", "fused_moe", "fused_topk", "fused_experts", diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 5ae40a2af5a2b..93019d0d0abb6 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -1,6 +1,6 @@ """Fused MoE utilities for GPTQ.""" import functools -from typing import Any, Dict, Optional +from typing import Optional import torch @@ -18,6 +18,7 @@ def get_scalar_type(num_bits: int, has_zp: bool): return scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128 +@torch.library.custom_op("vllm::single_marlin_moe", mutates_args=[]) def single_marlin_moe( hidden_states: torch.Tensor, w: torch.Tensor, @@ -28,7 +29,6 @@ def single_marlin_moe( g_idx: Optional[torch.Tensor] = None, sort_indices: Optional[torch.Tensor] = None, w_zeros: Optional[torch.Tensor] = None, - override_config: Optional[Dict[str, Any]] = None, num_bits: int = 8, is_k_full: bool = True, ) -> torch.Tensor: @@ -49,8 +49,6 @@ def single_marlin_moe( - topk (int): The number of top-k experts to select. - renormalize (bool): If True, renormalize the top-k weights to sum to 1. - w_zeros (Optional[torch.Tensor]): Optional zero points to be used for w. - - override_config (Optional[Dict[str, Any]]): Optional override - for the kernel configuration. - num_bits (bool): The number of bits in expert weights quantization. Returns: @@ -79,7 +77,6 @@ def single_marlin_moe( w.shape, topk_ids.shape[1], None, - override_config=override_config, is_marlin=True) config = get_config_func(M) @@ -122,6 +119,24 @@ def single_marlin_moe( return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1) +@single_marlin_moe.register_fake +def _( + hidden_states: torch.Tensor, + w: torch.Tensor, + scales: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + g_idx: Optional[torch.Tensor] = None, + sort_indices: Optional[torch.Tensor] = None, + w_zeros: Optional[torch.Tensor] = None, + num_bits: int = 8, + is_k_full: bool = True, +) -> torch.Tensor: + return torch.empty_like(hidden_states) + + +@torch.library.custom_op("vllm::fused_marlin_moe", mutates_args=[]) def fused_marlin_moe( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -137,7 +152,6 @@ def fused_marlin_moe( sort_indices2: Optional[torch.Tensor] = None, w1_zeros: Optional[torch.Tensor] = None, w2_zeros: Optional[torch.Tensor] = None, - override_config: Optional[Dict[str, Any]] = None, num_bits: int = 8, is_k_full: bool = True, ) -> torch.Tensor: @@ -161,8 +175,6 @@ def fused_marlin_moe( permutation. - topk_weights (torch.Tensor): Top-k weights. - topk_ids (torch.Tensor): Indices of topk-k elements. - - override_config (Optional[Dict[str, Any]]): Optional override - for the kernel configuration. - w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1. - w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2. - num_bits (bool): The number of bits in expert weights quantization. @@ -209,7 +221,6 @@ def fused_marlin_moe( w2.shape, topk_ids.shape[1], None, - override_config=override_config, is_marlin=True, ) config = get_config_func(M) @@ -311,3 +322,25 @@ def fused_marlin_moe( return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1) + + +@fused_marlin_moe.register_fake +def _( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + gating_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + g_idx1: Optional[torch.Tensor] = None, + g_idx2: Optional[torch.Tensor] = None, + sort_indices1: Optional[torch.Tensor] = None, + sort_indices2: Optional[torch.Tensor] = None, + w1_zeros: Optional[torch.Tensor] = None, + w2_zeros: Optional[torch.Tensor] = None, + num_bits: int = 8, + is_k_full: bool = True, +) -> torch.Tensor: + return torch.empty_like(hidden_states) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 90a4209b5bce5..1cf5c2253ca0b 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -358,9 +358,10 @@ def try_get_optimal_moe_config( top_k: int, dtype: Optional[str], M: int, - override_config: Optional[Dict[str, Any]] = None, is_marlin: bool = False, ): + from vllm.model_executor.layers.fused_moe import get_config + override_config = get_config() if override_config: config = override_config else: @@ -465,19 +466,109 @@ def get_config_dtype_str(dtype: torch.dtype, return None +@torch.library.custom_op("vllm::inplace_fused_experts", + mutates_args=["hidden_states"]) +def inplace_fused_experts(hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None) -> None: + fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True, + use_fp8_w8a8, use_int8_w8a16, w1_scale, w2_scale, + a1_scale, a2_scale) + + +@inplace_fused_experts.register_fake +def _(hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None) -> None: + pass + + +@torch.library.custom_op("vllm::outplace_fused_experts", mutates_args=[]) +def outplace_fused_experts( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None) -> torch.Tensor: + return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, + False, use_fp8_w8a8, use_int8_w8a16, w1_scale, + w2_scale, a1_scale, a2_scale) + + +@outplace_fused_experts.register_fake +def _(hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None) -> torch.Tensor: + return torch.empty_like(hidden_states) + + def fused_experts(hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, inplace: bool = False, - override_config: Optional[Dict[str, Any]] = None, use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None): + if inplace: + torch.ops.vllm.inplace_fused_experts(hidden_states, w1, w2, + topk_weights, topk_ids, + use_fp8_w8a8, use_int8_w8a16, + w1_scale, w2_scale, a1_scale, + a2_scale) + return hidden_states + else: + return torch.ops.vllm.outplace_fused_experts( + hidden_states, w1, w2, topk_weights, topk_ids, use_fp8_w8a8, + use_int8_w8a16, w1_scale, w2_scale, a1_scale, a2_scale) + + +def fused_experts_impl(hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None): # Check constraints. assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" @@ -504,7 +595,6 @@ def fused_experts(hidden_states: torch.Tensor, w2.shape, topk_ids.shape[1], config_dtype, - override_config=override_config, ) config = get_config_func(M) @@ -602,7 +692,6 @@ def fused_moe( topk: int, renormalize: bool, inplace: bool = False, - override_config: Optional[Dict[str, Any]] = None, use_grouped_topk: bool = False, num_expert_group: Optional[int] = None, topk_group: Optional[int] = None, @@ -628,8 +717,6 @@ def fused_moe( - renormalize (bool): If True, renormalize the top-k weights to sum to 1. - inplace (bool): If True, perform the operation in-place. Defaults to False. - - override_config (Optional[Dict[str, Any]]): Optional override - for the kernel configuration. - num_expert_group: Optional[int]: additional parameter for grouped_topk - topk_group: Optional[int]: additional parameter for grouped_topk - use_grouped_topk: If True, use grouped_topk instead of fused_topk @@ -667,7 +754,6 @@ def fused_moe( topk_weights, topk_ids, inplace=inplace, - override_config=override_config, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, w1_scale=w1_scale, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 8dd36620e3fa0..5570771ac917b 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -12,7 +12,16 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.utils import set_weight_attrs - +from vllm.platforms import current_platform + +if current_platform.is_cuda_alike(): + from .fused_moe import fused_experts +else: + fused_experts = None # type: ignore +if current_platform.is_tpu(): + from .moe_pallas import fused_moe as fused_moe_pallas +else: + fused_moe_pallas = None # type: ignore logger = init_logger(__name__) @@ -96,9 +105,6 @@ def forward_cuda( num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None ) -> torch.Tensor: - from vllm.model_executor.layers.fused_moe.fused_moe import ( - fused_experts) - topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, @@ -132,17 +138,18 @@ def forward_tpu( num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None ) -> torch.Tensor: - from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe assert not use_grouped_topk assert num_expert_group is None assert topk_group is None assert custom_routing_function is None - return fused_moe(hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk=top_k, - gating_output=router_logits, - renormalize=renormalize) + return fused_moe_pallas(hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk=top_k, + gating_output=router_logits, + renormalize=renormalize) + + forward_native = forward_cuda class FusedMoE(torch.nn.Module): diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index b3d93b285769c..95ec12daeeeb5 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -3,6 +3,7 @@ import torch from torch.nn import Parameter +import vllm.model_executor.layers.fused_moe # noqa from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.layer import ( @@ -435,10 +436,6 @@ def apply( topk_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, ) -> torch.Tensor: - - from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( - fused_marlin_moe) - topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, @@ -449,7 +446,7 @@ def apply( num_expert_group=num_expert_group, custom_routing_function=custom_routing_function) - return fused_marlin_moe( + return torch.ops.vllm.fused_marlin_moe( x, layer.w13_qweight, layer.w2_qweight, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index be3d3985a74ad..dad04017d3212 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -6,6 +6,7 @@ from compressed_tensors import CompressionFormat from compressed_tensors.quantization import QuantizationStrategy +import vllm.model_executor.layers.fused_moe # noqa from vllm import _custom_ops as ops from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) @@ -481,10 +482,6 @@ def apply( topk_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, ) -> torch.Tensor: - - from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( - fused_marlin_moe) - topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, @@ -495,7 +492,7 @@ def apply( num_expert_group=num_expert_group, custom_routing_function=custom_routing_function) - return fused_marlin_moe( + return torch.ops.vllm.fused_marlin_moe( x, layer.w13_weight_packed, layer.w2_weight_packed, diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index e77191796bd7e..b97dd108d6785 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -2,6 +2,7 @@ import torch +import vllm.model_executor.layers.fused_moe # noqa from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.layer import ( @@ -536,9 +537,6 @@ def apply( topk_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, ) -> torch.Tensor: - from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( - fused_marlin_moe) - # The input must currently be float16 orig_dtype = x.dtype x = x.half() @@ -553,7 +551,7 @@ def apply( num_expert_group=num_expert_group, custom_routing_function=None) - return fused_marlin_moe( + return torch.ops.vllm.fused_marlin_moe( x, layer.w13_qweight, layer.w2_qweight, diff --git a/vllm/model_executor/models/granitemoe.py b/vllm/model_executor/models/granitemoe.py index fd0d4c89a28fe..5307bb21adb96 100644 --- a/vllm/model_executor/models/granitemoe.py +++ b/vllm/model_executor/models/granitemoe.py @@ -28,6 +28,7 @@ from transformers.models.granitemoe import GraniteMoeConfig from vllm.attention import Attention, AttentionMetadata +from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.fused_moe import FusedMoE @@ -244,6 +245,7 @@ def forward( return hidden_states +@support_torch_compile class GraniteMoeModel(nn.Module): def __init__(