Skip to content

Commit

Permalink
feature #118
Browse files Browse the repository at this point in the history
- cohere embeddings support
- added to the registry for autoembeddings
- added tests
  • Loading branch information
Udayk02 committed Jan 7, 2025
1 parent 390b7ff commit 53afd5d
Show file tree
Hide file tree
Showing 9 changed files with 398 additions and 2 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ model2vec = ["model2vec>=0.3.0", "numpy>=1.23.0, <2.2"]
st = ["sentence-transformers>=3.0.0", "numpy>=1.23.0, <2.2"]
openai = ["openai>=1.0.0", "numpy>=1.23.0, <2.2"]
semantic = ["model2vec>=0.3.0", "numpy>=1.23.0, <2.2"]
all = ["sentence-transformers>=3.0.0", "numpy>=1.23.0, <2.2", "openai>=1.0.0", "model2vec>=0.3.0"]
cohere = ["cohere>=5.13.0", "numpy>=1.23.0, <2.2"]
all = ["sentence-transformers>=3.0.0", "numpy>=1.23.0, <2.2", "openai>=1.0.0", "model2vec>=0.3.0", "cohere>=5.13.0"]
dev = [
"pytest>=6.2.0",
"pytest-cov>=4.0.0",
Expand Down
2 changes: 2 additions & 0 deletions src/chonkie/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
Model2VecEmbeddings,
OpenAIEmbeddings,
SentenceTransformerEmbeddings,
CohereEmbeddings,
)
from .refinery import (
BaseRefinery,
Expand Down Expand Up @@ -77,6 +78,7 @@
"Model2VecEmbeddings",
"SentenceTransformerEmbeddings",
"OpenAIEmbeddings",
"CohereEmbeddings",
"AutoEmbeddings",
]

Expand Down
2 changes: 2 additions & 0 deletions src/chonkie/embeddings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
from .model2vec import Model2VecEmbeddings
from .openai import OpenAIEmbeddings
from .sentence_transformer import SentenceTransformerEmbeddings
from .cohere import CohereEmbeddings

# Add all embeddings classes to __all__
__all__ = [
"BaseEmbeddings",
"Model2VecEmbeddings",
"SentenceTransformerEmbeddings",
"OpenAIEmbeddings",
"CohereEmbeddings",
"AutoEmbeddings",
]
6 changes: 6 additions & 0 deletions src/chonkie/embeddings/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ class AutoEmbeddings:
# Get Anthropic embeddings
embeddings = AutoEmbeddings.get_embeddings("anthropic://claude-v1", api_key="...")
# Get Cohere embeddings
embeddings = AutoEmbeddings.get_embeddings("cohere://embed-english-light-v3.0", api_key="...")
"""

@classmethod
Expand Down Expand Up @@ -53,6 +56,9 @@ def get_embeddings(
# Get Anthropic embeddings
embeddings = AutoEmbeddings.get_embeddings("anthropic://claude-v1", api_key="...")
# Get Cohere embeddings
embeddings = AutoEmbeddings.get_embeddings("cohere://embed-english-light-v3.0", api_key="...")
"""
# Load embeddings instance if already provided
if isinstance(model, BaseEmbeddings):
Expand Down
203 changes: 203 additions & 0 deletions src/chonkie/embeddings/cohere.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
import importlib
import os
import warnings
from typing import List, Optional

import numpy as np

from .base import BaseEmbeddings

class CohereEmbeddings(BaseEmbeddings):
"""Cohere embeddings implementation using their API"""

AVAILABLE_MODELS = {
# cohere v3.0 models
"embed-english-v3.0": (True, 1024), # tokenizer from tokenizers
"embed-multilingual-v3.0": (False, 1024), # not listed in the cohere models api list
"embed-english-light-v3.0": (True, 384), # from tokenizers
"embed-multilingual-light-v3.0": (False, 384), # not listed in the cohere models api list
# cohere v2.0 models
"embed-english-v2.0": (False, 4096), # url is not available in the cohere models api list
"embed-english-light-v2.0": (False, 1024), # not listed in the models list
"embed-multilingual-v2.0": (True, 768) # from tokenizers
}

DEFAULT_MODEL = "embed-english-light-v3.0"
TOKENIZER_BASE_URL = "https://storage.googleapis.com/cohere-public/tokenizers/"

def __init__(
self,
model: str = DEFAULT_MODEL,
api_key: Optional[str] = None,
client_name: Optional[str] = None,
max_retries: int = 3,
timeout: float = 60.0,
batch_size: int = 128,
show_warnings: bool = True
):
"""Initialize Cohere embeddings.
Args:
model: name of the Cohere embedding model to use
api_key: (optional) Cohere API key (if not provided, looks for COHERE_API_KEY environment variable)
organization: (optional) client name for API requests
max_retries: maximum number of retries for failed requests
timeout: timeout in seconds for API requests
batch_size: maximum number of texts to embed in one API call (maximum allowed by Cohere is 96)
show_warnings: whether to show warnings about token usage and truncation
"""

super().__init__()
if not self.is_available():
raise ImportError(
"Cohere package is not available. Please install it via pip."
)
else:
global cohere
import tokenizers
import requests
from cohere import ClientV2 # using v2

if model not in self.AVAILABLE_MODELS:
raise ValueError(
f"Model {model} is not available. Choose from: {list(self.AVAILABLE_MODELS.keys())}"
)

self.model = model
self._dimension = self.AVAILABLE_MODELS[model][1]
tokenizer_url = self.TOKENIZER_BASE_URL + (model if self.AVAILABLE_MODELS[model][0] else self.DEFAULT_MODEL) + ".json"
response = requests.get(tokenizer_url)
self._tokenizer = tokenizers.Tokenizer.from_str(response.text)
self._batch_size = min(batch_size, 96) # max batch size for cohere is 96
self._show_warnings = show_warnings
self._max_retries = max_retries
self._api_key = api_key or os.getenv("COHERE_API_KEY")

if self._api_key is None:
raise ValueError(
"Cohere API key not found. Either pass it as api_key or set COHERE_API_KEY environment variable."
)

# setup Cohere client
self.client = ClientV2(
api_key=api_key or os.getenv("COHERE_API_KEY"),
client_name=client_name,
timeout=timeout
)


def embed(self, text: str) -> np.ndarray:
"""Generate embeddings for a single text"""
token_count = self.count_tokens(text)
if token_count > 512 and self._show_warnings: # Cohere models max_context_length
warnings.warn(
f"Text has {token_count} tokens which exceeds the model's context length of 512."
"Generation may not be optimal"
)

for _ in range(self._max_retries):
try:
response = self.client.embed(
model=self.model,
input_type="search_document",
embedding_types=["float"],
texts=[text]
)

return np.array(response.embeddings.float_[0], dtype=np.float32)
except Exception as e:
if self._show_warnings:
warnings.warn(
f"There was an exception while generating embeddings. Exception: {str(e)}. Retrying..."
)

raise RuntimeError(
"Unable to generate embeddings through Cohere."
)

def embed_batch(self, texts: List[str]) -> List[np.ndarray]:
"""Get embeddings for multiple texts using batched API calls."""
if not texts:
return []

all_embeddings = []

# process in batches
for i in range(0, len(texts), self._batch_size):
batch = texts[i: i + self._batch_size]

# check token_counts and warn if necessary
token_counts = self.count_tokens_batch(batch)
if self._show_warnings:
for _, count in zip(batch, token_counts):
if count > 512:
warnings.warn(
f"Text has {count} tokens which exceeds the model's context length of 512."
"Generation may not be optimal."
)

try:
for _ in range(self._max_retries):
try:
response = self.client.embed(
model=self.model,
input_type="search_document",
embedding_types=["float"],
texts=batch
)

embeddings = [
np.array(e, dtype=np.float32) for e in response.embeddings.float_
]
all_embeddings.extend(embeddings)
break
except Exception as e:
if self._show_warnings:
warnings.warn(
f"There was an exception while generating embeddings. Exception: {str(e)}. Retrying..."
)

except Exception as e:
# If the batch fails, try one by one
if len(batch) > 1:
warnings.warn(
f"Batch embedding failed: {str(e)}. Trying one by one."
)
individual_embeddings = [self.embed(text) for text in batch]
all_embeddings.extend(individual_embeddings)
else:
raise e

return all_embeddings

def count_tokens(self, text: str) -> int:
"""Count tokens in text using the model's tokenizer."""
return len(self._tokenizer.encode(text, add_special_tokens=False))

def count_tokens_batch(self, texts: List[str]) -> List[int]:
"""Count tokens in multiple texts"""
tokens = self._tokenizer.encode_batch(texts, add_special_tokens=False)
return [len(t) for t in tokens]

def similarity(self, u: np.ndarray, v: np.ndarray) -> float:
"""Compute cosine similarity between two embeddings."""
return np.divide(
np.dot(u, v), np.linalg.norm(u) * np.linalg.norm(v), dtype=float
)

@property
def dimension(self) -> int:
"""Return the embedding dimension"""
return self._dimension

def get_tokenizer_or_token_counter(self):
"""Return a tokenizers tokenizer object of the current model"""
return self._tokenizer

@classmethod
def is_available(cls) -> bool:
"""Check if the Cohere package is available."""
return importlib.util.find_spec("cohere") is not None

def __repr__(self) -> str:
return f"CohereEmbeddings(model={self.model})"
13 changes: 13 additions & 0 deletions src/chonkie/embeddings/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .model2vec import Model2VecEmbeddings
from .openai import OpenAIEmbeddings
from .sentence_transformer import SentenceTransformerEmbeddings
from .cohere import CohereEmbeddings


