Skip to content

Commit

Permalink
Adds method to read the pooling types from model's files
Browse files Browse the repository at this point in the history
Signed-off-by: Flavia Beo <[email protected]>
  • Loading branch information
flaviabeo committed Oct 18, 2024
1 parent 25aeb7d commit 4f3dd00
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 9 deletions.
2 changes: 2 additions & 0 deletions vllm/model_executor/layers/pooler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ class PoolingType(IntEnum):
LAST = 0
ALL = 1
CLS = 2
MEAN = 3
MAX = 4


class Pooler(nn.Module):
Expand Down
10 changes: 9 additions & 1 deletion vllm/model_executor/models/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput

from vllm.transformers_utils.config import get_pooling_config

class BertEmbedding(nn.Module):

Expand Down Expand Up @@ -390,7 +391,8 @@ def __init__(
) -> None:
super().__init__()
self.model = BertModel(config, cache_config, quant_config)
self._pooler = Pooler(pooling_type=PoolingType.CLS, normalize=True)
self.pooling_type = get_pooling_type()
self._pooler = Pooler(pooling_type=self.pooling_type, normalize=True)

def forward(
self,
Expand All @@ -417,3 +419,9 @@ def pooler(

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
self.model.load_weights(weights)

def get_pooling_type():
pooling_type_name = get_pooling_config(self.model)
pooling_types = PoolingType.__dict__.items()
pooling_type = next((value for key, value in pooling_types if key.lower() in pooling_type_name), None)
return PoolingType(pooling_type)
58 changes: 50 additions & 8 deletions vllm/transformers_utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,21 +231,63 @@ def get_config(

return config

def get_hf_file_to_dict(file_name, model, revision):
"""
Downloads a file from the Hugging Face Hub and returns its contents as a dictionary.
Parameters:
- file_name (str): The name of the file to download.
- model (str): The name of the model on the Hugging Face Hub.
- revision (str): The specific version of the model.
Returns:
- config_dict (dict): A dictionary containing the contents of the downloaded file.
"""
file_path = Path(model) / file_name

if not file_path.is_file():
file_path = Path(
hf_hub_download(model, file_name, revision=revision))

with open(file_path, "r") as file:
config_dict = json.load(file)

return config_dict

def get_pooling_config(model, revision='main'):
"""
This function gets the pooling config from the model
Args:
model (str): The name of the Hugging Face model.
revision (str): The specific version of the model.
Default value is 'main'.
Returns:
str: The type of pooling used in the model, or None if no pooling is found.
"""

modules_file_name = "modules.json"
modules_dict = get_hf_file_to_dict(modules_file_name, model, revision)

if next((item for item in modules_dict if item["path"] == "1_Pooling"), None):

pooling_file_name = "1_Pooling/config.json"
pooling_dict = get_hf_file_to_dict(pooling_file_name, model, revision)
pooling_type_name = next((item for item, val in pooling_dict.items() if val == True), None)

return pooling_type_name

return None


def load_params_config(model, revision) -> PretrainedConfig:
# This function loads a params.json config which
# should be used when loading models in mistral format

config_file_name = "params.json"

config_path = Path(model) / config_file_name

if not config_path.is_file():
config_path = Path(
hf_hub_download(model, config_file_name, revision=revision))

with open(config_path, "r") as file:
config_dict = json.load(file)
config_dict = get_hf_file_to_dict(config_file_name, model, revision)

config_mapping = {
"dim": "hidden_size",
Expand Down

0 comments on commit 4f3dd00

Please sign in to comment.