forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Kernel] Fullgraph and opcheck tests (vllm-project#8479)
- Loading branch information
Showing
26 changed files
with
744 additions
and
116 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,42 +1,13 @@ | ||
import os | ||
|
||
import pytest | ||
|
||
from vllm.utils import cuda_device_count_stateless | ||
|
||
from ..utils import fork_new_process_for_each_test | ||
|
||
|
||
@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"]) | ||
@pytest.mark.parametrize("tp_size", [1, 2]) | ||
@fork_new_process_for_each_test | ||
def test_full_graph(model, tp_size): | ||
|
||
# Skip the test if there are not enough CUDA devices. | ||
if cuda_device_count_stateless() < tp_size: | ||
pytest.skip("Not enough CUDA devices for the test.") | ||
|
||
# make sure these models can be captured in full graph mode | ||
if "VLLM_TEST_DYNAMO_GRAPH_CAPTURE" not in os.environ: | ||
os.environ["VLLM_TEST_DYNAMO_GRAPH_CAPTURE"] = "1" | ||
from vllm.compilation.backends import vllm_backend | ||
|
||
from vllm import LLM, SamplingParams | ||
prompts = [ | ||
"Hello, my name is", | ||
"The president of the United States is", | ||
"The capital of France is", | ||
"The future of AI is", | ||
] | ||
sampling_params = SamplingParams(temperature=0) | ||
llm = LLM(model=model, | ||
enforce_eager=True, | ||
tensor_parallel_size=tp_size, | ||
disable_custom_all_reduce=True) | ||
from .utils import TEST_MODELS, check_full_graph_support | ||
|
||
outputs = llm.generate(prompts, sampling_params) | ||
|
||
# Print the outputs. | ||
for output in outputs: | ||
prompt = output.prompt | ||
generated_text = output.outputs[0].text | ||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") | ||
@pytest.mark.parametrize("model_info", TEST_MODELS) | ||
@pytest.mark.parametrize("backend", ["eager", vllm_backend]) | ||
def test_full_graph(model_info, backend): | ||
model = model_info[0] | ||
model_kwargs = model_info[1] | ||
check_full_graph_support(model, model_kwargs, backend, tp_size=1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import pytest | ||
|
||
from vllm.compilation.backends import vllm_backend | ||
from vllm.utils import cuda_device_count_stateless | ||
|
||
from ..utils import fork_new_process_for_each_test | ||
from .utils import TEST_MODELS_SMOKE, check_full_graph_support | ||
|
||
|
||
@pytest.mark.parametrize("model_info", TEST_MODELS_SMOKE) | ||
@pytest.mark.parametrize("tp_size", [2]) | ||
@pytest.mark.parametrize("backend", ["eager", vllm_backend]) | ||
@fork_new_process_for_each_test | ||
def test_full_graph_multi_gpu(model_info, tp_size, backend): | ||
model = model_info[0] | ||
model_kwargs = model_info[1] | ||
|
||
# Skip the test if there are not enough CUDA devices. | ||
if cuda_device_count_stateless() < tp_size: | ||
pytest.skip("Not enough CUDA devices for the test.") | ||
|
||
check_full_graph_support(model, model_kwargs, backend, tp_size=tp_size) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
import pytest | ||
|
||
from vllm.compilation.backends import vllm_backend | ||
|
||
from .utils import TEST_MODELS_SMOKE, check_full_graph_support | ||
|
||
|
||
@pytest.mark.parametrize("model_info", TEST_MODELS_SMOKE) | ||
@pytest.mark.parametrize("backend", ["eager", vllm_backend]) | ||
def test_full_graph(model_info, backend): | ||
model = model_info[0] | ||
model_kwargs = model_info[1] | ||
check_full_graph_support(model, model_kwargs, backend, tp_size=1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
import os | ||
|
||
import torch | ||
|
||
from tests.quantization.utils import is_quant_method_supported | ||
from vllm import LLM, SamplingParams | ||
from vllm.plugins import set_torch_compile_backend | ||
from vllm.utils import is_hip | ||
|
||
TEST_MODELS_SMOKE = [ | ||
("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples", { | ||
"quantization": "compressed-tensors" | ||
}), | ||
("meta-llama/Meta-Llama-3-8B", {}), | ||
] | ||
|
||
TEST_MODELS = [ | ||
("facebook/opt-125m", {}), | ||
("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", { | ||
"dtype": torch.float16, | ||
"quantization": "compressed-tensors" | ||
}), | ||
("neuralmagic/Meta-Llama-3-8B-Instruct-FP8", { | ||
"dtype": torch.float16, | ||
"quantization": "fp8" | ||
}), | ||
("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples", { | ||
"quantization": "compressed-tensors" | ||
}), | ||
("meta-llama/Meta-Llama-3-8B", {}), | ||
] | ||
|
||
# TODO: enable in pytorch 2.5 | ||
if False and is_quant_method_supported("aqlm"): # noqa: SIM223 | ||
TEST_MODELS.append(("ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf", { | ||
"quantization": "aqlm" | ||
})) | ||
|
||
# TODO: enable in pytorch 2.5 | ||
if False and is_quant_method_supported("gguf"): # noqa: SIM223 | ||
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF", { | ||
"quantization": "gguf" | ||
})) | ||
|
||
if is_quant_method_supported("gptq"): | ||
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", { | ||
"quantization": "gptq" | ||
})) | ||
|
||
if is_quant_method_supported("gptq_marlin"): | ||
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", { | ||
"quantization": "gptq_marlin" | ||
})) | ||
|
||
if is_quant_method_supported("gptq_marlin_24"): | ||
TEST_MODELS.append(("alexm-nm/tinyllama-24-marlin24-4bit-g128", { | ||
"quantization": "gptq_marlin_24" | ||
})) | ||
|
||
if is_quant_method_supported("marlin"): | ||
TEST_MODELS.append(("robertgshaw2/TinyLlama-1.1B-Chat-v1.0-g128-marlin", { | ||
"quantization": "marlin" | ||
})) | ||
|
||
if not is_hip() and is_quant_method_supported("awq"): | ||
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", { | ||
"quantization": "AWQ" | ||
})) | ||
|
||
|
||
def check_full_graph_support(model, model_kwargs, backend, tp_size=1): | ||
# make sure these models can be captured in full graph mode | ||
if "VLLM_TEST_DYNAMO_GRAPH_CAPTURE" not in os.environ: | ||
os.environ["VLLM_TEST_DYNAMO_GRAPH_CAPTURE"] = "1" | ||
os.environ["VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "1" | ||
|
||
# Inductor doesn't support fp8/gptq_marlin_24 yet. | ||
quantization = model_kwargs.get("quantization") | ||
if (quantization == "fp8" or quantization == "gptq_marlin" | ||
or quantization == "gptq_marlin_24") and backend != "eager": | ||
return | ||
|
||
set_torch_compile_backend(backend) | ||
|
||
prompts = [ | ||
"Hello, my name is", | ||
"The president of the United States is", | ||
"The capital of France is", | ||
"The future of AI is", | ||
] | ||
sampling_params = SamplingParams(temperature=0) | ||
llm = LLM(model=model, | ||
enforce_eager=True, | ||
tensor_parallel_size=tp_size, | ||
disable_custom_all_reduce=True, | ||
**model_kwargs) | ||
|
||
outputs = llm.generate(prompts, sampling_params) | ||
|
||
# Print the outputs. | ||
for output in outputs: | ||
prompt = output.prompt | ||
generated_text = output.outputs[0].text | ||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import torch | ||
|
||
from tests.kernels.utils import opcheck | ||
from vllm import _custom_ops as ops # noqa: F401 | ||
|
||
|
||
def test_aqlm_dequant_opcheck(): | ||
codes = torch.randint(-32768, | ||
32767, (22016, 512, 1), | ||
device='cuda', | ||
dtype=torch.int16) | ||
codebooks = torch.rand((2, 65536, 1, 8), | ||
device='cuda', | ||
dtype=torch.float16) | ||
codebook_partition_sizes = [11008, 11008] | ||
|
||
opcheck(torch.ops._C.aqlm_dequant, | ||
(codes, codebooks, codebook_partition_sizes)) | ||
|
||
|
||
def test_aqlm_gemm_opcheck(): | ||
input = torch.rand((4, 4096), device='cuda', dtype=torch.float16) | ||
codes = torch.randint(-32768, | ||
32767, (12288, 512, 1), | ||
device='cuda', | ||
dtype=torch.int16) | ||
codebooks = torch.rand((3, 65536, 1, 8), | ||
device='cuda', | ||
dtype=torch.float16) | ||
scales = torch.rand((12288, 1, 1, 1), device='cuda', dtype=torch.float16) | ||
codebook_partition_sizes = [4096, 4096, 4096] | ||
bias = None | ||
|
||
opcheck(torch.ops._C.aqlm_gemm, | ||
(input, codes, codebooks, scales, codebook_partition_sizes, None)) | ||
opcheck(torch.ops._C.aqlm_gemm, | ||
(input, codes, codebooks, scales, codebook_partition_sizes, bias)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import os | ||
|
||
import torch | ||
|
||
from tests.kernels.utils import opcheck | ||
from vllm import _custom_ops as ops # noqa: F401 | ||
|
||
|
||
def test_awq_dequantize_opcheck(): | ||
os.environ["VLLM_USE_TRITON_AWQ"] = "0" | ||
qweight = torch.randint(-2000000000, | ||
2000000000, (8192, 256), | ||
device='cuda', | ||
dtype=torch.int32) | ||
scales = torch.rand((64, 2048), device='cuda', dtype=torch.float16) | ||
zeros = torch.empty((64, 256), device='cuda', dtype=torch.int32) | ||
split_k_iters = 0 | ||
thx = 0 | ||
thy = 0 | ||
opcheck(torch.ops._C.awq_dequantize, | ||
(qweight, scales, zeros, split_k_iters, thx, thy)) | ||
|
||
|
||
def test_awq_gemm_opcheck(): | ||
os.environ["VLLM_USE_TRITON_AWQ"] = "0" | ||
input = torch.rand((2, 8192), device='cuda', dtype=torch.float16) | ||
qweight = torch.randint(-2000000000, | ||
2000000000, (8192, 256), | ||
device='cuda', | ||
dtype=torch.int32) | ||
scales = torch.randint(-2000000000, | ||
2000000000, (64, 256), | ||
device='cuda', | ||
dtype=torch.int32) | ||
qzeros = torch.empty((64, 2048), device='cuda', dtype=torch.float16) | ||
split_k_iters = 8 | ||
opcheck(torch.ops._C.awq_gemm, | ||
(input, qweight, qzeros, scales, split_k_iters)) |
Oops, something went wrong.