Skip to content

Commit

Permalink
Adds method to read the pooling types from model's files
Browse files Browse the repository at this point in the history
Signed-off-by: Flavia Beo <[email protected]>
  • Loading branch information
flaviabeo committed Oct 21, 2024
1 parent 25aeb7d commit 03d72a7
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 14 deletions.
9 changes: 9 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.model_executor.layers.pooler import PoolingConfig
from vllm.model_executor.models import ModelRegistry
from vllm.platforms import current_platform
from vllm.tracing import is_otel_available, otel_import_error_traceback
from vllm.transformers_utils.config import (ConfigFormat, get_config,
get_hf_image_processor_config,
get_pooling_config,
get_hf_text_config)
from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory,
is_hip, is_neuron, is_openvino, is_xpu,
Expand Down Expand Up @@ -163,6 +165,7 @@ def __init__(self,
code_revision, rope_scaling, rope_theta,
config_format)
self.hf_text_config = get_hf_text_config(self.hf_config)
self.pooling_config = self.get_pooling_config()
self.hf_image_processor_config = get_hf_image_processor_config(
self.model, revision)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
Expand Down Expand Up @@ -362,6 +365,12 @@ def _verify_bnb_config(self) -> None:
"fallback to the eager mode.")
self.enforce_eager = True

def get_pooling_config(self) -> PoolingConfig:
pooling_config = get_pooling_config(self.model,
self.revision)
return PoolingConfig(pooling_config["pooling_type"],
pooling_config["normalize"])

def verify_async_output_proc(self, parallel_config, speculative_config,
device_config) -> None:
if not self.use_async_output_proc:
Expand Down
3 changes: 2 additions & 1 deletion vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def __init__(
"num_scheduler_steps=%d, chunked_prefill_enabled=%s "
"multi_step_stream_outputs=%s, enable_prefix_caching=%s, "
"use_async_output_proc=%s, use_cached_outputs=%s, "
"mm_processor_kwargs=%s)",
"mm_processor_kwargs=%s, pooling_config=%s)",
VLLM_VERSION,
model_config.model,
speculative_config,
Expand Down Expand Up @@ -287,6 +287,7 @@ def __init__(
model_config.use_async_output_proc,
use_cached_outputs,
model_config.mm_processor_kwargs,
model_config.pooling_config
)
# TODO(woosuk): Print more configs in debug mode.
self.model_config = model_config
Expand Down
25 changes: 25 additions & 0 deletions vllm/model_executor/layers/pooler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,31 @@ class PoolingType(IntEnum):
LAST = 0
ALL = 1
CLS = 2
MEAN = 3
MAX = 4


class PoolingConfig():
"""A class that configures the pooling operation.
Attributes:
pooling_type (str): The type of pooling to use.
normalize (bool): Whether to normalize the pooled data.
Methods:
get_pooling_type(pooling_type_name): Returns the pooling
type enum value corresponding to the given string.
"""
def __init__(self, pooling_type: str, normalize: bool):
self.pooling_type = self.get_pooling_type(pooling_type)
self.normalize = normalize

def get_pooling_type(self, pooling_type_name: str) -> PoolingType:
pooling_types = PoolingType.__dict__.items()
return PoolingType(next((value for key,
value in pooling_types if key.lower()
in pooling_type_name),
2))


class Pooler(nn.Module):
Expand Down
12 changes: 9 additions & 3 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.pooler import PoolingConfig
from vllm.model_executor.model_loader.tensorizer import (
TensorizerConfig, is_vllm_tensorized, load_with_tensorizer,
serialize_vllm_model, tensorizer_weights_iterator)
Expand Down Expand Up @@ -122,7 +123,8 @@ def _get_model_initialization_kwargs(
model_class: Type[nn.Module],
lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
scheduler_config: Optional[SchedulerConfig] = None) -> Dict[str, Any]:
scheduler_config: Optional[SchedulerConfig] = None
) -> Dict[str, Any]:
"""Get extra kwargs for model initialization."""
extra_kwargs: Dict[str, Any] = {}

