Skip to content

Commit

Permalink
Linting fixes
Browse files Browse the repository at this point in the history
Signed-off-by: Flavia Beo <[email protected]>
  • Loading branch information
flaviabeo committed Oct 24, 2024
1 parent 8407ac7 commit 8dc6ed3
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 24 deletions.
1 change: 1 addition & 0 deletions tests/model_executor/test_model_load_with_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
MODEL_NAME = os.environ.get("MODEL_NAME", "BAAI/bge-base-en-v1.5")
REVISION = os.environ.get("REVISION", "main")


def test_model_loading_with_params(vllm_runner):
"""
Test parameter weight loading with tp>1.
Expand Down
6 changes: 3 additions & 3 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,8 @@ class ModelConfig:
can not be gathered from the vllm arguments.
config_format: The config format which shall be loaded.
Defaults to 'auto' which defaults to 'hf'.
pooling_config: pooling and normalize config from the model - only applies
to sentence-transformers models.
pooling_config: pooling and normalize config from the model -
only applies to sentence-transformers models.
bert_config: tokenizationconfiguration dictionary for a given
Sentence Transformer BERT model.
mm_processor_kwargs: Arguments to be forwarded to the model's processor
Expand Down Expand Up @@ -180,7 +180,7 @@ def __init__(self,
config_format)
self.hf_text_config = get_hf_text_config(self.hf_config)
self.pooling_config = self.get_pooling_config()
self.bert_config = self._get_bert_config()
self.bert_config = self._get_bert_config()
self.do_lower_case = self._get_bert_config()
self.hf_image_processor_config = get_hf_image_processor_config(
self.model, revision)
Expand Down
2 changes: 1 addition & 1 deletion vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def __init__(
model_config.pooling_config.pooling_type,
model_config.pooling_config.normalize,
model_config.chat_template_text_format,
model_config.mm_processor_kwargs,)
model_config.mm_processor_kwargs)
# TODO(woosuk): Print more configs in debug mode.
self.model_config = model_config
self.cache_config = cache_config
Expand Down
28 changes: 14 additions & 14 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,17 +148,18 @@ def _get_model_initialization_kwargs(
return extra_kwargs


def build_model(model_class: Type[nn.Module], hf_config: PretrainedConfig,
def build_model(model_class: Type[nn.Module],
hf_config: PretrainedConfig,
cache_config: Optional[CacheConfig],
quant_config: Optional[QuantizationConfig], *,
quant_config: Optional[QuantizationConfig],
*,
lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
scheduler_config: Optional[SchedulerConfig],
pooling_config: Optional[PoolingConfig] = None) -> nn.Module:
extra_kwargs = _get_model_initialization_kwargs(model_class, lora_config,
multimodal_config,
scheduler_config
)
scheduler_config)

return model_class(config=hf_config,
cache_config=cache_config,
Expand All @@ -176,16 +177,15 @@ def _initialize_model(
"""Initialize a model with the given configurations."""
model_class, _ = get_model_architecture(model_config)

return build_model(
model_class,
model_config.hf_config,
cache_config=cache_config,
quant_config=_get_quantization_config(model_config, load_config),
lora_config=lora_config,
multimodal_config=model_config.multimodal_config,
scheduler_config=scheduler_config,
pooling_config=model_config.pooling_config
)
return build_model(model_class,
model_config.hf_config,
cache_config=cache_config,
quant_config=_get_quantization_config(
model_config, load_config),
lora_config=lora_config,
multimodal_config=model_config.multimodal_config,
scheduler_config=scheduler_config,
pooling_config=model_config.pooling_config)


class BaseModelLoader(ABC):
Expand Down
6 changes: 2 additions & 4 deletions vllm/transformers_utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,10 +328,8 @@ def get_sentence_transformer_tokenizer_config(model, revision='main'):
"sentence_albert_config.json",
"sentence_xlm-roberta_config.json",
"sentence_xlnet_config.json",
]:
bert_dict = get_hf_file_to_dict(config_name,
model,
revision)
]:
bert_dict = get_hf_file_to_dict(config_name, model, revision)
if bert_dict:
break

Expand Down
6 changes: 4 additions & 2 deletions vllm/transformers_utils/tokenizer_group/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@ def init_tokenizer_from_configs(model_config: ModelConfig,
trust_remote_code=model_config.trust_remote_code,
revision=model_config.tokenizer_revision)

if model_config.bert_config is not None and "do_lower_case" in model_config.bert_config:
init_kwargs["do_lower_case"] = model_config.bert_config["do_lower_case"]
if (model_config.bert_config is not None
and "do_lower_case" in model_config.bert_config):
init_kwargs["do_lower_case"] = model_config.bert_config[
"do_lower_case"]

return get_tokenizer_group(parallel_config.tokenizer_pool_config,
**init_kwargs)
Expand Down

0 comments on commit 8dc6ed3

Please sign in to comment.