From 9c326601bf86b93243c9dc9dde4e86b333303838 Mon Sep 17 00:00:00 2001 From: Flavia Beo Date: Mon, 28 Oct 2024 17:42:09 -0300 Subject: [PATCH] Method to treat the pooling name string from file Signed-off-by: Flavia Beo --- vllm/config.py | 12 ++++++++++-- vllm/transformers_utils/config.py | 17 +++++++++++++++++ 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 4ea783356d4e3..01c4a7af0fc4d 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -421,8 +421,11 @@ def _verify_bnb_config(self) -> None: self.enforce_eager = True def get_pooling_type(self, pooling_type_name: str) -> PoolingType: - mapping = {i.name: i for i in PoolingType} - return mapping[pooling_type_name.strip().upper()] + pooling_types = {i.name: i for i in PoolingType} + pooling_type = pooling_types[pooling_type_name] + if pooling_type: + return PoolingType(pooling_type) + return None def get_pooling_config( self, pooling_type_arg: Optional[str], @@ -433,6 +436,11 @@ def get_pooling_config( if pooling_config is not None: pooling_type = pooling_config["pooling_type"] normalize = pooling_config["normalize"] + if not pooling_type_arg and not normalize_arg: + return PoolingConfig( + pooling_type=self.get_pooling_type(pooling_type), + normalize=normalize + ) if pooling_type_arg: pooling_type = pooling_type_arg if normalize_arg: diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 662e0374abc43..e2adbb5a3da05 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -317,12 +317,29 @@ def get_pooling_config(model, revision='main', token: Optional[str] = None): token) pooling_type_name = next( (item for item, val in pooling_dict.items() if val is True), None) + + pooling_type_name = get_pooling_config_name(pooling_type_name) return {"pooling_type": pooling_type_name, "normalize": normalize} return None +def get_pooling_config_name(pooling_name: str): + if "pooling_mode_" in pooling_name: + pooling_name = pooling_name.replace("pooling_mode_", "") + + if "_" in pooling_name: + pooling_name = pooling_name.split("_")[0] + + if "lasttoken" in pooling_name: + pooling_name = "last" + + pooling_type_name = pooling_name.upper() + + return pooling_type_name + + def get_sentence_transformer_tokenizer_config(model, revision='main', token: Optional[str] = None):