diff --git a/vllm/config.py b/vllm/config.py index 4533fb017188c..467b3683e6991 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -10,11 +10,13 @@ import vllm.envs as envs from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS +from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.models import ModelRegistry from vllm.platforms import current_platform from vllm.tracing import is_otel_available, otel_import_error_traceback from vllm.transformers_utils.config import (ConfigFormat, get_config, get_hf_image_processor_config, + get_pooling_config, get_hf_text_config) from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory, is_hip, is_neuron, is_openvino, is_xpu, @@ -163,6 +165,7 @@ def __init__(self, code_revision, rope_scaling, rope_theta, config_format) self.hf_text_config = get_hf_text_config(self.hf_config) + self.pooling_type = self.get_pooling_type() self.hf_image_processor_config = get_hf_image_processor_config( self.model, revision) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) @@ -362,6 +365,16 @@ def _verify_bnb_config(self) -> None: "fallback to the eager mode.") self.enforce_eager = True + def get_pooling_type(self) -> PoolingType: + pooling_type_name = get_pooling_config(self.model, + self.revision) + pooling_types = PoolingType.__dict__.items() + pooling_type = next((value for key, + value in pooling_types if key.lower() + in pooling_type_name), + None) + return PoolingType(pooling_type) + def verify_async_output_proc(self, parallel_config, speculative_config, device_config) -> None: if not self.use_async_output_proc: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 61c21887e6816..a0cbfdd8d6911 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -251,7 +251,7 @@ def __init__( "num_scheduler_steps=%d, chunked_prefill_enabled=%s " "multi_step_stream_outputs=%s, enable_prefix_caching=%s, " "use_async_output_proc=%s, use_cached_outputs=%s, " - "mm_processor_kwargs=%s)", + "mm_processor_kwargs=%s, pooling_type=%s)", VLLM_VERSION, model_config.model, speculative_config, @@ -287,6 +287,7 @@ def __init__( model_config.use_async_output_proc, use_cached_outputs, model_config.mm_processor_kwargs, + model_config.pooling_type ) # TODO(woosuk): Print more configs in debug mode. self.model_config = model_config diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index f5f1d47a4c27c..89295ec26c1ac 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -6,7 +6,7 @@ from vllm.attention import Attention, AttentionMetadata, AttentionType from vllm.attention.backends.xformers import XFormersImpl -from vllm.config import CacheConfig +from vllm.config import CacheConfig, ModelConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -21,7 +21,6 @@ from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.sequence import IntermediateTensors, PoolerOutput -from vllm.transformers_utils.config import get_pooling_config class BertEmbedding(nn.Module): @@ -391,8 +390,11 @@ def __init__( ) -> None: super().__init__() self.model = BertModel(config, cache_config, quant_config) - self.pooling_type = self.get_pooling_type() - self._pooler = Pooler(pooling_type=self.pooling_type, normalize=True) + print("====================================================") + print(cache_config.__dict__) + # self.pooling_type = ModelConfig.get_pooling_type(config) + # self._pooler = Pooler(pooling_type=self.pooling_type, normalize=True) + self._pooler = Pooler(PoolingType.CLS, normalize=True) def forward( self, @@ -419,12 +421,3 @@ def pooler( def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): self.model.load_weights(weights) - - def get_pooling_type(self): - pooling_type_name = get_pooling_config(self.model) - pooling_types = PoolingType.__dict__.items() - pooling_type = next((value for key, - value in pooling_types if key.lower() - in pooling_type_name), - None) - return PoolingType(pooling_type) \ No newline at end of file