diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index d7cb111742836..221bb77434868 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -40,7 +40,7 @@ def get_pooling_type(self, pooling_type_name: str) -> PoolingType: pooling_types = PoolingType.__dict__.items() return PoolingType( next((value for key, value in pooling_types - if key.lower() in pooling_type_name), 2)) + if key.lower() in pooling_type_name), PoolingType.CLS)) class Pooler(nn.Module):