-
Notifications
You must be signed in to change notification settings - Fork 109
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- cohere embeddings support - added to the registry for autoembeddings - added tests
- Loading branch information
Showing
9 changed files
with
398 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,203 @@ | ||
import importlib | ||
import os | ||
import warnings | ||
from typing import List, Optional | ||
|
||
import numpy as np | ||
|
||
from .base import BaseEmbeddings | ||
|
||
class CohereEmbeddings(BaseEmbeddings): | ||
"""Cohere embeddings implementation using their API""" | ||
|
||
AVAILABLE_MODELS = { | ||
# cohere v3.0 models | ||
"embed-english-v3.0": (True, 1024), # tokenizer from tokenizers | ||
"embed-multilingual-v3.0": (False, 1024), # not listed in the cohere models api list | ||
"embed-english-light-v3.0": (True, 384), # from tokenizers | ||
"embed-multilingual-light-v3.0": (False, 384), # not listed in the cohere models api list | ||
# cohere v2.0 models | ||
"embed-english-v2.0": (False, 4096), # url is not available in the cohere models api list | ||
"embed-english-light-v2.0": (False, 1024), # not listed in the models list | ||
"embed-multilingual-v2.0": (True, 768) # from tokenizers | ||
} | ||
|
||
DEFAULT_MODEL = "embed-english-light-v3.0" | ||
TOKENIZER_BASE_URL = "https://storage.googleapis.com/cohere-public/tokenizers/" | ||
|
||
def __init__( | ||
self, | ||
model: str = DEFAULT_MODEL, | ||
api_key: Optional[str] = None, | ||
client_name: Optional[str] = None, | ||
max_retries: int = 3, | ||
timeout: float = 60.0, | ||
batch_size: int = 128, | ||
show_warnings: bool = True | ||
): | ||
"""Initialize Cohere embeddings. | ||
Args: | ||
model: name of the Cohere embedding model to use | ||
api_key: (optional) Cohere API key (if not provided, looks for COHERE_API_KEY environment variable) | ||
organization: (optional) client name for API requests | ||
max_retries: maximum number of retries for failed requests | ||
timeout: timeout in seconds for API requests | ||
batch_size: maximum number of texts to embed in one API call (maximum allowed by Cohere is 96) | ||
show_warnings: whether to show warnings about token usage and truncation | ||
""" | ||
|
||
super().__init__() | ||
if not self.is_available(): | ||
raise ImportError( | ||
"Cohere package is not available. Please install it via pip." | ||
) | ||
else: | ||
global cohere | ||
import tokenizers | ||
import requests | ||
from cohere import ClientV2 # using v2 | ||
|
||
if model not in self.AVAILABLE_MODELS: | ||
raise ValueError( | ||
f"Model {model} is not available. Choose from: {list(self.AVAILABLE_MODELS.keys())}" | ||
) | ||
|
||
self.model = model | ||
self._dimension = self.AVAILABLE_MODELS[model][1] | ||
tokenizer_url = self.TOKENIZER_BASE_URL + (model if self.AVAILABLE_MODELS[model][0] else self.DEFAULT_MODEL) + ".json" | ||
response = requests.get(tokenizer_url) | ||
self._tokenizer = tokenizers.Tokenizer.from_str(response.text) | ||
self._batch_size = min(batch_size, 96) # max batch size for cohere is 96 | ||
self._show_warnings = show_warnings | ||
self._max_retries = max_retries | ||
self._api_key = api_key or os.getenv("COHERE_API_KEY") | ||
|
||
if self._api_key is None: | ||
raise ValueError( | ||
"Cohere API key not found. Either pass it as api_key or set COHERE_API_KEY environment variable." | ||
) | ||
|
||
# setup Cohere client | ||
self.client = ClientV2( | ||
api_key=api_key or os.getenv("COHERE_API_KEY"), | ||
client_name=client_name, | ||
timeout=timeout | ||
) | ||
|
||
|
||
def embed(self, text: str) -> np.ndarray: | ||
"""Generate embeddings for a single text""" | ||
token_count = self.count_tokens(text) | ||
if token_count > 512 and self._show_warnings: # Cohere models max_context_length | ||
warnings.warn( | ||
f"Text has {token_count} tokens which exceeds the model's context length of 512." | ||
"Generation may not be optimal" | ||
) | ||
|
||
for _ in range(self._max_retries): | ||
try: | ||
response = self.client.embed( | ||
model=self.model, | ||
input_type="search_document", | ||
embedding_types=["float"], | ||
texts=[text] | ||
) | ||
|
||
return np.array(response.embeddings.float_[0], dtype=np.float32) | ||
except Exception as e: | ||
if self._show_warnings: | ||
warnings.warn( | ||
f"There was an exception while generating embeddings. Exception: {str(e)}. Retrying..." | ||
) | ||
|
||
raise RuntimeError( | ||
"Unable to generate embeddings through Cohere." | ||
) | ||
|
||
def embed_batch(self, texts: List[str]) -> List[np.ndarray]: | ||
"""Get embeddings for multiple texts using batched API calls.""" | ||
if not texts: | ||
return [] | ||
|
||
all_embeddings = [] | ||
|
||
# process in batches | ||
for i in range(0, len(texts), self._batch_size): | ||
batch = texts[i: i + self._batch_size] | ||
|
||
# check token_counts and warn if necessary | ||
token_counts = self.count_tokens_batch(batch) | ||
if self._show_warnings: | ||
for _, count in zip(batch, token_counts): | ||
if count > 512: | ||
warnings.warn( | ||
f"Text has {count} tokens which exceeds the model's context length of 512." | ||
"Generation may not be optimal." | ||
) | ||
|
||
try: | ||
for _ in range(self._max_retries): | ||
try: | ||
response = self.client.embed( | ||
model=self.model, | ||
input_type="search_document", | ||
embedding_types=["float"], | ||
texts=batch | ||
) | ||
|
||
embeddings = [ | ||
np.array(e, dtype=np.float32) for e in response.embeddings.float_ | ||
] | ||
all_embeddings.extend(embeddings) | ||
break | ||
except Exception as e: | ||
if self._show_warnings: | ||
warnings.warn( | ||
f"There was an exception while generating embeddings. Exception: {str(e)}. Retrying..." | ||
) | ||
|
||
except Exception as e: | ||
# If the batch fails, try one by one | ||
if len(batch) > 1: | ||
warnings.warn( | ||
f"Batch embedding failed: {str(e)}. Trying one by one." | ||
) | ||
individual_embeddings = [self.embed(text) for text in batch] | ||
all_embeddings.extend(individual_embeddings) | ||
else: | ||
raise e | ||
|
||
return all_embeddings | ||
|
||
def count_tokens(self, text: str) -> int: | ||
"""Count tokens in text using the model's tokenizer.""" | ||
return len(self._tokenizer.encode(text, add_special_tokens=False)) | ||
|
||
def count_tokens_batch(self, texts: List[str]) -> List[int]: | ||
"""Count tokens in multiple texts""" | ||
tokens = self._tokenizer.encode_batch(texts, add_special_tokens=False) | ||
return [len(t) for t in tokens] | ||
|
||
def similarity(self, u: np.ndarray, v: np.ndarray) -> float: | ||
"""Compute cosine similarity between two embeddings.""" | ||
return np.divide( | ||
np.dot(u, v), np.linalg.norm(u) * np.linalg.norm(v), dtype=float | ||
) | ||
|
||
@property | ||
def dimension(self) -> int: | ||
"""Return the embedding dimension""" | ||
return self._dimension | ||
|
||
def get_tokenizer_or_token_counter(self): | ||
"""Return a tokenizers tokenizer object of the current model""" | ||
return self._tokenizer | ||
|
||
@classmethod | ||
def is_available(cls) -> bool: | ||
"""Check if the Cohere package is available.""" | ||
return importlib.util.find_spec("cohere") is not None | ||
|
||
def __repr__(self) -> str: | ||
return f"CohereEmbeddings(model={self.model})" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.