Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] Support Cohere Embeddings for SemanticChunker and SDPMChunker #118 #130

Merged
merged 7 commits into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ model2vec = ["model2vec>=0.3.0", "numpy>=1.23.0, <2.2"]
st = ["sentence-transformers>=3.0.0", "numpy>=1.23.0, <2.2"]
openai = ["tiktoken>=0.5.0", "openai>=1.0.0", "numpy>=1.23.0, <2.2"]
semantic = ["model2vec>=0.3.0", "numpy>=1.23.0, <2.2"]
all = ["sentence-transformers>=3.0.0", "numpy>=1.23.0, <2.2", "openai>=1.0.0", "model2vec>=0.3.0"]
cohere = ["cohere>=5.13.0", "numpy>=1.23.0, <2.2"]
all = ["sentence-transformers>=3.0.0", "numpy>=1.23.0, <2.2", "openai>=1.0.0", "model2vec>=0.3.0", "cohere>=5.13.0"]
dev = [
"tiktoken>=0.5.0",
"pytest>=6.2.0",
Expand Down
2 changes: 2 additions & 0 deletions src/chonkie/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
Model2VecEmbeddings,
OpenAIEmbeddings,
SentenceTransformerEmbeddings,
CohereEmbeddings,
)
from .refinery import (
BaseRefinery,
Expand Down Expand Up @@ -77,6 +78,7 @@
"Model2VecEmbeddings",
"SentenceTransformerEmbeddings",
"OpenAIEmbeddings",
"CohereEmbeddings",
"AutoEmbeddings",
]

Expand Down
2 changes: 2 additions & 0 deletions src/chonkie/embeddings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
from .model2vec import Model2VecEmbeddings
from .openai import OpenAIEmbeddings
from .sentence_transformer import SentenceTransformerEmbeddings
from .cohere import CohereEmbeddings

# Add all embeddings classes to __all__
__all__ = [
"BaseEmbeddings",
"Model2VecEmbeddings",
"SentenceTransformerEmbeddings",
"OpenAIEmbeddings",
"CohereEmbeddings",
"AutoEmbeddings",
]
6 changes: 6 additions & 0 deletions src/chonkie/embeddings/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ class AutoEmbeddings:
# Get Anthropic embeddings
embeddings = AutoEmbeddings.get_embeddings("anthropic://claude-v1", api_key="...")

# Get Cohere embeddings
embeddings = AutoEmbeddings.get_embeddings("cohere://embed-english-light-v3.0", api_key="...")

"""

@classmethod
Expand Down Expand Up @@ -52,6 +55,9 @@ def get_embeddings(
# Get Anthropic embeddings
embeddings = AutoEmbeddings.get_embeddings("anthropic://claude-v1", api_key="...")

# Get Cohere embeddings
embeddings = AutoEmbeddings.get_embeddings("cohere://embed-english-light-v3.0", api_key="...")

"""
# Load embeddings instance if already provided
if isinstance(model, BaseEmbeddings):
Expand Down
203 changes: 203 additions & 0 deletions src/chonkie/embeddings/cohere.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
import importlib
import os
import warnings
from typing import List, Optional

import numpy as np

from .base import BaseEmbeddings

class CohereEmbeddings(BaseEmbeddings):
"""Cohere embeddings implementation using their API"""

AVAILABLE_MODELS = {
# cohere v3.0 models
"embed-english-v3.0": (True, 1024), # tokenizer from tokenizers
"embed-multilingual-v3.0": (False, 1024), # not listed in the cohere models api list
"embed-english-light-v3.0": (True, 384), # from tokenizers
"embed-multilingual-light-v3.0": (False, 384), # not listed in the cohere models api list
# cohere v2.0 models
"embed-english-v2.0": (False, 4096), # url is not available in the cohere models api list
"embed-english-light-v2.0": (False, 1024), # not listed in the models list
"embed-multilingual-v2.0": (True, 768) # from tokenizers
}

DEFAULT_MODEL = "embed-english-light-v3.0"
TOKENIZER_BASE_URL = "https://storage.googleapis.com/cohere-public/tokenizers/"

def __init__(
self,
model: str = DEFAULT_MODEL,
api_key: Optional[str] = None,
client_name: Optional[str] = None,
max_retries: int = 3,
timeout: float = 60.0,
batch_size: int = 96,
show_warnings: bool = True
):
"""Initialize Cohere embeddings.

Args:
model: name of the Cohere embedding model to use
api_key: (optional) Cohere API key (if not provided, looks for COHERE_API_KEY environment variable)
client_name: (optional) client name for API requests
max_retries: maximum number of retries for failed requests
timeout: timeout in seconds for API requests
batch_size: maximum number of texts to embed in one API call (maximum allowed by Cohere is 96)
show_warnings: whether to show warnings about token usage and truncation
"""

