Skip to content

Commit

Permalink
[Kernel] Fullgraph and opcheck tests (vllm-project#8479)
Browse files Browse the repository at this point in the history
  • Loading branch information
bnellnm authored Sep 25, 2024
1 parent 1c04644 commit 300da09
Show file tree
Hide file tree
Showing 26 changed files with 744 additions and 116 deletions.
19 changes: 17 additions & 2 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ steps:
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py
- VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py

- label: Core Test # 10min
mirror_hardwares: [amd]
fast_check: true
Expand Down Expand Up @@ -210,6 +210,21 @@ steps:
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py
parallelism: 4

- label: "PyTorch Fullgraph Smoke Test"
fast_check: true
source_file_dependencies:
- vllm/
- tests/compile
commands:
- pytest -v -s compile/test_full_graph_smoke.py

- label: "PyTorch Fullgraph Test"
source_file_dependencies:
- vllm/
- tests/compile
commands:
- pytest -v -s compile/test_full_graph.py

- label: Kernels Test %N # 30min each
mirror_hardwares: [amd]
source_file_dependencies:
Expand Down Expand Up @@ -355,7 +370,7 @@ steps:
- tests/distributed/
- vllm/compilation
commands:
- pytest -v -s ./compile/test_full_graph.py
- pytest -v -s ./compile/test_full_graph_multi_gpu.py
- pytest -v -s ./compile/test_wrapper.py
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep -q 'Same node test passed'
- TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m distributed_2_gpus
Expand Down
2 changes: 1 addition & 1 deletion csrc/mamba/mamba_ssm/selective_scan_fwd.cu
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,7 @@ selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] {
selective_scan_fwd_cuda<input_t, weight_t>(params, stream);
});
std::vector<at::Tensor> result = {out, x.value()};
std::vector<at::Tensor> result = {out};
if (has_z) { result.push_back(out_z); }
return result;
}
Expand Down
4 changes: 2 additions & 2 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor! A, Tensor! B, Tensor! C,"
"Tensor? D_, Tensor? z_, Tensor? delta_bias_,"
"bool delta_softplus,"
"Tensor? index_, Tensor(a! -> *)? x) -> Tensor(a)[]");
"Tensor? index_, Tensor!? x) -> Tensor[]");
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);

ops.def(
Expand All @@ -292,7 +292,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor? bias_,"
"Tensor? seq_idx_,"
"Tensor? initial_states_,"
"Tensor? final_states_out_,"
"Tensor!? final_states_out_,"
"bool silu_activation) -> Tensor");
ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd);
#endif
Expand Down
45 changes: 8 additions & 37 deletions tests/compile/test_full_graph.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,13 @@
import os

import pytest

from vllm.utils import cuda_device_count_stateless

from ..utils import fork_new_process_for_each_test


@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"])
@pytest.mark.parametrize("tp_size", [1, 2])
@fork_new_process_for_each_test
def test_full_graph(model, tp_size):

# Skip the test if there are not enough CUDA devices.
if cuda_device_count_stateless() < tp_size:
pytest.skip("Not enough CUDA devices for the test.")

# make sure these models can be captured in full graph mode
if "VLLM_TEST_DYNAMO_GRAPH_CAPTURE" not in os.environ:
os.environ["VLLM_TEST_DYNAMO_GRAPH_CAPTURE"] = "1"
from vllm.compilation.backends import vllm_backend

from vllm import LLM, SamplingParams
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0)
llm = LLM(model=model,
enforce_eager=True,
tensor_parallel_size=tp_size,
disable_custom_all_reduce=True)
from .utils import TEST_MODELS, check_full_graph_support

outputs = llm.generate(prompts, sampling_params)

# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
@pytest.mark.parametrize("model_info", TEST_MODELS)
@pytest.mark.parametrize("backend", ["eager", vllm_backend])
def test_full_graph(model_info, backend):
model = model_info[0]
model_kwargs = model_info[1]
check_full_graph_support(model, model_kwargs, backend, tp_size=1)
22 changes: 22 additions & 0 deletions tests/compile/test_full_graph_multi_gpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import pytest

from vllm.compilation.backends import vllm_backend
from vllm.utils import cuda_device_count_stateless

