From c3166f16419a6e4c2ed6296fc14b1085bdd8cbc8 Mon Sep 17 00:00:00 2001 From: Max de Bayser Date: Mon, 28 Oct 2024 10:22:42 -0300 Subject: [PATCH] add pooling_config to models with a Pooler layer Signed-off-by: Max de Bayser --- vllm/model_executor/models/gemma2.py | 11 +++++++++-- vllm/model_executor/models/llama.py | 11 +++++++++-- vllm/model_executor/models/llava_next.py | 13 ++++++++++--- vllm/model_executor/models/phi3v.py | 13 ++++++++++--- vllm/model_executor/models/qwen2_rm.py | 11 +++++++++-- 5 files changed, 47 insertions(+), 12 deletions(-) diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index d79248f93f5ae..6e62ef28926fd 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -31,7 +31,8 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.pooler import Pooler, PoolingType +from vllm.model_executor.layers.pooler import (Pooler, PoolingConfig, + PoolingType) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput @@ -473,12 +474,18 @@ class Gemma2EmbeddingModel(nn.Module, SupportsPP): def __init__( self, + pooling_config: Optional[PoolingConfig] = None, **kwargs, ) -> None: super().__init__() self.model = Gemma2Model(**kwargs) - self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) + if pooling_config is not None: + self._pooler = Pooler(pooling_config.pooling_type, + pooling_config.normalize) + else: + self._pooler = Pooler(pooling_type=PoolingType.LAST, + normalize=True) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index c346e3e808e3f..c3c992cf2e17f 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -38,7 +38,8 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.pooler import Pooler, PoolingType +from vllm.model_executor.layers.pooler import (Pooler, PoolingConfig, + PoolingType) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( get_compressed_tensors_cache_scale) @@ -627,12 +628,18 @@ class LlamaEmbeddingModel(nn.Module, SupportsPP): def __init__( self, + pooling_config: Optional[PoolingConfig] = None, **kwargs, ) -> None: super().__init__() self.model = LlamaModel(**kwargs) - self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) + if pooling_config is not None: + self._pooler = Pooler(pooling_config.pooling_type, + pooling_config.normalize) + else: + self._pooler = Pooler(pooling_type=PoolingType.LAST, + normalize=True) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index 2a582deeaa2c9..7cb719f5c57aa 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -13,7 +13,8 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext -from vllm.model_executor.layers.pooler import Pooler, PoolingType +from vllm.model_executor.layers.pooler import (Pooler, PoolingConfig, + PoolingType) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.pooling_metadata import PoolingMetadata @@ -285,7 +286,8 @@ def __init__(self, config: LlavaNextConfig, multimodal_config: MultiModalConfig, cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None) -> None: + quant_config: Optional[QuantizationConfig] = None, + pooling_config: Optional[PoolingConfig] = None) -> None: super().__init__() self.config = config @@ -306,7 +308,12 @@ def __init__(self, # The same model class supports both language generation and embedding # because the architecture name is the same - self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) + if pooling_config is not None: + self._pooler = Pooler(pooling_config.pooling_type, + pooling_config.normalize) + else: + self._pooler = Pooler(pooling_type=PoolingType.LAST, + normalize=True) self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 855a9b17585a4..6e8b323e89fb4 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -30,7 +30,8 @@ from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, token_inputs) from vllm.logger import init_logger -from vllm.model_executor.layers.pooler import Pooler, PoolingType +from vllm.model_executor.layers.pooler import (Pooler, PoolingConfig, + PoolingType) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -525,7 +526,8 @@ def __init__(self, config: PretrainedConfig, multimodal_config: MultiModalConfig, cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None) -> None: + quant_config: Optional[QuantizationConfig] = None, + pooling_config: Optional[PoolingConfig] = None) -> None: super().__init__() self.config = config @@ -547,7 +549,12 @@ def __init__(self, # The same model class supports both language generation and embedding # because the architecture name is the same - self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) + if pooling_config is not None: + self._pooler = Pooler(pooling_config.pooling_type, + pooling_config.normalize) + else: + self._pooler = Pooler(pooling_type=PoolingType.LAST, + normalize=True) self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/qwen2_rm.py b/vllm/model_executor/models/qwen2_rm.py index ee0eeb9db3808..3493aeffd2062 100644 --- a/vllm/model_executor/models/qwen2_rm.py +++ b/vllm/model_executor/models/qwen2_rm.py @@ -14,7 +14,8 @@ from vllm.config import CacheConfig, LoRAConfig from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) -from vllm.model_executor.layers.pooler import Pooler, PoolingType +from vllm.model_executor.layers.pooler import (Pooler, PoolingConfig, + PoolingType) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.sequence import IntermediateTensors, PoolerOutput @@ -64,6 +65,7 @@ def __init__( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, + pooling_config: Optional[PoolingConfig] = None, ) -> None: # TODO (@robertgshaw2): see if this can be moved out if (cache_config.sliding_window is not None @@ -93,7 +95,12 @@ def __init__( RowParallelLinear(config.hidden_size, 1, quant_config=quant_config), ) - self._pooler = Pooler(pooling_type=PoolingType.ALL, normalize=False) + if pooling_config is not None: + self._pooler = Pooler(pooling_config.pooling_type, + pooling_config.normalize) + else: + self._pooler = Pooler(pooling_type=PoolingType.ALL, + normalize=False) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors)