-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
## Description New Cohere Embedder Class and a cookbook example ## Type of change Please check the options that are relevant: - [ ] Bug fix (non-breaking change which fixes an issue) - [x] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] Model update - [ ] Infrastructure change ## Checklist - [x] My code follows Phidata's style guidelines and best practices - [x] I have performed a self-review of my code - [x] I have added docstrings and comments for complex logic - [x] My changes generate no new warnings or errors - [x] I have added cookbook examples for my new addition (if needed) - [x] I have updated requirements.txt/pyproject.toml (if needed) - [x] I have verified my changes in a clean environment --------- Co-authored-by: Dirk Brand <[email protected]>
- Loading branch information
Showing
3 changed files
with
91 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
from phi.agent import AgentKnowledge | ||
from phi.vectordb.pgvector import PgVector | ||
from phi.embedder.cohere import CohereEmbedder | ||
|
||
embeddings = CohereEmbedder().get_embedding("The quick brown fox jumps over the lazy dog.") | ||
# Print the embeddings and their dimensions | ||
print(f"Embeddings: {embeddings[:5]}") | ||
print(f"Dimensions: {len(embeddings)}") | ||
|
||
# Example usage: | ||
knowledge_base = AgentKnowledge( | ||
vector_db=PgVector( | ||
db_url="postgresql+psycopg://ai:ai@localhost:5532/ai", | ||
table_name="cohere_embeddings", | ||
embedder=CohereEmbedder(), | ||
), | ||
num_documents=2, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
from typing import Optional, Dict, List, Tuple, Any, Union | ||
|
||
from phi.embedder.base import Embedder | ||
from phi.utils.log import logger | ||
|
||
try: | ||
from cohere import Client as CohereClient | ||
from cohere.types.embed_response import EmbeddingsFloatsEmbedResponse, EmbeddingsByTypeEmbedResponse | ||
except ImportError: | ||
raise ImportError("`cohere` not installed. Please install using `pip install cohere`.") | ||
|
||
|
||
class CohereEmbedder(Embedder): | ||
model: str = "embed-english-v3.0" | ||
input_type: str = "search_query" | ||
embedding_types: Optional[List[str]] = None | ||
api_key: Optional[str] = None | ||
request_params: Optional[Dict[str, Any]] = None | ||
client_params: Optional[Dict[str, Any]] = None | ||
cohere_client: Optional[CohereClient] = None | ||
|
||
@property | ||
def client(self) -> CohereClient: | ||
if self.cohere_client: | ||
return self.cohere_client | ||
client_params: Dict[str, Any] = {} | ||
if self.api_key: | ||
client_params["api_key"] = self.api_key | ||
return CohereClient(**client_params) | ||
|
||
def response(self, text: str) -> Union[EmbeddingsFloatsEmbedResponse, EmbeddingsByTypeEmbedResponse]: | ||
request_params: Dict[str, Any] = {} | ||
|
||
if self.model: | ||
request_params["model"] = self.model | ||
if self.input_type: | ||
request_params["input_type"] = self.input_type | ||
if self.embedding_types: | ||
request_params["embedding_types"] = self.embedding_types | ||
if self.request_params: | ||
request_params.update(self.request_params) | ||
return self.client.embed(texts=[text], **request_params) | ||
|
||
def get_embedding(self, text: str) -> List[float]: | ||
response: Union[EmbeddingsFloatsEmbedResponse, EmbeddingsByTypeEmbedResponse] = self.response(text=text) | ||
try: | ||
if isinstance(response, EmbeddingsFloatsEmbedResponse): | ||
return response.embeddings[0] | ||
elif isinstance(response, EmbeddingsByTypeEmbedResponse): | ||
return response.embeddings.float_[0] if response.embeddings.float_ else [] | ||
else: | ||
logger.warning("No embeddings found") | ||
return [] | ||
except Exception as e: | ||
logger.warning(e) | ||
return [] | ||
|
||
def get_embedding_and_usage(self, text: str) -> Tuple[List[float], Optional[Dict[str, Any]]]: | ||
response: Union[EmbeddingsFloatsEmbedResponse, EmbeddingsByTypeEmbedResponse] = self.response(text=text) | ||
|
||
embedding: List[float] = [] | ||
if isinstance(response, EmbeddingsFloatsEmbedResponse): | ||
embedding = response.embeddings[0] | ||
elif isinstance(response, EmbeddingsByTypeEmbedResponse): | ||
embedding = response.embeddings.float_[0] if response.embeddings.float_ else [] | ||
|
||
usage = response.meta.billed_units if response.meta else None | ||
if usage: | ||
return embedding, usage.model_dump() | ||
return embedding, None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters