-
Notifications
You must be signed in to change notification settings - Fork 184
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update cohere prompt driver, add cohere embedding driver, cohere stru… (
#831)
- Loading branch information
1 parent
8630317
commit 7a137e1
Showing
15 changed files
with
348 additions
and
171 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
from attrs import Factory, define, field | ||
|
||
from griptape.config import StructureConfig | ||
from griptape.drivers import ( | ||
BaseEmbeddingDriver, | ||
BasePromptDriver, | ||
CoherePromptDriver, | ||
CohereEmbeddingDriver, | ||
BaseVectorStoreDriver, | ||
LocalVectorStoreDriver, | ||
) | ||
|
||
|
||
@define | ||
class CohereStructureConfig(StructureConfig): | ||
api_key: str = field(metadata={"serializable": False}, kw_only=True) | ||
|
||
prompt_driver: BasePromptDriver = field( | ||
default=Factory(lambda self: CoherePromptDriver(model="command-r", api_key=self.api_key), takes_self=True), | ||
metadata={"serializable": True}, | ||
kw_only=True, | ||
) | ||
embedding_driver: BaseEmbeddingDriver = field( | ||
default=Factory( | ||
lambda self: CohereEmbeddingDriver( | ||
model="embed-english-v3.0", api_key=self.api_key, input_type="search_document" | ||
), | ||
takes_self=True, | ||
), | ||
metadata={"serializable": True}, | ||
kw_only=True, | ||
) | ||
vector_store_driver: BaseVectorStoreDriver = field( | ||
default=Factory(lambda self: LocalVectorStoreDriver(embedding_driver=self.embedding_driver), takes_self=True), | ||
kw_only=True, | ||
metadata={"serializable": True}, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
from __future__ import annotations | ||
from typing import TYPE_CHECKING | ||
from attrs import define, field, Factory | ||
from griptape.drivers import BaseEmbeddingDriver | ||
from griptape.tokenizers import CohereTokenizer | ||
from griptape.utils import import_optional_dependency | ||
|
||
if TYPE_CHECKING: | ||
from cohere import Client | ||
|
||
|
||
@define | ||
class CohereEmbeddingDriver(BaseEmbeddingDriver): | ||
""" | ||
Attributes: | ||
api_key: Cohere API key. | ||
model: Cohere model name. | ||
client: Custom `cohere.Client`. | ||
tokenizer: Custom `CohereTokenizer`. | ||
input_type: Cohere embedding input type. | ||
""" | ||
|
||
DEFAULT_MODEL = "models/embedding-001" | ||
|
||
api_key: str = field(kw_only=True, metadata={"serializable": False}) | ||
client: Client = field( | ||
default=Factory(lambda self: import_optional_dependency("cohere").Client(self.api_key), takes_self=True), | ||
kw_only=True, | ||
) | ||
tokenizer: CohereTokenizer = field( | ||
default=Factory(lambda self: CohereTokenizer(model=self.model, client=self.client), takes_self=True), | ||
kw_only=True, | ||
) | ||
|
||
input_type: str = field(kw_only=True, metadata={"serializable": True}) | ||
|
||
def try_embed_chunk(self, chunk: str) -> list[float]: | ||
result = self.client.embed(texts=[chunk], model=self.model, input_type=self.input_type) | ||
|
||
if isinstance(result.embeddings, list): | ||
return result.embeddings[0] | ||
else: | ||
raise ValueError("Non-float embeddings are not supported.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.