from ..utils import fork_new_process_for_each_test
from .utils import TEST_MODELS_SMOKE, check_full_graph_support


@pytest.mark.parametrize("model_info", TEST_MODELS_SMOKE)
@pytest.mark.parametrize("tp_size", [2])
@pytest.mark.parametrize("backend", ["eager", vllm_backend])
@fork_new_process_for_each_test
def test_full_graph_multi_gpu(model_info, tp_size, backend):
model = model_info[0]
model_kwargs = model_info[1]

# Skip the test if there are not enough CUDA devices.
if cuda_device_count_stateless() < tp_size:
pytest.skip("Not enough CUDA devices for the test.")

check_full_graph_support(model, model_kwargs, backend, tp_size=tp_size)
13 changes: 13 additions & 0 deletions tests/compile/test_full_graph_smoke.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import pytest

from vllm.compilation.backends import vllm_backend

from .utils import TEST_MODELS_SMOKE, check_full_graph_support


@pytest.mark.parametrize("model_info", TEST_MODELS_SMOKE)
@pytest.mark.parametrize("backend", ["eager", vllm_backend])
def test_full_graph(model_info, backend):
model = model_info[0]
model_kwargs = model_info[1]
check_full_graph_support(model, model_kwargs, backend, tp_size=1)
104 changes: 104 additions & 0 deletions tests/compile/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import os

import torch

from tests.quantization.utils import is_quant_method_supported
from vllm import LLM, SamplingParams
from vllm.plugins import set_torch_compile_backend
from vllm.utils import is_hip

TEST_MODELS_SMOKE = [
("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples", {
"quantization": "compressed-tensors"
}),
("meta-llama/Meta-Llama-3-8B", {}),
]

TEST_MODELS = [
("facebook/opt-125m", {}),
("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", {
"dtype": torch.float16,
"quantization": "compressed-tensors"
}),
("neuralmagic/Meta-Llama-3-8B-Instruct-FP8", {
"dtype": torch.float16,
"quantization": "fp8"
}),
("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples", {
"quantization": "compressed-tensors"
}),
("meta-llama/Meta-Llama-3-8B", {}),
]

# TODO: enable in pytorch 2.5
if False and is_quant_method_supported("aqlm"): # noqa: SIM223
TEST_MODELS.append(("ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf", {
"quantization": "aqlm"
}))

# TODO: enable in pytorch 2.5
if False and is_quant_method_supported("gguf"): # noqa: SIM223
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF", {
"quantization": "gguf"
}))

if is_quant_method_supported("gptq"):
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", {
"quantization": "gptq"
}))

if is_quant_method_supported("gptq_marlin"):
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", {
"quantization": "gptq_marlin"
}))

if is_quant_method_supported("gptq_marlin_24"):
TEST_MODELS.append(("alexm-nm/tinyllama-24-marlin24-4bit-g128", {
"quantization": "gptq_marlin_24"
}))

if is_quant_method_supported("marlin"):
TEST_MODELS.append(("robertgshaw2/TinyLlama-1.1B-Chat-v1.0-g128-marlin", {
"quantization": "marlin"
}))

if not is_hip() and is_quant_method_supported("awq"):
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", {
"quantization": "AWQ"
}))


def check_full_graph_support(model, model_kwargs, backend, tp_size=1):
# make sure these models can be captured in full graph mode
if "VLLM_TEST_DYNAMO_GRAPH_CAPTURE" not in os.environ:
os.environ["VLLM_TEST_DYNAMO_GRAPH_CAPTURE"] = "1"
os.environ["VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "1"

# Inductor doesn't support fp8/gptq_marlin_24 yet.
quantization = model_kwargs.get("quantization")
if (quantization == "fp8" or quantization == "gptq_marlin"
or quantization == "gptq_marlin_24") and backend != "eager":
return

set_torch_compile_backend(backend)

prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0)
llm = LLM(model=model,
enforce_eager=True,
tensor_parallel_size=tp_size,
disable_custom_all_reduce=True,
**model_kwargs)

outputs = llm.generate(prompts, sampling_params)

# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
6 changes: 6 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,12 @@ def cleanup_fixture(should_do_global_cleanup_after_test: bool):
cleanup()


@pytest.fixture(autouse=True)
def dynamo_reset():
yield
torch._dynamo.reset()


