Skip to content

Commit

Permalink
Changes variable name to encoder_config
Browse files Browse the repository at this point in the history
Signed-off-by: Flavia Beo <[email protected]>
  • Loading branch information
flaviabeo committed Nov 6, 2024
1 parent cc6da45 commit e342f4e
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 17 deletions.
4 changes: 2 additions & 2 deletions tests/model_executor/test_model_load_with_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ def test_model_loading_with_params(vllm_runner):
model_tokenizer = model.model.llm_engine.tokenizer

# asserts on the bert model config file
assert model_config.bert_config["max_seq_length"] == 512
assert model_config.bert_config["do_lower_case"]
assert model_config.encoder_config["max_seq_length"] == 512
assert model_config.encoder_config["do_lower_case"]

# asserts on the pooling config files
assert model_config.pooler_config.pooling_type == PoolingType.CLS.name
Expand Down
2 changes: 1 addition & 1 deletion tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def test_get_bert_tokenization_sentence_transformer_config():
revision=None,
)

bert_bge_model_config = bge_model_config._get_bert_config()
bert_bge_model_config = bge_model_config._get_encoder_config()

assert bert_bge_model_config["max_seq_length"] == 512
assert bert_bge_model_config["do_lower_case"]
Expand Down
12 changes: 6 additions & 6 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def __init__(
code_revision, rope_scaling, rope_theta,
config_format)
self.hf_text_config = get_hf_text_config(self.hf_config)
self.bert_config = self._get_bert_config()
self.encoder_config = self._get_encoder_config()
self.hf_image_processor_config = get_hf_image_processor_config(
self.model, revision)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
Expand Down Expand Up @@ -231,7 +231,7 @@ def __init__(
disable_sliding_window=self.disable_sliding_window,
sliding_window_len=self.get_hf_config_sliding_window(),
spec_target_max_model_len=spec_target_max_model_len,
bert_config=self.bert_config)
encoder_config=self.encoder_config)
self.served_model_name = get_served_model_name(model,
served_model_name)
self.multimodal_config = self._init_multimodal_config(
Expand Down Expand Up @@ -275,7 +275,7 @@ def _init_multimodal_config(

return None

def _get_bert_config(self):
def _get_encoder_config(self):
return get_sentence_transformer_tokenizer_config(
self.model, self.revision)

Expand Down Expand Up @@ -1808,7 +1808,7 @@ def _get_and_verify_max_len(
disable_sliding_window: bool,
sliding_window_len: Optional[Union[int, List[Optional[int]]]],
spec_target_max_model_len: Optional[int] = None,
bert_config: Optional[Any] = None,
encoder_config: Optional[Any] = None,
) -> int:
"""Get and verify the model's maximum length."""
derived_max_model_len = float("inf")
Expand Down Expand Up @@ -1891,8 +1891,8 @@ def _get_and_verify_max_len(
"original_max_position_embeddings"]
derived_max_model_len *= scaling_factor

if bert_config and "max_seq_length" in bert_config:
derived_max_model_len = bert_config["max_seq_length"]
if encoder_config and "max_seq_length" in encoder_config:
derived_max_model_len = encoder_config["max_seq_length"]

# If the user specified a max length, make sure it is smaller than the
# derived length from the HF model config.
Expand Down
10 changes: 5 additions & 5 deletions vllm/transformers_utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,15 +377,15 @@ def get_sentence_transformer_tokenizer_config(model: str,
"sentence_xlm-roberta_config.json",
"sentence_xlnet_config.json",
]:
bert_dict = get_hf_file_to_dict(config_name, model, revision, token)
if bert_dict:
encoder_dict = get_hf_file_to_dict(config_name, model, revision, token)
if encoder_dict:
break

if not bert_dict:
if not encoder_dict:
return None

if all(k in bert_dict for k in ("max_seq_length", "do_lower_case")):
return bert_dict
if all(k in encoder_dict for k in ("max_seq_length", "do_lower_case")):
return encoder_dict
return None


Expand Down
6 changes: 3 additions & 3 deletions vllm/transformers_utils/tokenizer_group/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ def init_tokenizer_from_configs(model_config: ModelConfig,
trust_remote_code=model_config.trust_remote_code,
revision=model_config.tokenizer_revision)

if (model_config.bert_config is not None
and "do_lower_case" in model_config.bert_config):
init_kwargs["do_lower_case"] = model_config.bert_config[
if (model_config.encoder_config is not None
and "do_lower_case" in model_config.encoder_config):
init_kwargs["do_lower_case"] = model_config.encoder_config[
"do_lower_case"]

return get_tokenizer_group(parallel_config.tokenizer_pool_config,
Expand Down

0 comments on commit e342f4e

Please sign in to comment.