Skip to content

Commit

Permalink
Asserts on the correct tokenizer loaded
Browse files Browse the repository at this point in the history
Signed-off-by: Flavia Beo <[email protected]>
  • Loading branch information
flaviabeo committed Oct 24, 2024
1 parent 1554834 commit 8407ac7
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 6 deletions.
17 changes: 14 additions & 3 deletions tests/model_executor/test_model_load_with_params.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import os

from vllm.model_executor.layers.pooler import PoolingType

MAX_MODEL_LEN = 128
MODEL_NAME = os.environ.get("MODEL_NAME", "BAAI/bge-base-en-v1.5")
REVISION = os.environ.get("REVISION", "main")


def test_model_loading_with_params(vllm_runner):
"""
Test parameter weight loading with tp>1.
Expand All @@ -18,10 +19,20 @@ def test_model_loading_with_params(vllm_runner):

model_config = model.model.llm_engine.model_config

print(model.model.llm_engine.__dict__.items())
model_tokenizer = model.model.llm_engine.tokenizer

# asserts on the bert model config file
assert model_config.bert_config["max_seq_length"] == 512
assert model_config.bert_config["do_lower_case"]
assert model_config.pooling_config.pooling_type == 2

# asserts on the pooling config files
assert model_config.pooling_config.pooling_type == PoolingType.CLS
assert model_config.pooling_config.normalize

# asserts on the tokenizer loaded
assert model_tokenizer.tokenizer_id == "BAAI/bge-base-en-v1.5"
assert model_tokenizer.tokenizer_config["do_lower_case"]
assert model_tokenizer.tokenizer.model_max_length == 512

# assert output
assert output
4 changes: 3 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,8 @@ def __init__(self,
config_format)
self.hf_text_config = get_hf_text_config(self.hf_config)
self.pooling_config = self.get_pooling_config()
self.bert_config, self.do_lower_case = self._get_bert_config()
self.bert_config = self._get_bert_config()
self.do_lower_case = self._get_bert_config()
self.hf_image_processor_config = get_hf_image_processor_config(
self.model, revision)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
Expand Down Expand Up @@ -1756,6 +1757,7 @@ def _get_and_verify_max_len(
"max_seq_length",
"seq_len",
]

# Choose the smallest "max_length" from the possible keys.
max_len_key = None
for key in possible_keys:
Expand Down
3 changes: 1 addition & 2 deletions vllm/transformers_utils/tokenizer_group/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ def init_tokenizer_from_configs(model_config: ModelConfig,
trust_remote_code=model_config.trust_remote_code,
revision=model_config.tokenizer_revision)

if model_config.bert_config is not None and "do_lower_case"
in model_config.bert_config:
if model_config.bert_config is not None and "do_lower_case" in model_config.bert_config:
init_kwargs["do_lower_case"] = model_config.bert_config["do_lower_case"]

return get_tokenizer_group(parallel_config.tokenizer_pool_config,
Expand Down

0 comments on commit 8407ac7

Please sign in to comment.