super().__init__()
if not self.is_available():
raise ImportError(
"Cohere package is not available. Please install it via pip."
)
else:
global cohere
import tokenizers
import requests
from cohere import ClientV2 # using v2

if model not in self.AVAILABLE_MODELS:
raise ValueError(
f"Model {model} is not available. Choose from: {list(self.AVAILABLE_MODELS.keys())}"
)

self.model = model
self._dimension = self.AVAILABLE_MODELS[model][1]
tokenizer_url = self.TOKENIZER_BASE_URL + (model if self.AVAILABLE_MODELS[model][0] else self.DEFAULT_MODEL) + ".json"
response = requests.get(tokenizer_url)
self._tokenizer = tokenizers.Tokenizer.from_str(response.text)
self._batch_size = min(batch_size, 96) # max batch size for cohere is 96
self._show_warnings = show_warnings
self._max_retries = max_retries
self._api_key = api_key or os.getenv("COHERE_API_KEY")

if self._api_key is None:
raise ValueError(
"Cohere API key not found. Either pass it as api_key or set COHERE_API_KEY environment variable."
)

# setup Cohere client
self.client = ClientV2(
api_key=api_key or os.getenv("COHERE_API_KEY"),
client_name=client_name,
timeout=timeout
)


def embed(self, text: str) -> np.ndarray:
"""Generate embeddings for a single text"""
token_count = self.count_tokens(text)
if token_count > 512 and self._show_warnings: # Cohere models max_context_length
warnings.warn(
f"Text has {token_count} tokens which exceeds the model's context length of 512."
"Generation may not be optimal"
)

for _ in range(self._max_retries):
try:
response = self.client.embed(
model=self.model,
input_type="search_document",
embedding_types=["float"],
texts=[text]
)

return np.array(response.embeddings.float_[0], dtype=np.float32)
except Exception as e:
if self._show_warnings:
warnings.warn(
f"There was an exception while generating embeddings. Exception: {str(e)}. Retrying..."
)

raise RuntimeError(
"Unable to generate embeddings through Cohere."
)

def embed_batch(self, texts: List[str]) -> List[np.ndarray]:
"""Get embeddings for multiple texts using batched API calls."""
if not texts:
return []

all_embeddings = []

# process in batches
for i in range(0, len(texts), self._batch_size):
batch = texts[i: i + self._batch_size]

# check token_counts and warn if necessary
token_counts = self.count_tokens_batch(batch)
if self._show_warnings:
for _, count in zip(batch, token_counts):
if count > 512:
warnings.warn(
f"Text has {count} tokens which exceeds the model's context length of 512."
"Generation may not be optimal."
)

try:
for _ in range(self._max_retries):
try:
response = self.client.embed(
model=self.model,
input_type="search_document",
embedding_types=["float"],
texts=batch
)

embeddings = [
np.array(e, dtype=np.float32) for e in response.embeddings.float_
]
all_embeddings.extend(embeddings)
break
except Exception as e:
if self._show_warnings:
warnings.warn(
f"There was an exception while generating embeddings. Exception: {str(e)}. Retrying..."
)

except Exception as e:
# If the batch fails, try one by one
if len(batch) > 1:
warnings.warn(
f"Batch embedding failed: {str(e)}. Trying one by one."
)
individual_embeddings = [self.embed(text) for text in batch]
all_embeddings.extend(individual_embeddings)
else:
raise e

return all_embeddings

def count_tokens(self, text: str) -> int:
"""Count tokens in text using the model's tokenizer."""
return len(self._tokenizer.encode(text, add_special_tokens=False))

def count_tokens_batch(self, texts: List[str]) -> List[int]:
"""Count tokens in multiple texts"""
tokens = self._tokenizer.encode_batch(texts, add_special_tokens=False)
return [len(t) for t in tokens]

def similarity(self, u: np.ndarray, v: np.ndarray) -> float:
"""Compute cosine similarity between two embeddings."""
return np.divide(
np.dot(u, v), np.linalg.norm(u) * np.linalg.norm(v), dtype=float
)

@property
def dimension(self) -> int:
"""Return the embedding dimension"""
return self._dimension

def get_tokenizer_or_token_counter(self):
"""Return a tokenizers tokenizer object of the current model"""
return self._tokenizer

@classmethod
def is_available(cls) -> bool:
"""Check if the Cohere package is available."""
return importlib.util.find_spec("cohere") is not None

def __repr__(self) -> str:
return f"CohereEmbeddings(model={self.model})"
13 changes: 13 additions & 0 deletions src/chonkie/embeddings/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .model2vec import Model2VecEmbeddings
from .openai import OpenAIEmbeddings
from .sentence_transformer import SentenceTransformerEmbeddings
from .cohere import CohereEmbeddings


