diff --git a/llama_index/embeddings/huggingface.py b/llama_index/embeddings/huggingface.py index 56da5bb27fbd2..20842d6fdb413 100644 --- a/llama_index/embeddings/huggingface.py +++ b/llama_index/embeddings/huggingface.py @@ -1,5 +1,5 @@ import asyncio -from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Union +from typing import TYPE_CHECKING, Any, List, Optional, Sequence from llama_index.bridge.pydantic import Field, PrivateAttr from llama_index.callbacks import CallbackManager @@ -12,6 +12,7 @@ DEFAULT_HUGGINGFACE_EMBEDDING_MODEL, format_query, format_text, + get_pooling_mode, ) from llama_index.embeddings.pooling import Pooling from llama_index.llms.huggingface import HuggingFaceInferenceAPI @@ -28,7 +29,7 @@ class HuggingFaceEmbedding(BaseEmbedding): max_length: int = Field( default=DEFAULT_HUGGINGFACE_LENGTH, description="Maximum length of input.", gt=0 ) - pooling: Pooling = Field(default=Pooling.CLS, description="Pooling strategy.") + pooling: Pooling = Field(default=None, description="Pooling strategy.") normalize: bool = Field(default=True, description="Normalize embeddings or not.") query_instruction: Optional[str] = Field( description="Instruction to prepend to query text." @@ -48,7 +49,7 @@ def __init__( self, model_name: Optional[str] = None, tokenizer_name: Optional[str] = None, - pooling: Union[str, Pooling] = "cls", + pooling: Optional[str] = None, max_length: Optional[int] = None, query_instruction: Optional[str] = None, text_instruction: Optional[str] = None, @@ -105,14 +106,15 @@ def __init__( "Unable to find max_length from model config. Please specify max_length." ) from exc - if isinstance(pooling, str): - try: - pooling = Pooling(pooling) - except ValueError as exc: - raise NotImplementedError( - f"Pooling {pooling} unsupported, please pick one in" - f" {[p.value for p in Pooling]}." - ) from exc + if not pooling: + pooling = get_pooling_mode(model_name) + try: + pooling = Pooling(pooling) + except ValueError as exc: + raise NotImplementedError( + f"Pooling {pooling} unsupported, please pick one in" + f" {[p.value for p in Pooling]}." + ) from exc super().__init__( embed_batch_size=embed_batch_size, diff --git a/llama_index/embeddings/huggingface_optimum.py b/llama_index/embeddings/huggingface_optimum.py index 668dd58386795..6c69e37f97327 100644 --- a/llama_index/embeddings/huggingface_optimum.py +++ b/llama_index/embeddings/huggingface_optimum.py @@ -3,7 +3,12 @@ from llama_index.bridge.pydantic import Field, PrivateAttr from llama_index.callbacks import CallbackManager from llama_index.core.embeddings.base import DEFAULT_EMBED_BATCH_SIZE, BaseEmbedding -from llama_index.embeddings.huggingface_utils import format_query, format_text +from llama_index.embeddings.huggingface_utils import ( + format_query, + format_text, + get_pooling_mode, +) +from llama_index.embeddings.pooling import Pooling from llama_index.utils import infer_torch_device @@ -29,7 +34,7 @@ class OptimumEmbedding(BaseEmbedding): def __init__( self, folder_name: str, - pooling: str = "cls", + pooling: Optional[str] = None, max_length: Optional[int] = None, normalize: bool = True, query_instruction: Optional[str] = None, @@ -63,8 +68,15 @@ def __init__( "Please provide max_length." ) - if pooling not in ["cls", "mean"]: - raise ValueError(f"Pooling {pooling} not supported.") + if not pooling: + pooling = get_pooling_mode(model) + try: + pooling = Pooling(pooling) + except ValueError as exc: + raise NotImplementedError( + f"Pooling {pooling} unsupported, please pick one in" + f" {[p.value for p in Pooling]}." + ) from exc super().__init__( embed_batch_size=embed_batch_size, diff --git a/llama_index/embeddings/huggingface_utils.py b/llama_index/embeddings/huggingface_utils.py index 606bced13b6b6..009aaab7649ba 100644 --- a/llama_index/embeddings/huggingface_utils.py +++ b/llama_index/embeddings/huggingface_utils.py @@ -1,5 +1,7 @@ from typing import Optional +import requests + DEFAULT_HUGGINGFACE_EMBEDDING_MODEL = "BAAI/bge-small-en" DEFAULT_INSTRUCT_MODEL = "hkunlp/instructor-base" @@ -72,3 +74,26 @@ def format_text( # NOTE: strip() enables backdoor for defeating instruction prepend by # passing empty string return f"{instruction} {text}".strip() + + +def get_pooling_mode(model_name: Optional[str]) -> str: + pooling_config_url = ( + f"https://huggingface.co/{model_name}/raw/main/1_Pooling/config.json" + ) + + try: + response = requests.get(pooling_config_url) + config_data = response.json() + + cls_token = config_data.get("pooling_mode_cls_token", False) + mean_tokens = config_data.get("pooling_mode_mean_tokens", False) + + if mean_tokens: + return "mean" + elif cls_token: + return "cls" + except requests.exceptions.RequestException: + print( + "Warning: Pooling config file not found; pooling mode is defaulted to 'cls'." + ) + return "cls"