Skip to content

Commit

Permalink
Update cohere prompt driver, add cohere embedding driver, cohere stru… (
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter authored Jun 5, 2024
1 parent 8630317 commit 7a137e1
Show file tree
Hide file tree
Showing 15 changed files with 348 additions and 171 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `BaseTask.add_parent()` to add a parent task to a child task.
- `BaseTask.add_parents()` to add multiple parent tasks to a child task.
- `Structure.resolve_relationships()` to resolve asymmetrically defined parent/child relationships. In other words, if a parent declares a child, but the child does not declare the parent, the parent will automatically be added as a parent of the child when running this method. The method is invoked automatically by `Structure.before_run()`.
- `CohereEmbeddingDriver` for using Cohere's embeddings API.
- `CohereStructureConfig` for providing Structures with quick Cohere configuration.

### Changed
- **BREAKING**: `Workflow` no longer modifies task relationships when adding tasks via `tasks` init param, `add_tasks()` or `add_task()`. Previously, adding a task would automatically add the previously added task as its parent. Existing code that relies on this behavior will need to be updated to explicitly add parent/child relationships using the API offered by `BaseTask`.
- `Structure.before_run()` now automatically resolves asymmetrically defined parent/child relationships using the new `Structure.resolve_relationships()`.
- Updated `HuggingFaceHubPromptDriver` to use `transformers`'s `apply_chat_template`.
- Updated `HuggingFacePipelinePromptDriver` to use chat features of `transformers.TextGenerationPipeline`.
- Updated `CoherePromptDriver` to use Cohere's latest SDK.

### Fixed
- `Workflow.insert_task()` no longer inserts duplicate tasks when given multiple parent tasks.
Expand Down
23 changes: 23 additions & 0 deletions docs/griptape-framework/drivers/embedding-drivers.md
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,29 @@ embeddings = driver.embed_string("Hello world!")
print(embeddings[:3])
```

### Cohere Embeddings

The [CohereEmbeddingDriver](../../reference/griptape/drivers/embedding/cohere_embedding_driver.md) uses the [Cohere Embeddings API](https://docs.cohere.com/docs/embeddings).

!!! info
This driver requires the `drivers-embedding-cohere` [extra](../index.md#extras).

```python
import os
from griptape.drivers import CohereEmbeddingDriver

embedding_driver=CohereEmbeddingDriver(
model="embed-english-v3.0",
api_key=os.environ["COHERE_API_KEY"],
input_type="search_document",
)

embeddings = embedding_driver.embed_string("Hello world!")

# display the first 3 embeddings
print(embeddings[:3])
```

### Override Default Structure Embedding Driver
Here is how you can override the Embedding Driver that is used by default in Structures.

Expand Down
2 changes: 2 additions & 0 deletions griptape/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .amazon_bedrock_structure_config import AmazonBedrockStructureConfig
from .anthropic_structure_config import AnthropicStructureConfig
from .google_structure_config import GoogleStructureConfig
from .cohere_structure_config import CohereStructureConfig


__all__ = [
Expand All @@ -19,4 +20,5 @@
"AmazonBedrockStructureConfig",
"AnthropicStructureConfig",
"GoogleStructureConfig",
"CohereStructureConfig",
]
37 changes: 37 additions & 0 deletions griptape/config/cohere_structure_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from attrs import Factory, define, field

from griptape.config import StructureConfig
from griptape.drivers import (
BaseEmbeddingDriver,
BasePromptDriver,
CoherePromptDriver,
CohereEmbeddingDriver,
BaseVectorStoreDriver,
LocalVectorStoreDriver,
)


@define
class CohereStructureConfig(StructureConfig):
api_key: str = field(metadata={"serializable": False}, kw_only=True)

prompt_driver: BasePromptDriver = field(
default=Factory(lambda self: CoherePromptDriver(model="command-r", api_key=self.api_key), takes_self=True),
metadata={"serializable": True},
kw_only=True,
)
embedding_driver: BaseEmbeddingDriver = field(
default=Factory(
lambda self: CohereEmbeddingDriver(
model="embed-english-v3.0", api_key=self.api_key, input_type="search_document"
),
takes_self=True,
),
metadata={"serializable": True},
kw_only=True,
)
vector_store_driver: BaseVectorStoreDriver = field(
default=Factory(lambda self: LocalVectorStoreDriver(embedding_driver=self.embedding_driver), takes_self=True),
kw_only=True,
metadata={"serializable": True},
)
2 changes: 2 additions & 0 deletions griptape/drivers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from .embedding.huggingface_hub_embedding_driver import HuggingFaceHubEmbeddingDriver
from .embedding.google_embedding_driver import GoogleEmbeddingDriver
from .embedding.dummy_embedding_driver import DummyEmbeddingDriver
from .embedding.cohere_embedding_driver import CohereEmbeddingDriver

from .embedding_model.base_embedding_model_driver import BaseEmbeddingModelDriver
from .embedding_model.sagemaker_huggingface_embedding_model_driver import SageMakerHuggingFaceEmbeddingModelDriver
Expand Down Expand Up @@ -143,6 +144,7 @@
"GoogleEmbeddingDriver",
"DummyEmbeddingDriver",
"BaseEmbeddingModelDriver",
"CohereEmbeddingDriver",
"SageMakerHuggingFaceEmbeddingModelDriver",
"SageMakerTensorFlowHubEmbeddingModelDriver",
"BaseVectorStoreDriver",
Expand Down
43 changes: 43 additions & 0 deletions griptape/drivers/embedding/cohere_embedding_driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from attrs import define, field, Factory
from griptape.drivers import BaseEmbeddingDriver
from griptape.tokenizers import CohereTokenizer
from griptape.utils import import_optional_dependency

if TYPE_CHECKING:
from cohere import Client


@define
class CohereEmbeddingDriver(BaseEmbeddingDriver):
"""
Attributes:
api_key: Cohere API key.
model: Cohere model name.
client: Custom `cohere.Client`.
tokenizer: Custom `CohereTokenizer`.
input_type: Cohere embedding input type.
"""

DEFAULT_MODEL = "models/embedding-001"

api_key: str = field(kw_only=True, metadata={"serializable": False})
client: Client = field(
default=Factory(lambda self: import_optional_dependency("cohere").Client(self.api_key), takes_self=True),
kw_only=True,
)
tokenizer: CohereTokenizer = field(
default=Factory(lambda self: CohereTokenizer(model=self.model, client=self.client), takes_self=True),
kw_only=True,
)

input_type: str = field(kw_only=True, metadata={"serializable": True})

def try_embed_chunk(self, chunk: str) -> list[float]:
result = self.client.embed(texts=[chunk], model=self.model, input_type=self.input_type)

if isinstance(result.embeddings, list):
return result.embeddings[0]
else:
raise ValueError("Non-float embeddings are not supported.")
50 changes: 27 additions & 23 deletions griptape/drivers/prompt/cohere_prompt_driver.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any
from collections.abc import Iterator
from attrs import define, field, Factory
from griptape.artifacts import TextArtifact
Expand All @@ -21,7 +21,7 @@ class CoherePromptDriver(BasePromptDriver):
tokenizer: Custom `CohereTokenizer`.
"""

api_key: str = field(kw_only=True, metadata={"serializable": True})
api_key: str = field(kw_only=True, metadata={"serializable": False})
model: str = field(kw_only=True, metadata={"serializable": True})
client: Client = field(
default=Factory(lambda self: import_optional_dependency("cohere").Client(self.api_key), takes_self=True),
Expand All @@ -33,33 +33,37 @@ class CoherePromptDriver(BasePromptDriver):
)

def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
result = self.client.generate(**self._base_params(prompt_stack))
result = self.client.chat(**self._base_params(prompt_stack))

if result.generations:
if len(result.generations) == 1:
generation = result.generations[0]

return TextArtifact(value=generation.text.strip())
else:
raise Exception("completion with more than one choice is not supported yet")
else:
raise Exception("model response is empty")
return TextArtifact(value=result.text)

def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
result = self.client.generate(
**self._base_params(prompt_stack),
stream=True, # pyright: ignore[reportCallIssue]
)
result = self.client.chat_stream(**self._base_params(prompt_stack))

for chunk in result:
yield TextArtifact(value=chunk.text)
for event in result:
if event.event_type == "text-generation":
yield TextArtifact(value=event.text)

def _base_params(self, prompt_stack: PromptStack) -> dict:
prompt = self.prompt_stack_to_string(prompt_stack)
user_message = prompt_stack.inputs[-1].content
history_messages = [self.__to_cohere_message(input) for input in prompt_stack.inputs[:-1]]

return {
"prompt": self.prompt_stack_to_string(prompt_stack),
"model": self.model,
"message": user_message,
"chat_history": history_messages,
"temperature": self.temperature,
"end_sequences": self.tokenizer.stop_sequences,
"max_tokens": self.max_output_tokens(prompt),
"stop_sequences": self.tokenizer.stop_sequences,
}

def __to_cohere_message(self, input: PromptStack.Input) -> dict[str, Any]:
return {"role": self.__to_cohere_role(input.role), "text": input.content}

def __to_cohere_role(self, role: str) -> str:
if role == PromptStack.SYSTEM_ROLE:
return "SYSTEM"
if role == PromptStack.USER_ROLE:
return "USER"
elif role == PromptStack.ASSISTANT_ROLE:
return "CHATBOT"
else:
return "USER"
2 changes: 2 additions & 0 deletions griptape/schemas/base_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def _resolve_types(cls, attrs_cls: type) -> None:
from typing import Any

boto3 = import_optional_dependency("boto3") if is_dependency_installed("boto3") else Any
Client = import_optional_dependency("cohere").Client if is_dependency_installed("cohere") else Any

attrs.resolve_types(
attrs_cls,
Expand All @@ -122,6 +123,7 @@ def _resolve_types(cls, attrs_cls: type) -> None:
"BaseTokenizer": BaseTokenizer,
"BasePromptModelDriver": BasePromptModelDriver,
"boto3": boto3,
"Client": Client,
},
)

Expand Down
4 changes: 2 additions & 2 deletions griptape/tokenizers/cohere_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

@define()
class CohereTokenizer(BaseTokenizer):
MODEL_PREFIXES_TO_MAX_INPUT_TOKENS = {"command": 4096}
MODEL_PREFIXES_TO_MAX_OUTPUT_TOKENS = {"command": 4096}
MODEL_PREFIXES_TO_MAX_INPUT_TOKENS = {"command-r": 128000, "command": 4096, "embed": 512}
MODEL_PREFIXES_TO_MAX_OUTPUT_TOKENS = {"command": 4096, "embed": 512}

client: Client = field(kw_only=True)

Expand Down
Loading

0 comments on commit 7a137e1

Please sign in to comment.