diff --git a/vllm/config.py b/vllm/config.py index e8462421b73fc..0540dc4170584 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -421,7 +421,7 @@ def _verify_bnb_config(self) -> None: self.enforce_eager = True def get_pooling_type(self, - pooling_type_name: str) -> Optional[PoolingType]: + pooling_type_name: str) -> Union[PoolingType, None]: pooling_types = {i.name: i for i in PoolingType} return pooling_types.get(pooling_type_name) @@ -433,8 +433,8 @@ def get_pooling_config( pooling_type = self.get_pooling_type( pooling_config["pooling_type"]) normalize = pooling_config["normalize"] - pooling_config = PoolingConfig( - pooling_type=PoolingType(pooling_type), normalize=normalize) + pooling_config = PoolingConfig(pooling_type=pooling_type, + normalize=normalize) if pooling_type_arg is not None: pooling_config.pooling_type = self.get_pooling_type( pooling_type_arg)