Skip to content

Commit

Permalink
add pooling_config to models with a Pooler layer
Browse files Browse the repository at this point in the history
Signed-off-by: Max de Bayser <[email protected]>
  • Loading branch information
maxdebayser authored and flaviabeo committed Oct 28, 2024
1 parent 0b948a4 commit c3166f1
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 12 deletions.
11 changes: 9 additions & 2 deletions vllm/model_executor/models/gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.pooler import (Pooler, PoolingConfig,
PoolingType)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
Expand Down Expand Up @@ -473,12 +474,18 @@ class Gemma2EmbeddingModel(nn.Module, SupportsPP):

def __init__(
self,
pooling_config: Optional[PoolingConfig] = None,
**kwargs,
) -> None:
super().__init__()

self.model = Gemma2Model(**kwargs)
self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
if pooling_config is not None:
self._pooler = Pooler(pooling_config.pooling_type,
pooling_config.normalize)
else:
self._pooler = Pooler(pooling_type=PoolingType.LAST,
normalize=True)

self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
Expand Down
11 changes: 9 additions & 2 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.pooler import (Pooler, PoolingConfig,
PoolingType)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
get_compressed_tensors_cache_scale)
Expand Down Expand Up @@ -627,12 +628,18 @@ class LlamaEmbeddingModel(nn.Module, SupportsPP):

def __init__(
self,
pooling_config: Optional[PoolingConfig] = None,
**kwargs,
) -> None:
super().__init__()

self.model = LlamaModel(**kwargs)
self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
if pooling_config is not None:
self._pooler = Pooler(pooling_config.pooling_type,
pooling_config.normalize)
else:
self._pooler = Pooler(pooling_type=PoolingType.LAST,
normalize=True)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)

Expand Down
13 changes: 10 additions & 3 deletions vllm/model_executor/models/llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.pooler import (Pooler, PoolingConfig,
PoolingType)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.pooling_metadata import PoolingMetadata
Expand Down Expand Up @@ -285,7 +286,8 @@ def __init__(self,
config: LlavaNextConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None:
quant_config: Optional[QuantizationConfig] = None,
pooling_config: Optional[PoolingConfig] = None) -> None:
super().__init__()

self.config = config
Expand All @@ -306,7 +308,12 @@ def __init__(self,

# The same model class supports both language generation and embedding
# because the architecture name is the same
self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
if pooling_config is not None:
self._pooler = Pooler(pooling_config.pooling_type,
pooling_config.normalize)
else:
self._pooler = Pooler(pooling_type=PoolingType.LAST,
normalize=True)

self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
Expand Down
13 changes: 10 additions & 3 deletions vllm/model_executor/models/phi3v.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
token_inputs)
from vllm.logger import init_logger
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.pooler import (Pooler, PoolingConfig,
PoolingType)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
Expand Down Expand Up @@ -525,7 +526,8 @@ def __init__(self,
config: PretrainedConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None:
quant_config: Optional[QuantizationConfig] = None,
pooling_config: Optional[PoolingConfig] = None) -> None:
super().__init__()

self.config = config
Expand All @@ -547,7 +549,12 @@ def __init__(self,

# The same model class supports both language generation and embedding
# because the architecture name is the same
self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
if pooling_config is not None:
self._pooler = Pooler(pooling_config.pooling_type,
pooling_config.normalize)
else:
self._pooler = Pooler(pooling_type=PoolingType.LAST,
normalize=True)

self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
Expand Down
11 changes: 9 additions & 2 deletions vllm/model_executor/models/qwen2_rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
from vllm.config import CacheConfig, LoRAConfig
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.pooler import (Pooler, PoolingConfig,
PoolingType)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput
Expand Down Expand Up @@ -64,6 +65,7 @@ def __init__(
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
pooling_config: Optional[PoolingConfig] = None,
) -> None:
# TODO (@robertgshaw2): see if this can be moved out
if (cache_config.sliding_window is not None
Expand Down Expand Up @@ -93,7 +95,12 @@ def __init__(
RowParallelLinear(config.hidden_size, 1,
quant_config=quant_config),
)
self._pooler = Pooler(pooling_type=PoolingType.ALL, normalize=False)
if pooling_config is not None:
self._pooler = Pooler(pooling_config.pooling_type,
pooling_config.normalize)
else:
self._pooler = Pooler(pooling_type=PoolingType.ALL,
normalize=False)

self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
Expand Down

0 comments on commit c3166f1

Please sign in to comment.