@pytest.fixture
def example_prompts() -> List[str]:
prompts = []
Expand Down
37 changes: 37 additions & 0 deletions tests/kernels/test_aqlm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import torch

from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops # noqa: F401


def test_aqlm_dequant_opcheck():
codes = torch.randint(-32768,
32767, (22016, 512, 1),
device='cuda',
dtype=torch.int16)
codebooks = torch.rand((2, 65536, 1, 8),
device='cuda',
dtype=torch.float16)
codebook_partition_sizes = [11008, 11008]

opcheck(torch.ops._C.aqlm_dequant,
(codes, codebooks, codebook_partition_sizes))


def test_aqlm_gemm_opcheck():
input = torch.rand((4, 4096), device='cuda', dtype=torch.float16)
codes = torch.randint(-32768,
32767, (12288, 512, 1),
device='cuda',
dtype=torch.int16)
codebooks = torch.rand((3, 65536, 1, 8),
device='cuda',
dtype=torch.float16)
scales = torch.rand((12288, 1, 1, 1), device='cuda', dtype=torch.float16)
codebook_partition_sizes = [4096, 4096, 4096]
bias = None

opcheck(torch.ops._C.aqlm_gemm,
(input, codes, codebooks, scales, codebook_partition_sizes, None))
opcheck(torch.ops._C.aqlm_gemm,
(input, codes, codebooks, scales, codebook_partition_sizes, bias))
9 changes: 6 additions & 3 deletions tests/kernels/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,8 @@ def test_paged_attention(
(output, query, key_cache, value_cache, num_kv_heads, scale,
block_tables, seq_lens, block_size, max_seq_len, alibi_slopes,
kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
cond=(head_size == HEAD_SIZES[0]))
cond=(head_size == HEAD_SIZES[0]
and block_size == BLOCK_SIZES[0]))

elif version in ("v2", "rocm"):
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
Expand Down Expand Up @@ -246,7 +247,8 @@ def test_paged_attention(
key_cache, value_cache, num_kv_heads, scale, block_tables,
seq_lens, block_size, max_seq_len, alibi_slopes,
kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
cond=(head_size == HEAD_SIZES[0]))
cond=(head_size == HEAD_SIZES[0]
and block_size == BLOCK_SIZES[0]))

else:
ops.paged_attention_rocm(
Expand Down Expand Up @@ -274,7 +276,8 @@ def test_paged_attention(
key_cache, value_cache, num_kv_heads, scale, block_tables,
seq_lens, block_size, max_seq_len, alibi_slopes,
kv_cache_dtype, k_scale, v_scale),
cond=(head_size == HEAD_SIZES[0]))
cond=(head_size == HEAD_SIZES[0]
and block_size == BLOCK_SIZES[0]))

else:
raise AssertionError(f"Unknown version: {version}")
Expand Down
38 changes: 38 additions & 0 deletions tests/kernels/test_awq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import os

import torch

from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops # noqa: F401


def test_awq_dequantize_opcheck():
os.environ["VLLM_USE_TRITON_AWQ"] = "0"
qweight = torch.randint(-2000000000,
2000000000, (8192, 256),
device='cuda',
dtype=torch.int32)
scales = torch.rand((64, 2048), device='cuda', dtype=torch.float16)
zeros = torch.empty((64, 256), device='cuda', dtype=torch.int32)
split_k_iters = 0
thx = 0
thy = 0
opcheck(torch.ops._C.awq_dequantize,
(qweight, scales, zeros, split_k_iters, thx, thy))


def test_awq_gemm_opcheck():
os.environ["VLLM_USE_TRITON_AWQ"] = "0"
input = torch.rand((2, 8192), device='cuda', dtype=torch.float16)
qweight = torch.randint(-2000000000,
2000000000, (8192, 256),
device='cuda',
dtype=torch.int32)
scales = torch.randint(-2000000000,
2000000000, (64, 256),
device='cuda',
dtype=torch.int32)
qzeros = torch.empty((64, 2048), device='cuda', dtype=torch.float16)
split_k_iters = 8
opcheck(torch.ops._C.awq_gemm,
(input, qweight, qzeros, scales, split_k_iters))
Loading

0 comments on commit 300da09

Please sign in to comment.