diff --git a/vllm/config.py b/vllm/config.py index f3dcd9f0d7afc..e381eef73da05 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -5,9 +5,9 @@ Mapping, Optional, Set, Tuple, Type, Union) import torch +from transformers import PretrainedConfig import vllm.envs as envs -from transformers import PretrainedConfig from vllm.logger import init_logger from vllm.model_executor.layers.pooler import PoolingConfig, PoolingType from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS @@ -420,7 +420,8 @@ def _verify_bnb_config(self) -> None: "fallback to the eager mode.") self.enforce_eager = True - def get_pooling_type(self, pooling_type_name: str) -> PoolingType: + def get_pooling_type(self, + pooling_type_name: str) -> Optional[PoolingType]: pooling_types = {i.name: i for i in PoolingType} return pooling_types.get(pooling_type_name) diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index f4dd98ced80ee..2ee24471dc0a9 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -9,13 +9,13 @@ from huggingface_hub.utils import (EntryNotFoundError, LocalEntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError) +from transformers import GenerationConfig, PretrainedConfig from transformers.models.auto.image_processing_auto import ( get_image_processor_config) from transformers.models.auto.modeling_auto import ( MODEL_FOR_CAUSAL_LM_MAPPING_NAMES) from transformers.utils import CONFIG_NAME as HF_CONFIG_NAME -from transformers import GenerationConfig, PretrainedConfig from vllm.envs import VLLM_USE_MODELSCOPE from vllm.logger import init_logger from vllm.model_executor.layers.pooler import PoolingType