Skip to content

Commit

Permalink
Method to treat the pooling name string from file
Browse files Browse the repository at this point in the history
Signed-off-by: Flavia Beo <[email protected]>
  • Loading branch information
flaviabeo committed Oct 28, 2024
1 parent 2cd2450 commit 9c32660
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 2 deletions.
12 changes: 10 additions & 2 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,8 +421,11 @@ def _verify_bnb_config(self) -> None:
self.enforce_eager = True

def get_pooling_type(self, pooling_type_name: str) -> PoolingType:
mapping = {i.name: i for i in PoolingType}
return mapping[pooling_type_name.strip().upper()]
pooling_types = {i.name: i for i in PoolingType}
pooling_type = pooling_types[pooling_type_name]
if pooling_type:
return PoolingType(pooling_type)
return None

def get_pooling_config(
self, pooling_type_arg: Optional[str],
Expand All @@ -433,6 +436,11 @@ def get_pooling_config(
if pooling_config is not None:
pooling_type = pooling_config["pooling_type"]
normalize = pooling_config["normalize"]
if not pooling_type_arg and not normalize_arg:
return PoolingConfig(
pooling_type=self.get_pooling_type(pooling_type),
normalize=normalize
)
if pooling_type_arg:
pooling_type = pooling_type_arg
if normalize_arg:
Expand Down
17 changes: 17 additions & 0 deletions vllm/transformers_utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,12 +317,29 @@ def get_pooling_config(model, revision='main', token: Optional[str] = None):
token)
pooling_type_name = next(
(item for item, val in pooling_dict.items() if val is True), None)

pooling_type_name = get_pooling_config_name(pooling_type_name)

return {"pooling_type": pooling_type_name, "normalize": normalize}

return None


def get_pooling_config_name(pooling_name: str):
if "pooling_mode_" in pooling_name:
pooling_name = pooling_name.replace("pooling_mode_", "")

if "_" in pooling_name:
pooling_name = pooling_name.split("_")[0]

if "lasttoken" in pooling_name:
pooling_name = "last"

pooling_type_name = pooling_name.upper()

return pooling_type_name


def get_sentence_transformer_tokenizer_config(model,
revision='main',
token: Optional[str] = None):
Expand Down

0 comments on commit 9c32660

Please sign in to comment.