Skip to content

Commit

Permalink
Moves logic to ModelConfig
Browse files Browse the repository at this point in the history
Signed-off-by: Flavia Beo <[email protected]>
  • Loading branch information
flaviabeo committed Oct 18, 2024
1 parent 3a37e64 commit 80f5874
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 14 deletions.
13 changes: 13 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.models import ModelRegistry
from vllm.platforms import current_platform
from vllm.tracing import is_otel_available, otel_import_error_traceback
from vllm.transformers_utils.config import (ConfigFormat, get_config,
get_hf_image_processor_config,
get_pooling_config,
get_hf_text_config)
from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory,
is_hip, is_neuron, is_openvino, is_xpu,
Expand Down Expand Up @@ -163,6 +165,7 @@ def __init__(self,
code_revision, rope_scaling, rope_theta,
config_format)
self.hf_text_config = get_hf_text_config(self.hf_config)
self.pooling_type = self.get_pooling_type()
self.hf_image_processor_config = get_hf_image_processor_config(
self.model, revision)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
Expand Down Expand Up @@ -362,6 +365,16 @@ def _verify_bnb_config(self) -> None:
"fallback to the eager mode.")
self.enforce_eager = True

def get_pooling_type(self) -> PoolingType:
pooling_type_name = get_pooling_config(self.model,
self.revision)
pooling_types = PoolingType.__dict__.items()
pooling_type = next((value for key,
value in pooling_types if key.lower()
in pooling_type_name),
None)
return PoolingType(pooling_type)

def verify_async_output_proc(self, parallel_config, speculative_config,
device_config) -> None:
if not self.use_async_output_proc:
Expand Down
3 changes: 2 additions & 1 deletion vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def __init__(
"num_scheduler_steps=%d, chunked_prefill_enabled=%s "
"multi_step_stream_outputs=%s, enable_prefix_caching=%s, "
"use_async_output_proc=%s, use_cached_outputs=%s, "
"mm_processor_kwargs=%s)",
"mm_processor_kwargs=%s, pooling_type=%s)",
VLLM_VERSION,
model_config.model,
speculative_config,
Expand Down Expand Up @@ -287,6 +287,7 @@ def __init__(
model_config.use_async_output_proc,
use_cached_outputs,
model_config.mm_processor_kwargs,
model_config.pooling_type
)
# TODO(woosuk): Print more configs in debug mode.
self.model_config = model_config
Expand Down
19 changes: 6 additions & 13 deletions vllm/model_executor/models/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.attention.backends.xformers import XFormersImpl
from vllm.config import CacheConfig
from vllm.config import CacheConfig, ModelConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
Expand All @@ -21,7 +21,6 @@
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput

from vllm.transformers_utils.config import get_pooling_config

class BertEmbedding(nn.Module):

Expand Down Expand Up @@ -391,8 +390,11 @@ def __init__(
) -> None:
super().__init__()
self.model = BertModel(config, cache_config, quant_config)
self.pooling_type = self.get_pooling_type()
self._pooler = Pooler(pooling_type=self.pooling_type, normalize=True)
print("====================================================")
print(cache_config.__dict__)
# self.pooling_type = ModelConfig.get_pooling_type(config)
# self._pooler = Pooler(pooling_type=self.pooling_type, normalize=True)
self._pooler = Pooler(PoolingType.CLS, normalize=True)

def forward(
self,
Expand All @@ -419,12 +421,3 @@ def pooler(

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
self.model.load_weights(weights)

def get_pooling_type(self):
pooling_type_name = get_pooling_config(self.model)
pooling_types = PoolingType.__dict__.items()
pooling_type = next((value for key,
value in pooling_types if key.lower()
in pooling_type_name),
None)
return PoolingType(pooling_type)

0 comments on commit 80f5874

Please sign in to comment.