Skip to content

Commit

Permalink
Update pooling strategy for embeding models (#10536)
Browse files Browse the repository at this point in the history
Update pooling strategy for embediing models
  • Loading branch information
ravi03071991 authored Feb 8, 2024
1 parent c5daa1d commit 22ef01d
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 15 deletions.
24 changes: 13 additions & 11 deletions llama_index/embeddings/huggingface.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import asyncio
from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Union
from typing import TYPE_CHECKING, Any, List, Optional, Sequence

from llama_index.bridge.pydantic import Field, PrivateAttr
from llama_index.callbacks import CallbackManager
Expand All @@ -12,6 +12,7 @@
DEFAULT_HUGGINGFACE_EMBEDDING_MODEL,
format_query,
format_text,
get_pooling_mode,
)
from llama_index.embeddings.pooling import Pooling
from llama_index.llms.huggingface import HuggingFaceInferenceAPI
Expand All @@ -28,7 +29,7 @@ class HuggingFaceEmbedding(BaseEmbedding):
max_length: int = Field(
default=DEFAULT_HUGGINGFACE_LENGTH, description="Maximum length of input.", gt=0
)
pooling: Pooling = Field(default=Pooling.CLS, description="Pooling strategy.")
pooling: Pooling = Field(default=None, description="Pooling strategy.")
normalize: bool = Field(default=True, description="Normalize embeddings or not.")
query_instruction: Optional[str] = Field(
description="Instruction to prepend to query text."
Expand All @@ -48,7 +49,7 @@ def __init__(
self,
model_name: Optional[str] = None,
tokenizer_name: Optional[str] = None,
pooling: Union[str, Pooling] = "cls",
pooling: Optional[str] = None,
max_length: Optional[int] = None,
query_instruction: Optional[str] = None,
text_instruction: Optional[str] = None,
Expand Down Expand Up @@ -105,14 +106,15 @@ def __init__(
"Unable to find max_length from model config. Please specify max_length."
) from exc

if isinstance(pooling, str):
try:
pooling = Pooling(pooling)
except ValueError as exc:
raise NotImplementedError(
f"Pooling {pooling} unsupported, please pick one in"
f" {[p.value for p in Pooling]}."
) from exc
if not pooling:
pooling = get_pooling_mode(model_name)
try:
pooling = Pooling(pooling)
except ValueError as exc:
raise NotImplementedError(
f"Pooling {pooling} unsupported, please pick one in"
f" {[p.value for p in Pooling]}."
) from exc

super().__init__(
embed_batch_size=embed_batch_size,
Expand Down
20 changes: 16 additions & 4 deletions llama_index/embeddings/huggingface_optimum.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@
from llama_index.bridge.pydantic import Field, PrivateAttr
from llama_index.callbacks import CallbackManager
from llama_index.core.embeddings.base import DEFAULT_EMBED_BATCH_SIZE, BaseEmbedding
from llama_index.embeddings.huggingface_utils import format_query, format_text
from llama_index.embeddings.huggingface_utils import (
format_query,
format_text,
get_pooling_mode,
)
from llama_index.embeddings.pooling import Pooling
from llama_index.utils import infer_torch_device


Expand All @@ -29,7 +34,7 @@ class OptimumEmbedding(BaseEmbedding):
def __init__(
self,
folder_name: str,
pooling: str = "cls",
pooling: Optional[str] = None,
max_length: Optional[int] = None,
normalize: bool = True,
query_instruction: Optional[str] = None,
Expand Down Expand Up @@ -63,8 +68,15 @@ def __init__(
"Please provide max_length."
)

if pooling not in ["cls", "mean"]:
raise ValueError(f"Pooling {pooling} not supported.")
if not pooling:
pooling = get_pooling_mode(model)
try:
pooling = Pooling(pooling)
except ValueError as exc:
raise NotImplementedError(
f"Pooling {pooling} unsupported, please pick one in"
f" {[p.value for p in Pooling]}."
) from exc

super().__init__(
embed_batch_size=embed_batch_size,
Expand Down
25 changes: 25 additions & 0 deletions llama_index/embeddings/huggingface_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Optional

import requests

DEFAULT_HUGGINGFACE_EMBEDDING_MODEL = "BAAI/bge-small-en"
DEFAULT_INSTRUCT_MODEL = "hkunlp/instructor-base"

Expand Down Expand Up @@ -72,3 +74,26 @@ def format_text(
# NOTE: strip() enables backdoor for defeating instruction prepend by
# passing empty string
return f"{instruction} {text}".strip()


def get_pooling_mode(model_name: Optional[str]) -> str:
pooling_config_url = (
f"https://huggingface.co/{model_name}/raw/main/1_Pooling/config.json"
)

try:
response = requests.get(pooling_config_url)
config_data = response.json()

cls_token = config_data.get("pooling_mode_cls_token", False)
mean_tokens = config_data.get("pooling_mode_mean_tokens", False)

if mean_tokens:
return "mean"
elif cls_token:
return "cls"
except requests.exceptions.RequestException:
print(
"Warning: Pooling config file not found; pooling mode is defaulted to 'cls'."
)
return "cls"

0 comments on commit 22ef01d

Please sign in to comment.