Skip to content

Commit

Permalink
[torch.compile] rework test plans (vllm-project#9866)
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: Maxime Fournioux <[email protected]>
  • Loading branch information
youkaichao authored and mfournioux committed Nov 20, 2024
1 parent 6c43229 commit 0787a83
Show file tree
Hide file tree
Showing 4 changed files with 226 additions and 31 deletions.
113 changes: 95 additions & 18 deletions tests/compile/test_basic_correctness.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import dataclasses
from typing import Dict, List, Optional

import pytest
Expand All @@ -8,33 +9,109 @@
from ..utils import compare_all_settings


@dataclasses.dataclass
class TestSetting:
model: str
model_args: List[str]
pp_size: int
tp_size: int
attn_backend: str
method: str
fullgraph: bool


# representative settings for testing
test_settings = [
# basic llama model
TestSetting(
model="meta-llama/Llama-3.2-1B",
model_args=[],
pp_size=2,
tp_size=2,
attn_backend="FLASHINFER",
method="generate",
fullgraph=True,
),
# llama model with quantization
TestSetting(
model="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ",
model_args=["--quantization", "gptq"],
pp_size=1,
tp_size=1,
attn_backend="FLASH_ATTN",
method="generate",
fullgraph=True,
),
# MoE model
TestSetting(
model="ibm/PowerMoE-3b",
model_args=[],
pp_size=1,
tp_size=2,
attn_backend="FLASH_ATTN",
method="generate",
fullgraph=True,
),
# embedding model
TestSetting(
model="BAAI/bge-multilingual-gemma2",
model_args=["--task", "embedding"],
pp_size=1,
tp_size=1,
attn_backend="FLASHINFER",
method="encode",
fullgraph=True,
),
# vision language model
TestSetting(
model="microsoft/Phi-3.5-vision-instruct",
model_args=["--trust-remote-code", "--max-model-len", "2048"],
pp_size=2,
tp_size=1,
attn_backend="FLASH_ATTN",
method="generate_with_image",
fullgraph=False,
),
]


# we cannot afford testing the full Catesian product
# of all models and all levels
@pytest.mark.parametrize(
"model, model_args, pp_size, tp_size, attn_backend, method, fullgraph",
[
("meta-llama/Llama-3.2-1B", [], 2, 2, "FLASHINFER", "generate", True),
("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples",
["--quantization", "compressed-tensors"
], 1, 1, "FLASH_ATTN", "generate", True),
("ibm/PowerMoE-3b", [], 1, 2, "FLASH_ATTN", "generate", True),
# TODO: add multi-modality test for llava
("llava-hf/llava-1.5-7b-hf", [], 2, 1, "FLASHINFER", "generate", False)
])
def test_compile_correctness(model, model_args, pp_size, tp_size, attn_backend,
method, fullgraph):
@pytest.mark.parametrize("test_setting", test_settings)
def test_compile_correctness(test_setting: TestSetting):
# this test is run under multiple suits, with different GPUs.
# make sure we only run the test with correct CUDA devices.
# don't use "<", as it will duplicate the tests.
model = test_setting.model
model_args = test_setting.model_args
pp_size = test_setting.pp_size
tp_size = test_setting.tp_size
attn_backend = test_setting.attn_backend
method = test_setting.method
fullgraph = test_setting.fullgraph
if cuda_device_count_stateless() != pp_size * tp_size:
pytest.skip("Not correct CUDA devices for the test.")
import os
os.environ["VLLM_ATTENTION_BACKEND"] = attn_backend
all_args = [["--enforce-eager"] + model_args + ["-pp", str(pp_size)] +
["-tp", str(tp_size)]] * 3
# don't test VLLM_TORCH_COMPILE_LEVEL == 3 case
# inductor will change the output, so we cannot compare them.
final_args = ["--enforce-eager"] + model_args + ["-pp", str(pp_size)] + \
["-tp", str(tp_size)]

all_envs: List[Optional[Dict[str, str]]] = []

for level in [
CompilationLevel.NO_COMPILATION,
CompilationLevel.PIECEWISE,
]:
all_envs.append({"VLLM_TORCH_COMPILE_LEVEL": str(level)})

# inductor will change the output, so we only compare if the output
# is close, not exactly the same.
compare_all_settings(
model, [final_args] * 2,
all_envs,
method=method if method != "generate" else "generate_close")
all_envs.clear()

for level in [
CompilationLevel.NO_COMPILATION,
CompilationLevel.DYNAMO_AS_IS,
Expand All @@ -46,4 +123,4 @@ def test_compile_correctness(model, model_args, pp_size, tp_size, attn_backend,
all_envs[-1][
"VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "0" # type: ignore

compare_all_settings(model, all_args, all_envs, method=method)
compare_all_settings(model, [final_args] * 3, all_envs, method=method)
124 changes: 119 additions & 5 deletions tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import copy
import functools
import os
import signal
Expand All @@ -8,13 +9,14 @@
import warnings
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Callable, Dict, List, Literal, Optional, Type, Union
from typing import Any, Callable, Dict, List, Optional, Type, Union

import openai
import pytest
import requests
import torch
from openai.types.completion import Completion
from typing_extensions import ParamSpec, assert_never
from typing_extensions import ParamSpec

import vllm.envs as envs
from tests.models.utils import TextTextLogprobs
Expand Down Expand Up @@ -272,6 +274,31 @@ def _test_completion(
return results


def _test_completion_close(
client: openai.OpenAI,
model: str,
prompt: str,
):
results = []

# test with text prompt
completion = client.completions.create(model=model,
prompt=prompt,
max_tokens=1,
logprobs=5,
temperature=0.0)

logporbs = completion.choices[0].logprobs.top_logprobs[0]
logporbs = {k: round(v, 2) for k, v in logporbs.items()}

results.append({
"test": "completion_close",
"logprobs": logporbs,
})

return results


def _test_embeddings(
client: openai.OpenAI,
model: str,
Expand All @@ -295,13 +322,81 @@ def _test_embeddings(
return results


def _test_image_text(
client: openai.OpenAI,
model_name: str,
image_url: str,
):
results = []

# test pure text input
messages = [{
"role":
"user",
"content": [
{
"type": "text",
"text": "How do you feel today?"
},
],
}]

chat_completion = client.chat.completions.create(model=model_name,
messages=messages,
temperature=0.0,
max_tokens=1,
logprobs=True,
top_logprobs=5)
top_logprobs = chat_completion.choices[0].logprobs.content[0].top_logprobs

for x in top_logprobs:
x.logprob = round(x.logprob, 2)

results.append({
"test": "pure_text",
"logprobs": top_logprobs,
})

messages = [{
"role":
"user",
"content": [
{
"type": "image_url",
"image_url": {
"url": image_url
}
},
{
"type": "text",
"text": "What's in this image?"
},
],
}]

chat_completion = client.chat.completions.create(model=model_name,
messages=messages,
temperature=0.0,
max_tokens=1,
logprobs=True,
top_logprobs=5)
top_logprobs = chat_completion.choices[0].logprobs.content[0].top_logprobs

results.append({
"test": "text_image",
"logprobs": top_logprobs,
})

return results


def compare_two_settings(model: str,
arg1: List[str],
arg2: List[str],
env1: Optional[Dict[str, str]] = None,
env2: Optional[Dict[str, str]] = None,
*,
method: Literal["generate", "encode"] = "generate",
method: str = "generate",
max_wait_seconds: Optional[float] = None) -> None:
"""
Launch API server with two different sets of arguments/environments
Expand All @@ -328,7 +423,7 @@ def compare_all_settings(model: str,
all_args: List[List[str]],
all_envs: List[Optional[Dict[str, str]]],
*,
method: Literal["generate", "encode"] = "generate",
method: str = "generate",
max_wait_seconds: Optional[float] = None) -> None:
"""
Launch API server with several different sets of arguments/environments
Expand Down Expand Up @@ -397,10 +492,17 @@ def compare_all_settings(model: str,

if method == "generate":
results += _test_completion(client, model, prompt, token_ids)
elif method == "generate_close":
results += _test_completion_close(client, model, prompt)
elif method == "generate_with_image":
results += _test_image_text(
client, model,
"https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png"
)
elif method == "encode":
results += _test_embeddings(client, model, prompt)
else:
assert_never(method)
raise ValueError(f"Unknown method: {method}")

if i > 0:
# if any setting fails, raise an error early
Expand All @@ -410,6 +512,18 @@ def compare_all_settings(model: str,
compare_envs = all_envs[i]
for ref_result, compare_result in zip(ref_results,
compare_results):
ref_result = copy.deepcopy(ref_result)
compare_result = copy.deepcopy(compare_result)
if "embedding" in ref_result and method == "encode":
ref_embedding = torch.tensor(ref_result["embedding"])
compare_embedding = torch.tensor(
compare_result["embedding"])
mse = ((ref_embedding - compare_embedding)**2).mean()
assert mse < 1e-6, (
f"Embedding for {model=} are not the same.\n"
f"mse={mse}\n")
del ref_result["embedding"]
del compare_result["embedding"]
assert ref_result == compare_result, (
f"Results for {model=} are not the same.\n"
f"{ref_args=} {ref_envs=}\n"
Expand Down
10 changes: 5 additions & 5 deletions vllm/model_executor/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,13 +493,9 @@ def forward(
:class:`LlavaImageInputs`
"""
if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None
else:
# always pass the input via `inputs_embeds`
# to make sure the computation graph is consistent
image_input = self._parse_and_validate_image_input(**kwargs)

if image_input is not None:
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.language_model.model.get_input_embeddings(
Expand All @@ -511,7 +507,11 @@ def forward(
else:
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
input_ids = None

# always pass the input via `inputs_embeds`
# to make sure the computation graph is consistent
# for `torch.compile` integration
input_ids = None

hidden_states = self.language_model.model(input_ids,
positions,
Expand Down
10 changes: 7 additions & 3 deletions vllm/model_executor/models/phi3v.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,7 +679,6 @@ def forward(self,
intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object):
if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None
else:
image_input = self._parse_and_validate_image_input(**kwargs)
Expand All @@ -690,9 +689,14 @@ def forward(self,
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.image_token_id)
input_ids = None
else:
inputs_embeds = None
inputs_embeds = self.language_model.model.embed_tokens(
input_ids)

# always pass the input via `inputs_embeds`
# to make sure the computation graph is consistent
# for `torch.compile` integration
input_ids = None

hidden_states = self.language_model.model(input_ids,
positions,
Expand Down

0 comments on commit 0787a83

Please sign in to comment.