diff --git a/cookbook/embedders/cohere_embedder.py b/cookbook/embedders/cohere_embedder.py new file mode 100644 index 000000000..be36af603 --- /dev/null +++ b/cookbook/embedders/cohere_embedder.py @@ -0,0 +1,18 @@ +from phi.agent import AgentKnowledge +from phi.vectordb.pgvector import PgVector +from phi.embedder.cohere import CohereEmbedder + +embeddings = CohereEmbedder().get_embedding("The quick brown fox jumps over the lazy dog.") +# Print the embeddings and their dimensions +print(f"Embeddings: {embeddings[:5]}") +print(f"Dimensions: {len(embeddings)}") + +# Example usage: +knowledge_base = AgentKnowledge( + vector_db=PgVector( + db_url="postgresql+psycopg://ai:ai@localhost:5532/ai", + table_name="cohere_embeddings", + embedder=CohereEmbedder(), + ), + num_documents=2, +) diff --git a/phi/embedder/cohere.py b/phi/embedder/cohere.py new file mode 100644 index 000000000..6b3b3594c --- /dev/null +++ b/phi/embedder/cohere.py @@ -0,0 +1,70 @@ +from typing import Optional, Dict, List, Tuple, Any, Union + +from phi.embedder.base import Embedder +from phi.utils.log import logger + +try: + from cohere import Client as CohereClient + from cohere.types.embed_response import EmbeddingsFloatsEmbedResponse, EmbeddingsByTypeEmbedResponse +except ImportError: + raise ImportError("`cohere` not installed. Please install using `pip install cohere`.") + + +class CohereEmbedder(Embedder): + model: str = "embed-english-v3.0" + input_type: str = "search_query" + embedding_types: Optional[List[str]] = None + api_key: Optional[str] = None + request_params: Optional[Dict[str, Any]] = None + client_params: Optional[Dict[str, Any]] = None + cohere_client: Optional[CohereClient] = None + + @property + def client(self) -> CohereClient: + if self.cohere_client: + return self.cohere_client + client_params: Dict[str, Any] = {} + if self.api_key: + client_params["api_key"] = self.api_key + return CohereClient(**client_params) + + def response(self, text: str) -> Union[EmbeddingsFloatsEmbedResponse, EmbeddingsByTypeEmbedResponse]: + request_params: Dict[str, Any] = {} + + if self.model: + request_params["model"] = self.model + if self.input_type: + request_params["input_type"] = self.input_type + if self.embedding_types: + request_params["embedding_types"] = self.embedding_types + if self.request_params: + request_params.update(self.request_params) + return self.client.embed(texts=[text], **request_params) + + def get_embedding(self, text: str) -> List[float]: + response: Union[EmbeddingsFloatsEmbedResponse, EmbeddingsByTypeEmbedResponse] = self.response(text=text) + try: + if isinstance(response, EmbeddingsFloatsEmbedResponse): + return response.embeddings[0] + elif isinstance(response, EmbeddingsByTypeEmbedResponse): + return response.embeddings.float_[0] if response.embeddings.float_ else [] + else: + logger.warning("No embeddings found") + return [] + except Exception as e: + logger.warning(e) + return [] + + def get_embedding_and_usage(self, text: str) -> Tuple[List[float], Optional[Dict[str, Any]]]: + response: Union[EmbeddingsFloatsEmbedResponse, EmbeddingsByTypeEmbedResponse] = self.response(text=text) + + embedding: List[float] = [] + if isinstance(response, EmbeddingsFloatsEmbedResponse): + embedding = response.embeddings[0] + elif isinstance(response, EmbeddingsByTypeEmbedResponse): + embedding = response.embeddings.float_[0] if response.embeddings.float_ else [] + + usage = response.meta.billed_units if response.meta else None + if usage: + return embedding, usage.model_dump() + return embedding, None diff --git a/phi/embedder/openai.py b/phi/embedder/openai.py index dc223c121..db1d27a40 100644 --- a/phi/embedder/openai.py +++ b/phi/embedder/openai.py @@ -39,7 +39,7 @@ def client(self) -> OpenAIClient: _client_params.update(self.client_params) return OpenAIClient(**_client_params) - def _response(self, text: str) -> CreateEmbeddingResponse: + def response(self, text: str) -> CreateEmbeddingResponse: _request_params: Dict[str, Any] = { "input": text, "model": self.model, @@ -54,7 +54,7 @@ def _response(self, text: str) -> CreateEmbeddingResponse: return self.client.embeddings.create(**_request_params) def get_embedding(self, text: str) -> List[float]: - response: CreateEmbeddingResponse = self._response(text=text) + response: CreateEmbeddingResponse = self.response(text=text) try: return response.data[0].embedding except Exception as e: @@ -62,7 +62,7 @@ def get_embedding(self, text: str) -> List[float]: return [] def get_embedding_and_usage(self, text: str) -> Tuple[List[float], Optional[Dict]]: - response: CreateEmbeddingResponse = self._response(text=text) + response: CreateEmbeddingResponse = self.response(text=text) embedding = response.data[0].embedding usage = response.usage