From 8dc6ed31825e4857246d9f574d68c910f659452a Mon Sep 17 00:00:00 2001 From: Flavia Beo Date: Thu, 24 Oct 2024 16:16:16 -0300 Subject: [PATCH] Linting fixes Signed-off-by: Flavia Beo --- .../test_model_load_with_params.py | 1 + vllm/config.py | 6 ++-- vllm/engine/llm_engine.py | 2 +- vllm/model_executor/model_loader/loader.py | 28 +++++++++---------- vllm/transformers_utils/config.py | 6 ++-- .../tokenizer_group/__init__.py | 6 ++-- 6 files changed, 25 insertions(+), 24 deletions(-) diff --git a/tests/model_executor/test_model_load_with_params.py b/tests/model_executor/test_model_load_with_params.py index 6d4fb5e81194b..e3480f0c8f6db 100644 --- a/tests/model_executor/test_model_load_with_params.py +++ b/tests/model_executor/test_model_load_with_params.py @@ -6,6 +6,7 @@ MODEL_NAME = os.environ.get("MODEL_NAME", "BAAI/bge-base-en-v1.5") REVISION = os.environ.get("REVISION", "main") + def test_model_loading_with_params(vllm_runner): """ Test parameter weight loading with tp>1. diff --git a/vllm/config.py b/vllm/config.py index 67572d3edceb1..88bde618a65d3 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -112,8 +112,8 @@ class ModelConfig: can not be gathered from the vllm arguments. config_format: The config format which shall be loaded. Defaults to 'auto' which defaults to 'hf'. - pooling_config: pooling and normalize config from the model - only applies - to sentence-transformers models. + pooling_config: pooling and normalize config from the model - + only applies to sentence-transformers models. bert_config: tokenizationconfiguration dictionary for a given Sentence Transformer BERT model. mm_processor_kwargs: Arguments to be forwarded to the model's processor @@ -180,7 +180,7 @@ def __init__(self, config_format) self.hf_text_config = get_hf_text_config(self.hf_config) self.pooling_config = self.get_pooling_config() - self.bert_config = self._get_bert_config() + self.bert_config = self._get_bert_config() self.do_lower_case = self._get_bert_config() self.hf_image_processor_config = get_hf_image_processor_config( self.model, revision) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index e98d5b6525a2d..e96ad40224233 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -293,7 +293,7 @@ def __init__( model_config.pooling_config.pooling_type, model_config.pooling_config.normalize, model_config.chat_template_text_format, - model_config.mm_processor_kwargs,) + model_config.mm_processor_kwargs) # TODO(woosuk): Print more configs in debug mode. self.model_config = model_config self.cache_config = cache_config diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 12a9fa0aee6ee..133654e53c5d4 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -148,17 +148,18 @@ def _get_model_initialization_kwargs( return extra_kwargs -def build_model(model_class: Type[nn.Module], hf_config: PretrainedConfig, +def build_model(model_class: Type[nn.Module], + hf_config: PretrainedConfig, cache_config: Optional[CacheConfig], - quant_config: Optional[QuantizationConfig], *, + quant_config: Optional[QuantizationConfig], + *, lora_config: Optional[LoRAConfig], multimodal_config: Optional[MultiModalConfig], scheduler_config: Optional[SchedulerConfig], pooling_config: Optional[PoolingConfig] = None) -> nn.Module: extra_kwargs = _get_model_initialization_kwargs(model_class, lora_config, multimodal_config, - scheduler_config - ) + scheduler_config) return model_class(config=hf_config, cache_config=cache_config, @@ -176,16 +177,15 @@ def _initialize_model( """Initialize a model with the given configurations.""" model_class, _ = get_model_architecture(model_config) - return build_model( - model_class, - model_config.hf_config, - cache_config=cache_config, - quant_config=_get_quantization_config(model_config, load_config), - lora_config=lora_config, - multimodal_config=model_config.multimodal_config, - scheduler_config=scheduler_config, - pooling_config=model_config.pooling_config - ) + return build_model(model_class, + model_config.hf_config, + cache_config=cache_config, + quant_config=_get_quantization_config( + model_config, load_config), + lora_config=lora_config, + multimodal_config=model_config.multimodal_config, + scheduler_config=scheduler_config, + pooling_config=model_config.pooling_config) class BaseModelLoader(ABC): diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index efeac9f2f1a01..02ad9bba39ead 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -328,10 +328,8 @@ def get_sentence_transformer_tokenizer_config(model, revision='main'): "sentence_albert_config.json", "sentence_xlm-roberta_config.json", "sentence_xlnet_config.json", - ]: - bert_dict = get_hf_file_to_dict(config_name, - model, - revision) + ]: + bert_dict = get_hf_file_to_dict(config_name, model, revision) if bert_dict: break diff --git a/vllm/transformers_utils/tokenizer_group/__init__.py b/vllm/transformers_utils/tokenizer_group/__init__.py index 0b5b704d0a7c7..95ac1d4e6baf7 100644 --- a/vllm/transformers_utils/tokenizer_group/__init__.py +++ b/vllm/transformers_utils/tokenizer_group/__init__.py @@ -25,8 +25,10 @@ def init_tokenizer_from_configs(model_config: ModelConfig, trust_remote_code=model_config.trust_remote_code, revision=model_config.tokenizer_revision) - if model_config.bert_config is not None and "do_lower_case" in model_config.bert_config: - init_kwargs["do_lower_case"] = model_config.bert_config["do_lower_case"] + if (model_config.bert_config is not None + and "do_lower_case" in model_config.bert_config): + init_kwargs["do_lower_case"] = model_config.bert_config[ + "do_lower_case"] return get_tokenizer_group(parallel_config.tokenizer_pool_config, **init_kwargs)