@dataclass
Expand Down Expand Up @@ -159,3 +160,15 @@ def list_available(cls) -> List[str]:
pattern=r"^minishlab/|^minishlab/potion-base-|^minishlab/potion-|^potion-",
supported_types=["Model2Vec", "model2vec"],
)

# Register Cohere embeddings with pattern
EmbeddingsRegistry.register(
"cohere", CohereEmbeddings, pattern=r"^cohere|^embed-"
)
EmbeddingsRegistry.register("embed-english-v3.0", CohereEmbeddings)
EmbeddingsRegistry.register("embed-multilingual-v3.0", CohereEmbeddings)
EmbeddingsRegistry.register("embed-english-light-v3.0", CohereEmbeddings)
EmbeddingsRegistry.register("embed-multilingual-light-v3.0", CohereEmbeddings)
EmbeddingsRegistry.register("embed-english-v2.0", CohereEmbeddings)
EmbeddingsRegistry.register("embed-english-light-v2.0", CohereEmbeddings)
EmbeddingsRegistry.register("embed-multilingual-v2.0", CohereEmbeddings)
36 changes: 35 additions & 1 deletion tests/chunker/test_semantic_chunker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest

from chonkie import SemanticChunker
from chonkie.embeddings import Model2VecEmbeddings, OpenAIEmbeddings
from chonkie.embeddings import Model2VecEmbeddings, OpenAIEmbeddings, CohereEmbeddings
from chonkie.types import Chunk, SemanticChunk


Expand Down Expand Up @@ -44,6 +44,19 @@ def openai_embedding_model():
return OpenAIEmbeddings(model="text-embedding-3-small", api_key=api_key)


@pytest.fixture
def cohere_embedding_model():
"""Fixture that returns an Cohere embedding model for testing.
Returns:
CohereEmbeddings: An Cohere model initialized with 'embed-english-light-v3.0'
and the API key from environment variables.
"""
api_key = os.environ.get("COHERE_API_KEY")
return CohereEmbeddings(model="embed-english-light-v3.0", api_key=api_key)


@pytest.fixture
def sample_complex_markdown_text():
"""Fixture that returns a sample markdown text with complex formatting.
Expand Down Expand Up @@ -127,6 +140,27 @@ def test_semantic_chunker_initialization_sentence_transformer():
assert chunker.min_chunk_size == 2


@pytest.mark.skipif(
"COHERE_API_KEY" not in os.environ,
reason="Skipping test because COHERE_API_KEY is not defined",
)
def test_semantic_chunker_initialization_cohere(cohere_embedding_model):
"""Test that the SemanticChunker can be initialized with required parameters."""
chunker = SemanticChunker(
embedding_model=cohere_embedding_model,
chunk_size=512,
threshold=0.5,
)

assert chunker is not None
assert chunker.chunk_size == 512
assert chunker.threshold == 0.5
assert chunker.mode == "window"
assert chunker.similarity_window == 1
assert chunker.min_sentences == 1
assert chunker.min_chunk_size == 2


def test_semantic_chunker_chunking(embedding_model, sample_text):
"""Test that the SemanticChunker can chunk a sample text."""
chunker = SemanticChunker(
Expand Down
16 changes: 16 additions & 0 deletions tests/embeddings/test_auto_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from chonkie.embeddings.model2vec import Model2VecEmbeddings
from chonkie.embeddings.openai import OpenAIEmbeddings
from chonkie.embeddings.sentence_transformer import SentenceTransformerEmbeddings
from chonkie.embeddings.cohere import CohereEmbeddings


@pytest.fixture
Expand All @@ -32,6 +33,12 @@ def openai_identifier():
return "text-embedding-3-small"


@pytest.fixture
def cohere_identifier():
"""Fixture providing an Cohere identifier."""
return "embed-english-light-v3.0"


@pytest.fixture
def invalid_identifier():
"""Fixture providing an invalid identifier."""
Expand Down Expand Up @@ -70,6 +77,15 @@ def test_auto_embeddings_openai(openai_identifier):
assert embeddings.model == openai_identifier


def test_auto_embeddings_cohere(cohere_identifier):
"""Test that the AutoEmbeddings class can get Cohere embeddings."""
embeddings = AutoEmbeddings.get_embeddings(
cohere_identifier, api_key="your_cohere_api_key"
)
assert isinstance(embeddings, CohereEmbeddings)
assert embeddings.model == cohere_identifier


def test_auto_embeddings_invalid_identifier(invalid_identifier):
"""Test that the AutoEmbeddings class raises an error for an invalid identifier."""
with pytest.raises(ValueError):
Expand Down
Loading

0 comments on commit 53afd5d

Please sign in to comment.