Expand Down Expand Up @@ -152,14 +154,17 @@ def build_model(model_class: Type[nn.Module], hf_config: PretrainedConfig,
quant_config: Optional[QuantizationConfig], *,
lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
scheduler_config: Optional[SchedulerConfig]) -> nn.Module:
scheduler_config: Optional[SchedulerConfig],
pooling_config: Optional[PoolingConfig] = None) -> nn.Module:
extra_kwargs = _get_model_initialization_kwargs(model_class, lora_config,
multimodal_config,
scheduler_config)
scheduler_config
)

return model_class(config=hf_config,
cache_config=cache_config,
quant_config=quant_config,
pooling_config=pooling_config,
**extra_kwargs)


Expand All @@ -180,6 +185,7 @@ def _initialize_model(
lora_config=lora_config,
multimodal_config=model_config.multimodal_config,
scheduler_config=scheduler_config,
pooling_config=model_config.pooling_config
)


Expand Down
8 changes: 6 additions & 2 deletions vllm/model_executor/models/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.pooler import Pooler, PoolingConfig
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import (
Expand Down Expand Up @@ -387,10 +387,14 @@ def __init__(
config: BertConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
pooling_config: Optional[PoolingConfig] = None
) -> None:
super().__init__()
self.model = BertModel(config, cache_config, quant_config)
self._pooler = Pooler(pooling_type=PoolingType.CLS, normalize=True)
print(pooling_config.pooling_type)
print(pooling_config.normalize)
self._pooler = Pooler(pooling_config.pooling_type,
pooling_config.normalize)

def forward(
self,
Expand Down
75 changes: 67 additions & 8 deletions vllm/transformers_utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,21 +231,80 @@ def get_config(

return config

def get_hf_file_to_dict(file_name, model, revision):
"""
Downloads a file from the Hugging Face Hub and returns
its contents as a dictionary.
Parameters:
- file_name (str): The name of the file to download.
- model (str): The name of the model on the Hugging Face Hub.
- revision (str): The specific version of the model.
Returns:
- config_dict (dict): A dictionary containing
the contents of the downloaded file.
"""
file_path = Path(model) / file_name

if not file_path.is_file():
file_path = Path(
hf_hub_download(model, file_name, revision=revision))

with open(file_path, "r") as file:
config_dict = json.load(file)

return config_dict

def get_pooling_config(model, revision='main'):
"""
This function gets the pooling and normalize
config from the model.
Args:
model (str): The name of the Hugging Face model.
revision (str, optional): The specific version
of the model to use. Defaults to 'main'.
Returns:
dict: A dictionary containing the pooling
type and whether normalization is used.
"""

modules_file_name = "modules.json"
modules_dict = get_hf_file_to_dict(modules_file_name, model, revision)

pooling = next((item for item in modules_dict if
item["type"] == "sentence_transformers.models.Pooling"),
None)
normalize = next((item for item in modules_dict if
item["type"] ==
"sentence_transformers.models.Normalize"),
False)

if pooling:

pooling_file_name = "{}/config.json".format(pooling["path"])
pooling_dict = get_hf_file_to_dict(pooling_file_name, model, revision)
pooling_type_name = next((item for item,
val in pooling_dict.items() if val is True),
None)

return {
"pooling_type": pooling_type_name,
"normalize": normalize
}

return None


def load_params_config(model, revision) -> PretrainedConfig:
# This function loads a params.json config which
# should be used when loading models in mistral format

config_file_name = "params.json"

config_path = Path(model) / config_file_name

if not config_path.is_file():
config_path = Path(
hf_hub_download(model, config_file_name, revision=revision))

with open(config_path, "r") as file:
config_dict = json.load(file)
config_dict = get_hf_file_to_dict(config_file_name, model, revision)

config_mapping = {
"dim": "hidden_size",
Expand Down

0 comments on commit 03d72a7

Please sign in to comment.