@dataclass
Expand Down Expand Up @@ -159,3 +160,15 @@ def list_available(cls) -> List[str]:
pattern=r"^minishlab/|^minishlab/potion-base-|^minishlab/potion-|^potion-",
supported_types=["Model2Vec", "model2vec"],
)

# Register Cohere embeddings with pattern
EmbeddingsRegistry.register(
"cohere", CohereEmbeddings, pattern=r"^cohere|^embed-"
)
EmbeddingsRegistry.register("embed-english-v3.0", CohereEmbeddings)
EmbeddingsRegistry.register("embed-multilingual-v3.0", CohereEmbeddings)
EmbeddingsRegistry.register("embed-english-light-v3.0", CohereEmbeddings)
EmbeddingsRegistry.register("embed-multilingual-light-v3.0", CohereEmbeddings)
EmbeddingsRegistry.register("embed-english-v2.0", CohereEmbeddings)
EmbeddingsRegistry.register("embed-english-light-v2.0", CohereEmbeddings)
EmbeddingsRegistry.register("embed-multilingual-v2.0", CohereEmbeddings)
36 changes: 35 additions & 1 deletion tests/chunker/test_semantic_chunker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest

from chonkie import SemanticChunker
from chonkie.embeddings import Model2VecEmbeddings, OpenAIEmbeddings
from chonkie.embeddings import Model2VecEmbeddings, OpenAIEmbeddings, CohereEmbeddings
from chonkie.types import Chunk, SemanticChunk


Expand Down Expand Up @@ -44,6 +44,19 @@ def openai_embedding_model():
return OpenAIEmbeddings(model="text-embedding-3-small", api_key=api_key)


@pytest.fixture
def cohere_embedding_model():
"""Fixture that returns an Cohere embedding model for testing.

Returns:
CohereEmbeddings: An Cohere model initialized with 'embed-english-light-v3.0'
and the API key from environment variables.

"""
api_key = os.environ.get("COHERE_API_KEY")
return CohereEmbeddings(model="embed-english-light-v3.0", api_key=api_key)


@pytest.fixture
def sample_complex_markdown_text():
"""Fixture that returns a sample markdown text with complex formatting.
Expand Down Expand Up @@ -127,6 +140,27 @@ def test_semantic_chunker_initialization_sentence_transformer():
assert chunker.min_chunk_size == 2


@pytest.mark.skipif(
"COHERE_API_KEY" not in os.environ,
reason="Skipping test because COHERE_API_KEY is not defined",
)
def test_semantic_chunker_initialization_cohere(cohere_embedding_model):
"""Test that the SemanticChunker can be initialized with required parameters."""
chunker = SemanticChunker(
embedding_model=cohere_embedding_model,
chunk_size=512,
threshold=0.5,
)

assert chunker is not None
assert chunker.chunk_size == 512
assert chunker.threshold == 0.5
assert chunker.mode == "window"
assert chunker.similarity_window == 1
assert chunker.min_sentences == 1
assert chunker.min_chunk_size == 2


def test_semantic_chunker_chunking(embedding_model, sample_text):
"""Test that the SemanticChunker can chunk a sample text."""
chunker = SemanticChunker(
Expand Down
16 changes: 16 additions & 0 deletions tests/embeddings/test_auto_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from chonkie.embeddings.model2vec import Model2VecEmbeddings
from chonkie.embeddings.openai import OpenAIEmbeddings
from chonkie.embeddings.sentence_transformer import SentenceTransformerEmbeddings
from chonkie.embeddings.cohere import CohereEmbeddings


@pytest.fixture
Expand All @@ -32,6 +33,12 @@ def openai_identifier():
return "text-embedding-3-small"


@pytest.fixture
def cohere_identifier():
"""Fixture providing an Cohere identifier."""
return "embed-english-light-v3.0"


@pytest.fixture
def invalid_identifier():
"""Fixture providing an invalid identifier."""
Expand Down Expand Up @@ -70,6 +77,15 @@ def test_auto_embeddings_openai(openai_identifier):
assert embeddings.model == openai_identifier


def test_auto_embeddings_cohere(cohere_identifier):
"""Test that the AutoEmbeddings class can get Cohere embeddings."""
embeddings = AutoEmbeddings.get_embeddings(
cohere_identifier, api_key="your_cohere_api_key"
)
assert isinstance(embeddings, CohereEmbeddings)
assert embeddings.model == cohere_identifier


def test_auto_embeddings_invalid_identifier(invalid_identifier):
"""Test that the AutoEmbeddings class raises an error for an invalid identifier."""
with pytest.raises(ValueError):
Expand Down
Loading