Skip to content

Commit

Permalink
cohere-embedder-phi-2214 (#1586)
Browse files Browse the repository at this point in the history
## Description

New Cohere Embedder Class and a cookbook example 

## Type of change

Please check the options that are relevant:

- [ ] Bug fix (non-breaking change which fixes an issue)
- [x] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to not work as expected)
- [ ] Model update
- [ ] Infrastructure change

## Checklist

- [x] My code follows Phidata's style guidelines and best practices
- [x] I have performed a self-review of my code
- [x] I have added docstrings and comments for complex logic
- [x] My changes generate no new warnings or errors
- [x] I have added cookbook examples for my new addition (if needed)
- [x] I have updated requirements.txt/pyproject.toml (if needed)
- [x] I have verified my changes in a clean environment

---------

Co-authored-by: Dirk Brand <[email protected]>
  • Loading branch information
ysolanky and dirkbrnd authored Dec 17, 2024
1 parent 40fa745 commit eae8290
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 3 deletions.
18 changes: 18 additions & 0 deletions cookbook/embedders/cohere_embedder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from phi.agent import AgentKnowledge
from phi.vectordb.pgvector import PgVector
from phi.embedder.cohere import CohereEmbedder

embeddings = CohereEmbedder().get_embedding("The quick brown fox jumps over the lazy dog.")
# Print the embeddings and their dimensions
print(f"Embeddings: {embeddings[:5]}")
print(f"Dimensions: {len(embeddings)}")

# Example usage:
knowledge_base = AgentKnowledge(
vector_db=PgVector(
db_url="postgresql+psycopg://ai:ai@localhost:5532/ai",
table_name="cohere_embeddings",
embedder=CohereEmbedder(),
),
num_documents=2,
)
70 changes: 70 additions & 0 deletions phi/embedder/cohere.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from typing import Optional, Dict, List, Tuple, Any, Union

from phi.embedder.base import Embedder
from phi.utils.log import logger

try:
from cohere import Client as CohereClient
from cohere.types.embed_response import EmbeddingsFloatsEmbedResponse, EmbeddingsByTypeEmbedResponse
except ImportError:
raise ImportError("`cohere` not installed. Please install using `pip install cohere`.")


class CohereEmbedder(Embedder):
model: str = "embed-english-v3.0"
input_type: str = "search_query"
embedding_types: Optional[List[str]] = None
api_key: Optional[str] = None
request_params: Optional[Dict[str, Any]] = None
client_params: Optional[Dict[str, Any]] = None
cohere_client: Optional[CohereClient] = None

@property
def client(self) -> CohereClient:
if self.cohere_client:
return self.cohere_client
client_params: Dict[str, Any] = {}
if self.api_key:
client_params["api_key"] = self.api_key
return CohereClient(**client_params)

def response(self, text: str) -> Union[EmbeddingsFloatsEmbedResponse, EmbeddingsByTypeEmbedResponse]:
request_params: Dict[str, Any] = {}

if self.model:
request_params["model"] = self.model
if self.input_type:
request_params["input_type"] = self.input_type
if self.embedding_types:
request_params["embedding_types"] = self.embedding_types
if self.request_params:
request_params.update(self.request_params)
return self.client.embed(texts=[text], **request_params)

def get_embedding(self, text: str) -> List[float]:
response: Union[EmbeddingsFloatsEmbedResponse, EmbeddingsByTypeEmbedResponse] = self.response(text=text)
try:
if isinstance(response, EmbeddingsFloatsEmbedResponse):
return response.embeddings[0]
elif isinstance(response, EmbeddingsByTypeEmbedResponse):
return response.embeddings.float_[0] if response.embeddings.float_ else []
else:
logger.warning("No embeddings found")
return []
except Exception as e:
logger.warning(e)
return []

def get_embedding_and_usage(self, text: str) -> Tuple[List[float], Optional[Dict[str, Any]]]:
response: Union[EmbeddingsFloatsEmbedResponse, EmbeddingsByTypeEmbedResponse] = self.response(text=text)

embedding: List[float] = []
if isinstance(response, EmbeddingsFloatsEmbedResponse):
embedding = response.embeddings[0]
elif isinstance(response, EmbeddingsByTypeEmbedResponse):
embedding = response.embeddings.float_[0] if response.embeddings.float_ else []

usage = response.meta.billed_units if response.meta else None
if usage:
return embedding, usage.model_dump()
return embedding, None
6 changes: 3 additions & 3 deletions phi/embedder/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def client(self) -> OpenAIClient:
_client_params.update(self.client_params)
return OpenAIClient(**_client_params)

def _response(self, text: str) -> CreateEmbeddingResponse:
def response(self, text: str) -> CreateEmbeddingResponse:
_request_params: Dict[str, Any] = {
"input": text,
"model": self.model,
Expand All @@ -54,15 +54,15 @@ def _response(self, text: str) -> CreateEmbeddingResponse:
return self.client.embeddings.create(**_request_params)

def get_embedding(self, text: str) -> List[float]:
response: CreateEmbeddingResponse = self._response(text=text)
response: CreateEmbeddingResponse = self.response(text=text)
try:
return response.data[0].embedding
except Exception as e:
logger.warning(e)
return []

def get_embedding_and_usage(self, text: str) -> Tuple[List[float], Optional[Dict]]:
response: CreateEmbeddingResponse = self._response(text=text)
response: CreateEmbeddingResponse = self.response(text=text)

embedding = response.data[0].embedding
usage = response.usage
Expand Down

0 comments on commit eae8290

Please sign in to comment.