Skip to content

Commit

Permalink
fix loading of non-bert models and fix tests
Browse files Browse the repository at this point in the history
Signed-off-by: Max de Bayser <[email protected]>
  • Loading branch information
maxdebayser authored and flaviabeo committed Oct 25, 2024
1 parent 0999971 commit eea5984
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 45 deletions.
6 changes: 6 additions & 0 deletions tests/model_executor/test_model_load_with_params.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os

from vllm.model_executor.layers.pooler import PoolingType
from vllm.model_executor.models.bert import BertEmbeddingModel

MAX_MODEL_LEN = 128
MODEL_NAME = os.environ.get("MODEL_NAME", "BAAI/bge-base-en-v1.5")
Expand Down Expand Up @@ -35,5 +36,10 @@ def test_model_loading_with_params(vllm_runner):
assert model_tokenizer.tokenizer_config["do_lower_case"]
assert model_tokenizer.tokenizer.model_max_length == 512

model = model.model.llm_engine.model_executor\
.driver_worker.model_runner.model
assert isinstance(model, BertEmbeddingModel)
assert model._pooler.pooling_type == PoolingType.CLS
assert model._pooler.normalize
# assert output
assert output
6 changes: 4 additions & 2 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,11 @@ def test_get_sliding_window():


def test_get_pooling_config():
model_id = "sentence-transformers/all-MiniLM-L12-v2"
minilm_model_config = ModelConfig(
"sentence-transformers/all-MiniLM-L12-v2",
"sentence-transformers/all-MiniLM-L12-v2",
model_id,
task="auto",
tokenizer=model_id,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
Expand Down
14 changes: 6 additions & 8 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,6 @@ def __init__(self,
self.hf_text_config = get_hf_text_config(self.hf_config)
self.pooling_config = self.get_pooling_config()
self.bert_config = self._get_bert_config()
self.do_lower_case = self._get_bert_config()
self.hf_image_processor_config = get_hf_image_processor_config(
self.model, revision)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
Expand Down Expand Up @@ -253,11 +252,8 @@ def _init_multimodal_config(
return None

def _get_bert_config(self):
bert_config = get_sentence_transformer_tokenizer_config(
return get_sentence_transformer_tokenizer_config(
self.model, self.revision)
if bert_config is not None:
return bert_config
return None

def _init_attention_free(self) -> bool:
architectures = getattr(self.hf_config, "architectures", [])
Expand Down Expand Up @@ -422,10 +418,12 @@ def _verify_bnb_config(self) -> None:
"fallback to the eager mode.")
self.enforce_eager = True

def get_pooling_config(self) -> PoolingConfig:
def get_pooling_config(self) -> Optional[PoolingConfig]:
pooling_config = get_pooling_config(self.model, self.revision)
return PoolingConfig(pooling_config["pooling_type"],
pooling_config["normalize"])
if pooling_config is not None:
return PoolingConfig(pooling_config["pooling_type"],
pooling_config["normalize"])
return None

def verify_async_output_proc(self, parallel_config, speculative_config,
device_config) -> None:
Expand Down
47 changes: 17 additions & 30 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,42 +256,29 @@ def __init__(
"use_async_output_proc=%s, use_cached_outputs=%s, "
"pooling_config_type=%s, normalize=%s, "
"chat_template_text_format=%s, mm_processor_kwargs=%s)",
VLLM_VERSION,
model_config.model,
speculative_config,
model_config.tokenizer,
model_config.skip_tokenizer_init,
model_config.tokenizer_mode,
model_config.revision,
model_config.override_neuron_config,
model_config.rope_scaling,
model_config.rope_theta,
model_config.tokenizer_revision,
model_config.trust_remote_code,
model_config.dtype,
model_config.max_model_len,
load_config.download_dir,
load_config.load_format,
parallel_config.tensor_parallel_size,
VLLM_VERSION, model_config.model, speculative_config,
model_config.tokenizer, model_config.skip_tokenizer_init,
model_config.tokenizer_mode, model_config.revision,
model_config.override_neuron_config, model_config.rope_scaling,
model_config.rope_theta, model_config.tokenizer_revision,
model_config.trust_remote_code, model_config.dtype,
model_config.max_model_len, load_config.download_dir,
load_config.load_format, parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size,
parallel_config.disable_custom_all_reduce,
model_config.quantization,
model_config.enforce_eager,
cache_config.cache_dtype,
model_config.quantization_param_path,
device_config.device,
decoding_config,
observability_config,
model_config.seed,
model_config.served_model_name,
model_config.quantization, model_config.enforce_eager,
cache_config.cache_dtype, model_config.quantization_param_path,
device_config.device, decoding_config, observability_config,
model_config.seed, model_config.served_model_name,
scheduler_config.num_scheduler_steps,
scheduler_config.chunked_prefill_enabled,
scheduler_config.multi_step_stream_outputs,
cache_config.enable_prefix_caching,
model_config.use_async_output_proc,
use_cached_outputs,
model_config.pooling_config.pooling_type,
model_config.pooling_config.normalize,
model_config.use_async_output_proc, use_cached_outputs,
model_config.pooling_config.pooling_type
if model_config.pooling_config is not None else None,
model_config.pooling_config.normalize
if model_config.pooling_config is not None else None,
model_config.chat_template_text_format,
model_config.mm_processor_kwargs)
# TODO(woosuk): Print more configs in debug mode.
Expand Down
10 changes: 7 additions & 3 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@ def _get_model_initialization_kwargs(
model_class: Type[nn.Module],
lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
scheduler_config: Optional[SchedulerConfig] = None) -> Dict[str, Any]:
scheduler_config: Optional[SchedulerConfig] = None,
pooling_config: Optional[PoolingConfig] = None) -> Dict[str, Any]:
"""Get extra kwargs for model initialization."""
extra_kwargs: Dict[str, Any] = {}

Expand All @@ -145,6 +146,9 @@ def _get_model_initialization_kwargs(
if has_inner_state(model_class) and scheduler_config:
extra_kwargs["scheduler_config"] = scheduler_config

if pooling_config is not None:
extra_kwargs["pooling_config"] = pooling_config

return extra_kwargs


Expand All @@ -159,12 +163,12 @@ def build_model(model_class: Type[nn.Module],
pooling_config: Optional[PoolingConfig] = None) -> nn.Module:
extra_kwargs = _get_model_initialization_kwargs(model_class, lora_config,
multimodal_config,
scheduler_config)
scheduler_config,
pooling_config)

return model_class(config=hf_config,
cache_config=cache_config,
quant_config=quant_config,
pooling_config=pooling_config,
**extra_kwargs)


Expand Down
7 changes: 5 additions & 2 deletions vllm/transformers_utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@
from huggingface_hub.utils import (EntryNotFoundError, LocalEntryNotFoundError,
RepositoryNotFoundError,
RevisionNotFoundError)
from transformers import GenerationConfig, PretrainedConfig
from transformers.models.auto.image_processing_auto import (
get_image_processor_config)
from transformers.models.auto.modeling_auto import (
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
from transformers.utils import CONFIG_NAME as HF_CONFIG_NAME

from transformers import GenerationConfig, PretrainedConfig
from vllm.envs import VLLM_USE_MODELSCOPE
from vllm.logger import init_logger
# yapf conflicts with isort for this block
Expand Down Expand Up @@ -256,7 +256,7 @@ def get_hf_file_to_dict(file_name, model, revision):
hf_hub_file = hf_hub_download(model, file_name, revision=revision)
except (RepositoryNotFoundError, RevisionNotFoundError,
EntryNotFoundError, LocalEntryNotFoundError) as e:
logger.info("File or repository not found in hf_hub_download", e)
logger.debug("File or repository not found in hf_hub_download", e)
return None
file_path = Path(hf_hub_file)

Expand Down Expand Up @@ -285,6 +285,9 @@ def get_pooling_config(model, revision='main'):
modules_file_name = "modules.json"
modules_dict = get_hf_file_to_dict(modules_file_name, model, revision)

if modules_dict is None:
return None

pooling = next((item for item in modules_dict
if item["type"] == "sentence_transformers.models.Pooling"),
None)
Expand Down

0 comments on commit eea5984

Please sign in to comment.