Skip to content

Commit

Permalink
add test code
Browse files Browse the repository at this point in the history
  • Loading branch information
kakao-kevin-us committed Oct 25, 2024
1 parent bf18a77 commit eabb801
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 1 deletion.
32 changes: 31 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import torch.nn.functional as F
from huggingface_hub import snapshot_download
from PIL import Image
from transformers import (AutoModelForCausalLM, AutoTokenizer, BatchEncoding,
from transformers import (AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, AutoConfig, BatchEncoding,
BatchFeature)
from transformers.models.auto.auto_factory import _BaseAutoModelClass

Expand Down Expand Up @@ -271,6 +271,16 @@ def __init__(
).to(dtype=torch_dtype))
else:
model_kwargs = model_kwargs if model_kwargs is not None else {}
config = AutoConfig.from_pretrained(
model_name,
torch_dtype=torch_dtype,
trust_remote_code=True,
)
arch = config.architectures
if len(arch) > 0:
cls_type = arch[0].split("For")[-1]
auto_cls = eval(f"AutoModelFor{cls_type}")

self.model = self.wrap_device(
auto_cls.from_pretrained(
model_name,
Expand Down Expand Up @@ -334,6 +344,18 @@ def get_inputs(

return all_inputs

def classify(self, prompts: List[str]) -> List[str]:
# output is final logits
all_inputs = self.get_inputs(prompts)
outputs = []
print(f"model: {self.model}")
for inputs in all_inputs:
output = self.model(**self.wrap_device(inputs))
logits = output.logits.softmax(dim=-1)[0].tolist()
outputs.append(logits)

return outputs

def generate(
self,
prompts: List[str],
Expand Down Expand Up @@ -668,6 +690,14 @@ def get_inputs(
inputs[i]["multi_modal_data"] = {"audio": audio}

return inputs

def classify(self, prompts: List[str]) -> List[str]:
req_outputs = self.model.encode(prompts)
outputs = []
for req_output in req_outputs:
embedding = req_output.outputs.embedding
outputs.append(embedding)
return outputs

def generate(
self,
Expand Down
54 changes: 54 additions & 0 deletions tests/models/decoder_only/language/test_cls_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""Compare the outputs of HF and vLLM when using greedy sampling.
This test only tests small models. Big models such as 7B should be tested from
test_big_models.py because it could use a larger instance to run tests.
Run `pytest tests/models/test_models.py`.
"""
import pytest
import torch

from ...utils import check_logprobs_close, check_outputs_equal

CLASSIFICATION_MODELS = [
"jason9693/Qwen2.5-1.5B-apeach"
]


@pytest.mark.parametrize("model", CLASSIFICATION_MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
def test_classification_models(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
) -> None:
with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.classify(example_prompts)

with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.classify(example_prompts)

print(hf_outputs, vllm_outputs)

# check logits difference
for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
hf_output = torch.tensor(hf_output)
vllm_output = torch.tensor(vllm_output)

assert torch.allclose(hf_output, vllm_output, 1e-3)


@pytest.mark.parametrize("model", CLASSIFICATION_MODELS)
@pytest.mark.parametrize("dtype", ["float"])
def test_classification_model_print(
vllm_runner,
model: str,
dtype: str,
) -> None:
with vllm_runner(model, dtype=dtype) as vllm_model:
# This test is for verifying whether the model's extra_repr
# can be printed correctly.
print(vllm_model.model.llm_engine.model_executor.driver_worker.
model_runner.model)

0 comments on commit eabb801

Please sign in to comment.