Skip to content

Commit

Permalink
[ci][test] add correctness test for cpu offloading (vllm-project#6549)
Browse files Browse the repository at this point in the history
Signed-off-by: Alvant <[email protected]>
  • Loading branch information
youkaichao authored and Alvant committed Oct 26, 2024
1 parent 17ba77a commit 1c9b557
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 85 deletions.
1 change: 1 addition & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ steps:
commands:
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.8/flashinfer-0.0.8+cu121torch2.3-cp310-cp310-linux_x86_64.whl
- pytest -v -s basic_correctness/test_basic_correctness.py
- pytest -v -s basic_correctness/test_cpu_offload.py
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py
- VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py
Expand Down
8 changes: 8 additions & 0 deletions tests/basic_correctness/test_cpu_offload.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from ..utils import compare_two_settings


def test_cpu_offload():
compare_two_settings("meta-llama/Llama-2-7b-hf", [],
["--cpu-offload-gb", "4"])
compare_two_settings("nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t",
[], ["--cpu-offload-gb", "1"])
87 changes: 2 additions & 85 deletions tests/distributed/test_pipeline_parallel.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import pytest
from transformers import AutoTokenizer

from ..utils import RemoteOpenAIServer
from ..utils import compare_two_settings


@pytest.mark.parametrize(
Expand All @@ -13,7 +12,6 @@
(1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B"),
])
def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME):
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

pp_args = [
# use half precision for speed and memory savings in CI environment
Expand Down Expand Up @@ -48,85 +46,4 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME):
pp_args.append("--enforce-eager")
tp_args.append("--enforce-eager")

prompt = "Hello, my name is"
token_ids = tokenizer(prompt)["input_ids"]
results = []
for args in (pp_args, tp_args):
with RemoteOpenAIServer(MODEL_NAME, args) as server:
client = server.get_client()

# test models list
models = client.models.list()
models = models.data
served_model = models[0]
results.append({
"test": "models_list",
"id": served_model.id,
"root": served_model.root,
})

# test with text prompt
completion = client.completions.create(model=MODEL_NAME,
prompt=prompt,
max_tokens=5,
temperature=0.0)

results.append({
"test": "single_completion",
"text": completion.choices[0].text,
"finish_reason": completion.choices[0].finish_reason,
"usage": completion.usage,
})

# test using token IDs
completion = client.completions.create(
model=MODEL_NAME,
prompt=token_ids,
max_tokens=5,
temperature=0.0,
)

results.append({
"test": "token_ids",
"text": completion.choices[0].text,
"finish_reason": completion.choices[0].finish_reason,
"usage": completion.usage,
})

# test simple list
batch = client.completions.create(
model=MODEL_NAME,
prompt=[prompt, prompt],
max_tokens=5,
temperature=0.0,
)

results.append({
"test": "simple_list",
"text0": batch.choices[0].text,
"text1": batch.choices[1].text,
})

# test streaming
batch = client.completions.create(
model=MODEL_NAME,
prompt=[prompt, prompt],
max_tokens=5,
temperature=0.0,
stream=True,
)
texts = [""] * 2
for chunk in batch:
assert len(chunk.choices) == 1
choice = chunk.choices[0]
texts[choice.index] += choice.text
results.append({
"test": "streaming",
"texts": texts,
})

n = len(results) // 2
pp_results = results[:n]
tp_results = results[n:]
for pp, tp in zip(pp_results, tp_results):
assert pp == tp
compare_two_settings(MODEL_NAME, pp_args, tp_args)
94 changes: 94 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import openai
import ray
import requests
from transformers import AutoTokenizer

from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
Expand Down Expand Up @@ -124,6 +125,99 @@ def get_async_client(self):
)


def compare_two_settings(model: str, arg1: List[str], arg2: List[str]):
"""
Launch API server with two different sets of arguments and compare the
results of the API calls. The arguments are after the model name.
"""

tokenizer = AutoTokenizer.from_pretrained(model)

prompt = "Hello, my name is"
token_ids = tokenizer(prompt)["input_ids"]
results = []
for args in (arg1, arg2):
with RemoteOpenAIServer(model, args) as server:
client = server.get_client()

# test models list
models = client.models.list()
models = models.data
served_model = models[0]
results.append({
"test": "models_list",
"id": served_model.id,
"root": served_model.root,
})

# test with text prompt
completion = client.completions.create(model=model,
prompt=prompt,
max_tokens=5,
temperature=0.0)

results.append({
"test": "single_completion",
"text": completion.choices[0].text,
"finish_reason": completion.choices[0].finish_reason,
"usage": completion.usage,
})

# test using token IDs
completion = client.completions.create(
model=model,
prompt=token_ids,
max_tokens=5,
temperature=0.0,
)

results.append({
"test": "token_ids",
"text": completion.choices[0].text,
"finish_reason": completion.choices[0].finish_reason,
"usage": completion.usage,
})

# test simple list
batch = client.completions.create(
model=model,
prompt=[prompt, prompt],
max_tokens=5,
temperature=0.0,
)

results.append({
"test": "simple_list",
"text0": batch.choices[0].text,
"text1": batch.choices[1].text,
})

# test streaming
batch = client.completions.create(
model=model,
prompt=[prompt, prompt],
max_tokens=5,
temperature=0.0,
stream=True,
)
texts = [""] * 2
for chunk in batch:
assert len(chunk.choices) == 1
choice = chunk.choices[0]
texts[choice.index] += choice.text
results.append({
"test": "streaming",
"texts": texts,
})

n = len(results) // 2
arg1_results = results[:n]
arg2_results = results[n:]
for arg1_result, arg2_result in zip(arg1_results, arg2_results):
assert arg1_result == arg2_result, \
f"Results for {model=} are not the same with {arg1=} and {arg2=}"


def init_test_distributed_environment(
tp_size: int,
pp_size: int,
Expand Down

0 comments on commit 1c9b557

Please sign in to comment.