Skip to content

Commit

Permalink
modified auto_cls logic, and lint check
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin-Yang <[email protected]>
  • Loading branch information
jason9693 committed Oct 26, 2024
1 parent b1d1afc commit e50fd79
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 19 deletions.
15 changes: 2 additions & 13 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@
import torch.nn.functional as F
from huggingface_hub import snapshot_download
from PIL import Image
from transformers import (AutoModelForCausalLM,
AutoModelForSequenceClassification, AutoTokenizer,
AutoConfig, BatchEncoding, BatchFeature)
from transformers import (AutoModelForCausalLM, AutoTokenizer, BatchEncoding,
BatchFeature)
from transformers.models.auto.auto_factory import _BaseAutoModelClass

from tests.models.utils import (TokensTextLogprobs,
Expand Down Expand Up @@ -272,16 +271,6 @@ def __init__(
).to(dtype=torch_dtype))
else:
model_kwargs = model_kwargs if model_kwargs is not None else {}
config = AutoConfig.from_pretrained(
model_name,
torch_dtype=torch_dtype,
trust_remote_code=True,
)
arch = config.architectures
if len(arch) > 0:
cls_type = arch[0].split("For")[-1]
auto_cls = eval(f"AutoModelFor{cls_type}")

self.model = self.wrap_device(
auto_cls.from_pretrained(
model_name,
Expand Down
9 changes: 5 additions & 4 deletions tests/models/decoder_only/language/test_cls_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@
"""
import pytest
import torch
from transformers import AutoModelForSequenceClassification

CLASSIFICATION_MODELS = [
"jason9693/Qwen2.5-1.5B-apeach"
]
CLASSIFICATION_MODELS = ["jason9693/Qwen2.5-1.5B-apeach"]


@pytest.mark.parametrize("model", CLASSIFICATION_MODELS)
Expand All @@ -22,7 +21,9 @@ def test_classification_models(
model: str,
dtype: str,
) -> None:
with hf_runner(model, dtype=dtype) as hf_model:
with hf_runner(model,
dtype=dtype,
auto_cls=AutoModelForSequenceClassification) as hf_model:
hf_outputs = hf_model.classify(example_prompts)

with vllm_runner(model, dtype=dtype) as vllm_model:
Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def _add_seq_group(
else:
block_table = block_tables[seq_id][
-curr_sliding_window_block:]

print(f"prefix cache hit: {prefix_cache_hit}")
print(f"chunked prefill enabled: {chunked_prefill_enabled}")
print(f"prompt: {is_prompt}")
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@
"Gemma2Model": ("gemma2", "Gemma2EmbeddingModel"),
"MistralModel": ("llama", "LlamaEmbeddingModel"),
"Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
"Qwen2ForSequenceClassification": ("qwen2_cls", "Qwen2ForSequenceClassification"),
"Qwen2ForSequenceClassification": (
"qwen2_cls", "Qwen2ForSequenceClassification"),
# [Multimodal]
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
}
Expand Down

0 comments on commit e50fd79

Please sign in to comment.