Skip to content

Commit

Permalink
Merge branch 'vllm-project:main' into support-idefics3
Browse files Browse the repository at this point in the history
  • Loading branch information
jeejeelee authored Nov 1, 2024
2 parents 367f31e + 93a76dd commit 63265c4
Show file tree
Hide file tree
Showing 15 changed files with 467 additions and 99 deletions.
20 changes: 16 additions & 4 deletions tests/compile/piecewise/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,22 @@

import torch
from torch import nn
from torch.library import Library

from vllm.compilation.compile_context import set_compile_context
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.compilation.levels import CompilationLevel
from vllm.utils import direct_register_custom_op

os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE)

global_counter = 0

# create a library to hold the custom op
silly_lib = Library("silly", "FRAGMENT") # noqa


@torch.library.custom_op("silly::attention", mutates_args=["out"])
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
global global_counter
Expand All @@ -27,12 +31,20 @@ def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out[0] += 1


@silly_attention.register_fake
def _(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
return


direct_register_custom_op(
op_name="attention",
op_func=silly_attention,
mutates_args=["out"],
fake_impl=silly_attention_fake,
target_lib=silly_lib,
)


@support_torch_compile
class SillyModel(nn.Module):

Expand Down
20 changes: 16 additions & 4 deletions tests/compile/piecewise/test_toy_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,29 +8,41 @@

import torch
from torch import nn
from torch.library import Library

from vllm.compilation.compile_context import set_compile_context
from vllm.compilation.config import CompilationConfig
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.compilation.levels import CompilationLevel
from vllm.plugins import set_compilation_config
from vllm.utils import direct_register_custom_op

# create a library to hold the custom op
silly_lib = Library("silly", "FRAGMENT") # noqa


@torch.library.custom_op("silly::attention", mutates_args=["out"])
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
out.copy_(q)
out += k
out += v


@silly_attention.register_fake
def _(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
return


direct_register_custom_op(
op_name="attention",
op_func=silly_attention,
mutates_args=["out"],
fake_impl=silly_attention_fake,
target_lib=silly_lib,
)


@dataclass
class LlamaConfig:
hidden_size: int = 128
Expand Down
113 changes: 95 additions & 18 deletions tests/compile/test_basic_correctness.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import dataclasses
from typing import Dict, List, Optional

import pytest
Expand All @@ -8,33 +9,109 @@
from ..utils import compare_all_settings


@dataclasses.dataclass
class TestSetting:
model: str
model_args: List[str]
pp_size: int
tp_size: int
attn_backend: str
method: str
fullgraph: bool


# representative settings for testing
test_settings = [
# basic llama model
TestSetting(
model="meta-llama/Llama-3.2-1B",
model_args=[],
pp_size=2,
tp_size=2,
attn_backend="FLASHINFER",
method="generate",
fullgraph=True,
),
# llama model with quantization
TestSetting(
model="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ",
model_args=["--quantization", "gptq"],
pp_size=1,
tp_size=1,
attn_backend="FLASH_ATTN",
method="generate",
fullgraph=True,
),
# MoE model
TestSetting(
model="ibm/PowerMoE-3b",
model_args=[],
pp_size=1,
tp_size=2,
attn_backend="FLASH_ATTN",
method="generate",
fullgraph=True,
),
# embedding model
TestSetting(
model="BAAI/bge-multilingual-gemma2",
model_args=["--task", "embedding"],
pp_size=1,
tp_size=1,
attn_backend="FLASHINFER",
method="encode",
fullgraph=True,
),
# vision language model
TestSetting(
model="microsoft/Phi-3.5-vision-instruct",
model_args=["--trust-remote-code", "--max-model-len", "2048"],
pp_size=2,
tp_size=1,
attn_backend="FLASH_ATTN",
method="generate_with_image",
fullgraph=False,
),
]


# we cannot afford testing the full Catesian product
# of all models and all levels
@pytest.mark.parametrize(
"model, model_args, pp_size, tp_size, attn_backend, method, fullgraph",
[
("meta-llama/Llama-3.2-1B", [], 2, 2, "FLASHINFER", "generate", True),
("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples",
["--quantization", "compressed-tensors"
], 1, 1, "FLASH_ATTN", "generate", True),
("ibm/PowerMoE-3b", [], 1, 2, "FLASH_ATTN", "generate", True),
# TODO: add multi-modality test for llava
("llava-hf/llava-1.5-7b-hf", [], 2, 1, "FLASHINFER", "generate", False)
])
def test_compile_correctness(model, model_args, pp_size, tp_size, attn_backend,
method, fullgraph):
@pytest.mark.parametrize("test_setting", test_settings)
def test_compile_correctness(test_setting: TestSetting):
# this test is run under multiple suits, with different GPUs.
# make sure we only run the test with correct CUDA devices.
# don't use "<", as it will duplicate the tests.
model = test_setting.model
model_args = test_setting.model_args
pp_size = test_setting.pp_size
tp_size = test_setting.tp_size
attn_backend = test_setting.attn_backend
method = test_setting.method
fullgraph = test_setting.fullgraph
if cuda_device_count_stateless() != pp_size * tp_size:
pytest.skip("Not correct CUDA devices for the test.")
import os
os.environ["VLLM_ATTENTION_BACKEND"] = attn_backend
all_args = [["--enforce-eager"] + model_args + ["-pp", str(pp_size)] +
["-tp", str(tp_size)]] * 3
# don't test VLLM_TORCH_COMPILE_LEVEL == 3 case
# inductor will change the output, so we cannot compare them.
final_args = ["--enforce-eager"] + model_args + ["-pp", str(pp_size)] + \
["-tp", str(tp_size)]

all_envs: List[Optional[Dict[str, str]]] = []

for level in [
CompilationLevel.NO_COMPILATION,
CompilationLevel.PIECEWISE,
]:
all_envs.append({"VLLM_TORCH_COMPILE_LEVEL": str(level)})

# inductor will change the output, so we only compare if the output
# is close, not exactly the same.
compare_all_settings(
model, [final_args] * 2,
all_envs,
method=method if method != "generate" else "generate_close")
all_envs.clear()

for level in [
CompilationLevel.NO_COMPILATION,
CompilationLevel.DYNAMO_AS_IS,
Expand All @@ -46,4 +123,4 @@ def test_compile_correctness(model, model_args, pp_size, tp_size, attn_backend,
all_envs[-1][
"VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "0" # type: ignore

compare_all_settings(model, all_args, all_envs, method=method)
compare_all_settings(model, [final_args] * 3, all_envs, method=method)
Loading

0 comments on commit 63265c4

Please sign in to comment.