Skip to content

Commit

Permalink
[torch.compile] support moe models (vllm-project#9632)
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: Shanshan Wang <[email protected]>
  • Loading branch information
youkaichao authored and cooleel committed Oct 28, 2024
1 parent 5038583 commit e703269
Show file tree
Hide file tree
Showing 12 changed files with 217 additions and 78 deletions.
33 changes: 17 additions & 16 deletions benchmarks/kernels/benchmark_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,22 +88,23 @@ def prepare(i: int):
input_gating.copy_(gating_output[i])

def run():
fused_moe(
x,
w1,
w2,
input_gating,
topk,
renormalize=True,
inplace=True,
override_config=config,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
)
from vllm.model_executor.layers.fused_moe import override_config
with override_config(config):
fused_moe(
x,
w1,
w2,
input_gating,
topk,
renormalize=True,
inplace=True,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
)

# JIT compilation & warmup
run()
Expand Down
4 changes: 2 additions & 2 deletions tests/compile/test_basic_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
@pytest.mark.parametrize(
"model, model_args, pp_size, tp_size, attn_backend, method, fullgraph",
[
("meta-llama/Llama-3.2-1B", [], 2, 2, "FLASH_ATTN", "generate", True),
("meta-llama/Llama-3.2-1B", [], 2, 2, "FLASHINFER", "generate", True),
("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples",
["--quantization", "compressed-tensors"
], 1, 1, "FLASH_ATTN", "generate", True),
("google/gemma-2-2b-it", [], 1, 2, "FLASHINFER", "generate", True),
("ibm/PowerMoE-3b", [], 1, 2, "FLASH_ATTN", "generate", True),
# TODO: add multi-modality test for llava
("llava-hf/llava-1.5-7b-hf", [], 2, 1, "FLASHINFER", "generate", False)
])
Expand Down
21 changes: 10 additions & 11 deletions tests/kernels/test_awq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@
import pytest
import torch

import vllm.model_executor.layers.fused_moe # noqa
from tests.kernels.utils import (compute_max_diff, stack_and_dev, torch_moe,
torch_moe_single)
from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
fused_marlin_moe, single_marlin_moe)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
awq_marlin_quantize)
Expand Down Expand Up @@ -81,7 +80,7 @@ def test_fused_marlin_moe_awq(
score = torch.randn((m, e), device="cuda", dtype=dtype)

topk_weights, topk_ids = fused_topk(a, score, topk, False)
marlin_output = fused_marlin_moe(
marlin_output = torch.ops.vllm.fused_marlin_moe(
a,
qweight1,
qweight2,
Expand Down Expand Up @@ -150,14 +149,14 @@ def test_single_marlin_moe_multiply_awq(

score = torch.randn((m, e), device="cuda", dtype=dtype)

marlin_output = single_marlin_moe(a,
qweight,
scales,
score,
topk,
renormalize=False,
w_zeros=zp,
num_bits=num_bits)
marlin_output = torch.ops.vllm.single_marlin_moe(a,
qweight,
scales,
score,
topk,
renormalize=False,
w_zeros=zp,
num_bits=num_bits)

torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk)

Expand Down
7 changes: 3 additions & 4 deletions tests/kernels/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,11 @@
from transformers import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock

import vllm.model_executor.layers.fused_moe # noqa
from tests.kernels.utils import (compute_max_diff, opcheck, stack_and_dev,
torch_moe, torch_moe_single)
from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
fused_marlin_moe, single_marlin_moe)
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, moe_align_block_size)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
Expand Down Expand Up @@ -193,7 +192,7 @@ def test_fused_marlin_moe(
topk,
renormalize=False,
)
marlin_output = fused_marlin_moe(
marlin_output = torch.ops.vllm.fused_marlin_moe(
a,
qweight1,
qweight2,
Expand Down Expand Up @@ -309,7 +308,7 @@ def test_single_marlin_moe_multiply(
sort_indices = stack_and_dev(sort_indices_l)

score = torch.randn((m, e), device="cuda", dtype=dtype)
marlin_output = single_marlin_moe(
marlin_output = torch.ops.vllm.single_marlin_moe(
a,
qweight,
scales,
Expand Down
28 changes: 24 additions & 4 deletions vllm/model_executor/layers/fused_moe/__init__.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,43 @@
from contextlib import contextmanager
from typing import Any, Dict, Optional

from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
from vllm.triton_utils import HAS_TRITON

_config: Optional[Dict[str, Any]] = None


@contextmanager
def override_config(config):
global _config
old_config = _config
_config = config
yield
_config = old_config


def get_config() -> Optional[Dict[str, Any]]:
return _config


__all__ = [
"FusedMoE",
"FusedMoEMethodBase",
"FusedMoeWeightScaleSupported",
"override_config",
"get_config",
]

if HAS_TRITON:
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
fused_marlin_moe, single_marlin_moe)
# import to register the custom ops
import vllm.model_executor.layers.fused_moe.fused_marlin_moe # noqa
import vllm.model_executor.layers.fused_moe.fused_moe # noqa
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts, fused_moe, fused_topk, get_config_file_name,
grouped_topk)

__all__ += [
"fused_marlin_moe",
"single_marlin_moe",
"fused_moe",
"fused_topk",
"fused_experts",
Expand Down
51 changes: 42 additions & 9 deletions vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Fused MoE utilities for GPTQ."""
import functools
from typing import Any, Dict, Optional
from typing import Optional

import torch

Expand All @@ -18,6 +18,7 @@ def get_scalar_type(num_bits: int, has_zp: bool):
return scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128


@torch.library.custom_op("vllm::single_marlin_moe", mutates_args=[])
def single_marlin_moe(
hidden_states: torch.Tensor,
w: torch.Tensor,
Expand All @@ -28,7 +29,6 @@ def single_marlin_moe(
g_idx: Optional[torch.Tensor] = None,
sort_indices: Optional[torch.Tensor] = None,
w_zeros: Optional[torch.Tensor] = None,
override_config: Optional[Dict[str, Any]] = None,
num_bits: int = 8,
is_k_full: bool = True,
) -> torch.Tensor:
Expand All @@ -49,8 +49,6 @@ def single_marlin_moe(
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- w_zeros (Optional[torch.Tensor]): Optional zero points to be used for w.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
- num_bits (bool): The number of bits in expert weights quantization.
Returns:
Expand Down Expand Up @@ -79,7 +77,6 @@ def single_marlin_moe(
w.shape,
topk_ids.shape[1],
None,
override_config=override_config,
is_marlin=True)
config = get_config_func(M)

Expand Down Expand Up @@ -122,6 +119,24 @@ def single_marlin_moe(
return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)


@single_marlin_moe.register_fake
def _(
hidden_states: torch.Tensor,
w: torch.Tensor,
scales: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
g_idx: Optional[torch.Tensor] = None,
sort_indices: Optional[torch.Tensor] = None,
w_zeros: Optional[torch.Tensor] = None,
num_bits: int = 8,
is_k_full: bool = True,
) -> torch.Tensor:
return torch.empty_like(hidden_states)


@torch.library.custom_op("vllm::fused_marlin_moe", mutates_args=[])
def fused_marlin_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
Expand All @@ -137,7 +152,6 @@ def fused_marlin_moe(
sort_indices2: Optional[torch.Tensor] = None,
w1_zeros: Optional[torch.Tensor] = None,
w2_zeros: Optional[torch.Tensor] = None,
override_config: Optional[Dict[str, Any]] = None,
num_bits: int = 8,
is_k_full: bool = True,
) -> torch.Tensor:
Expand All @@ -161,8 +175,6 @@ def fused_marlin_moe(
permutation.
- topk_weights (torch.Tensor): Top-k weights.
- topk_ids (torch.Tensor): Indices of topk-k elements.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
- w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1.
- w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2.
- num_bits (bool): The number of bits in expert weights quantization.
Expand Down Expand Up @@ -209,7 +221,6 @@ def fused_marlin_moe(
w2.shape,
topk_ids.shape[1],
None,
override_config=override_config,
is_marlin=True,
)
config = get_config_func(M)
Expand Down Expand Up @@ -311,3 +322,25 @@ def fused_marlin_moe(

return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
dim=1)


@fused_marlin_moe.register_fake
def _(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
gating_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
g_idx1: Optional[torch.Tensor] = None,
g_idx2: Optional[torch.Tensor] = None,
sort_indices1: Optional[torch.Tensor] = None,
sort_indices2: Optional[torch.Tensor] = None,
w1_zeros: Optional[torch.Tensor] = None,
w2_zeros: Optional[torch.Tensor] = None,
num_bits: int = 8,
is_k_full: bool = True,
) -> torch.Tensor:
return torch.empty_like(hidden_states)
Loading

0 comments on commit e703269

Please sign in to comment.