Skip to content

Commit

Permalink
feat: nomic embed v1.5 hf embeddings (#10762)
Browse files Browse the repository at this point in the history
  • Loading branch information
zanussbaum authored Feb 15, 2024
1 parent b06c888 commit b5e96d4
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -1,10 +1,23 @@
from enum import Enum
from typing import Any, List, Optional
from typing import Any, List, Optional, Union

from llama_index.core.base.embeddings.base import BaseEmbedding
from llama_index.core.base.embeddings.base import (
BaseEmbedding,
DEFAULT_EMBED_BATCH_SIZE,
)
from llama_index.core.bridge.pydantic import Field, PrivateAttr
from llama_index.core.callbacks import CallbackManager

from llama_index.embeddings.huggingface import HuggingFaceEmbedding

from llama_index.core.bridge.pydantic import Field, PrivateAttr
from llama_index.embeddings.huggingface.pooling import Pooling
import torch
import logging

DEFAULT_HUGGINGFACE_LENGTH = 512
logger = logging.getLogger(__name__)


class NomicAITaskType(str, Enum):
SEARCH_QUERY = "search_query"
Expand Down Expand Up @@ -109,3 +122,106 @@ async def _aget_text_embedding(self, text: str) -> List[float]:
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Get text embeddings."""
return self._embed(texts, task_type=self.document_task_type)


class NomicHFEmbedding(HuggingFaceEmbedding):
tokenizer_name: str = Field(description="Tokenizer name from HuggingFace.")
max_length: int = Field(
default=DEFAULT_HUGGINGFACE_LENGTH, description="Maximum length of input.", gt=0
)
pooling: Pooling = Field(default=Pooling.MEAN, description="Pooling strategy.")
normalize: bool = Field(default=True, description="Normalize embeddings or not.")
query_instruction: Optional[str] = Field(
description="Instruction to prepend to query text."
)
text_instruction: Optional[str] = Field(
description="Instruction to prepend to text."
)
cache_folder: Optional[str] = Field(
description="Cache folder for huggingface files."
)
dimensionality: Optional[int] = Field(description="Dimensionality of embedding")

_model: Any = PrivateAttr()
_tokenizer: Any = PrivateAttr()
_device: str = PrivateAttr()

def __init__(
self,
model_name: Optional[str] = None,
tokenizer_name: Optional[str] = None,
pooling: Union[str, Pooling] = "cls",
max_length: Optional[int] = None,
query_instruction: Optional[str] = None,
text_instruction: Optional[str] = None,
normalize: bool = True,
model: Optional[Any] = None,
tokenizer: Optional[Any] = None,
embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
cache_folder: Optional[str] = None,
trust_remote_code: bool = False,
device: Optional[str] = None,
callback_manager: Optional[CallbackManager] = None,
dimensionality: int = 768,
):
super().__init__(
model_name=model_name,
tokenizer_name=tokenizer_name,
pooling=pooling,
max_length=max_length,
query_instruction=query_instruction,
text_instruction=text_instruction,
normalize=normalize,
model=model,
tokenizer=tokenizer,
embed_batch_size=embed_batch_size,
cache_folder=cache_folder,
trust_remote_code=trust_remote_code,
device=device,
callback_manager=callback_manager,
)
self.dimensionality = dimensionality
self._model.eval()

def _embed(self, sentences: List[str]) -> List[List[float]]:
"""Embed sentences."""
encoded_input = self._tokenizer(
sentences,
padding=True,
max_length=self.max_length,
truncation=True,
return_tensors="pt",
)

# pop token_type_ids
encoded_input.pop("token_type_ids", None)

# move tokenizer inputs to device
encoded_input = {
key: val.to(self._device) for key, val in encoded_input.items()
}

with torch.no_grad():
model_output = self._model(**encoded_input)

if self.pooling == Pooling.CLS:
context_layer: "torch.Tensor" = model_output[0]
embeddings = self.pooling.cls_pooling(context_layer)
else:
embeddings = self._mean_pooling(
token_embeddings=model_output[0],
attention_mask=encoded_input["attention_mask"],
)

if self.normalize:
import torch.nn.functional as F

if self.model_name == "nomic-ai/nomic-embed-text-v1.5":
emb_ln = F.layer_norm(
embeddings, normalized_shape=(embeddings.shape[1],)
)
embeddings = emb_ln[:, : self.dimensionality]

embeddings = F.normalize(embeddings, p=2, dim=1)

return embeddings.tolist()
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ version = "0.1.1"
[tool.poetry.dependencies]
python = ">=3.8.1,<3.12"
llama-index-core = "^0.10.1"
llama-index-embeddings-huggingface = "^0.1.0"
einops = "^0.7.0"
nomic = "^3.0.12"

[tool.poetry.group.dev.dependencies]
ipython = "8.10.0"
Expand Down

0 comments on commit b5e96d4

Please sign in to comment.