Skip to content

Commit

Permalink
Refactor how Prompt Drivers use Tokenizers
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Jun 6, 2024
1 parent 30aa370 commit 9e1533e
Show file tree
Hide file tree
Showing 44 changed files with 537 additions and 523 deletions.
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `Structure.resolve_relationships()` to resolve asymmetrically defined parent/child relationships. In other words, if a parent declares a child, but the child does not declare the parent, the parent will automatically be added as a parent of the child when running this method. The method is invoked automatically by `Structure.before_run()`.
- `CohereEmbeddingDriver` for using Cohere's embeddings API.
- `CohereStructureConfig` for providing Structures with quick Cohere configuration.
- `BaseTokenizer.prompt_stack_to_string()` to convert a Prompt Stack to a string.
- `BaseTokenizer.prompt_stack_input_to_string()` to convert a Prompt Stack Input to a ChatML-style message dictionary.

### Changed
- **BREAKING**: `Workflow` no longer modifies task relationships when adding tasks via `tasks` init param, `add_tasks()` or `add_task()`. Previously, adding a task would automatically add the previously added task as its parent. Existing code that relies on this behavior will need to be updated to explicitly add parent/child relationships using the API offered by `BaseTask`.
Expand All @@ -27,6 +29,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **BREAKING**: Removed `BedrockLlamaTokenizer`, use `SimpleTokenizer` instead.
- **BREAKING**: Removed `BedrockTitanTokenizer`, use `SimpleTokenizer` instead.
- **BREAKING**: Removed `OpenAiChatCompletionPromptDriver` as it uses the legacy [OpenAi Completions API](https://platform.openai.com/docs/api-reference/completions).
- **BREAKING**: Removed `BasePromptDriver.count_tokens()`.
- **BREAKING**: Removed `BasePromptDriver.max_output_tokens()`.
- **BREAKING**: Moved `BasePromptDriver.prompt_stack_to_string()` to `BaseTokenizer`.
- **BREAKING**: Moved/renamed `PromptStack.add_to_conversation_memory` to `BaseConversationMemory.add_to_prompt_stack`.
- **BREAKING**: Moved `griptape.constants.RESPONSE_STOP_SEQUENCE` to `ToolkitTask`.
- `ToolkitTask.RESPONSE_STOP_SEQUENCE` is now only added when using `ToolkitTask`.
- `BaseTokenizer.count_tokens()` can now approximately token counts given a Prompt Stack.
- Updated Prompt Drivers to use `BasePromptDriver.max_tokens` instead of using `BasePromptDriver.max_output_tokens()`.
- Improved error message when `GriptapeCloudKnowledgeBaseClient` does not have a description set.
- Updated `AmazonBedrockPromptDriver` to use [Converse API](https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html).
- `Structure.before_run()` now automatically resolves asymmetrically defined parent/child relationships using the new `Structure.resolve_relationships()`.
Expand Down
40 changes: 24 additions & 16 deletions griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any

from collections.abc import Iterator
from attrs import define, field, Factory
from griptape.drivers import BasePromptDriver
from typing import TYPE_CHECKING, Any

from attrs import Factory, define, field

from griptape.artifacts import TextArtifact
from griptape.drivers import BasePromptDriver
from griptape.tokenizers import AmazonBedrockTokenizer, BaseTokenizer
from griptape.utils import import_optional_dependency
from griptape.tokenizers import SimpleTokenizer, BaseTokenizer

if TYPE_CHECKING:
from griptape.utils import PromptStack
import boto3

from griptape.utils import PromptStack

Check warning on line 16 in griptape/drivers/prompt/amazon_bedrock_prompt_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/prompt/amazon_bedrock_prompt_driver.py#L16

Added line #L16 was not covered by tests


@define
class AmazonBedrockPromptDriver(BasePromptDriver):
Expand All @@ -19,7 +23,7 @@ class AmazonBedrockPromptDriver(BasePromptDriver):
default=Factory(lambda self: self.session.client("bedrock-runtime"), takes_self=True), kw_only=True
)
additional_model_request_fields: dict = field(default=Factory(dict), kw_only=True)
tokenizer: BaseTokenizer = field(default=Factory(lambda: SimpleTokenizer(characters_per_token=4)), kw_only=True)
tokenizer: BaseTokenizer = field(default=Factory(lambda: AmazonBedrockTokenizer()), kw_only=True)

def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
response = self.bedrock_client.converse(**self._base_params(prompt_stack))
Expand All @@ -40,12 +44,24 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
else:
raise Exception("model response is empty")

def prompt_stack_input_to_message(self, prompt_input: PromptStack.Input) -> dict:
content = [{"text": prompt_input.content}]

Check warning on line 48 in griptape/drivers/prompt/amazon_bedrock_prompt_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/prompt/amazon_bedrock_prompt_driver.py#L48

Added line #L48 was not covered by tests

if prompt_input.is_system():
return {"text": prompt_input.content}

Check warning on line 51 in griptape/drivers/prompt/amazon_bedrock_prompt_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/prompt/amazon_bedrock_prompt_driver.py#L51

Added line #L51 was not covered by tests
elif prompt_input.is_assistant():
return {"role": "assistant", "content": content}

Check warning on line 53 in griptape/drivers/prompt/amazon_bedrock_prompt_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/prompt/amazon_bedrock_prompt_driver.py#L53

Added line #L53 was not covered by tests
else:
return {"role": "user", "content": content}

Check warning on line 55 in griptape/drivers/prompt/amazon_bedrock_prompt_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/prompt/amazon_bedrock_prompt_driver.py#L55

Added line #L55 was not covered by tests

def _base_params(self, prompt_stack: PromptStack) -> dict:
system_messages = [
{"text": input.content} for input in prompt_stack.inputs if input.is_system() and input.content
self.tokenizer.prompt_stack_input_to_message(input)
for input in prompt_stack.inputs
if input.is_system() and input.content
]
messages = [
{"role": self.__to_amazon_bedrock_role(input), "content": [{"text": input.content}]}
self.tokenizer.prompt_stack_input_to_message(input)
for input in prompt_stack.inputs
if not input.is_system()
]
Expand All @@ -57,11 +73,3 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
"inferenceConfig": {"temperature": self.temperature},
"additionalModelRequestFields": self.additional_model_request_fields,
}

def __to_amazon_bedrock_role(self, prompt_input: PromptStack.Input) -> str:
if prompt_input.is_system():
return "system"
elif prompt_input.is_assistant():
return "assistant"
else:
return "user"
18 changes: 5 additions & 13 deletions griptape/drivers/prompt/anthropic_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from griptape.artifacts import TextArtifact
from griptape.utils import PromptStack, import_optional_dependency
from griptape.drivers import BasePromptDriver
from griptape.tokenizers import AnthropicTokenizer
from griptape.tokenizers import AnthropicTokenizer, BaseTokenizer


@define
Expand All @@ -15,7 +15,6 @@ class AnthropicPromptDriver(BasePromptDriver):
api_key: Anthropic API key.
model: Anthropic model name.
client: Custom `Anthropic` client.
tokenizer: Custom `AnthropicTokenizer`.
"""

api_key: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": False})
Expand All @@ -26,11 +25,12 @@ class AnthropicPromptDriver(BasePromptDriver):
),
kw_only=True,
)
tokenizer: AnthropicTokenizer = field(
tokenizer: BaseTokenizer = field(
default=Factory(lambda self: AnthropicTokenizer(model=self.model), takes_self=True), kw_only=True
)
top_p: float = field(default=0.999, kw_only=True, metadata={"serializable": True})
top_k: int = field(default=250, kw_only=True, metadata={"serializable": True})
max_tokens: int = field(default=1000, kw_only=True, metadata={"serializable": True})

def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
response = self.client.messages.create(**self._base_params(prompt_stack))
Expand All @@ -46,7 +46,7 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:

def _prompt_stack_to_model_input(self, prompt_stack: PromptStack) -> dict:
messages = [
{"role": self.__to_anthropic_role(prompt_input), "content": prompt_input.content}
self.tokenizer.prompt_stack_input_to_message(prompt_input)
for prompt_input in prompt_stack.inputs
if not prompt_input.is_system()
]
Expand All @@ -62,16 +62,8 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
"model": self.model,
"temperature": self.temperature,
"stop_sequences": self.tokenizer.stop_sequences,
"max_tokens": self.max_output_tokens(self.prompt_stack_to_string(prompt_stack)),
"top_p": self.top_p,
"top_k": self.top_k,
"max_tokens": self.max_tokens,
**self._prompt_stack_to_model_input(prompt_stack),
}

def __to_anthropic_role(self, prompt_input: PromptStack.Input) -> str:
if prompt_input.is_system():
return "system"
elif prompt_input.is_assistant():
return "assistant"
else:
return "user"
39 changes: 6 additions & 33 deletions griptape/drivers/prompt/base_prompt_driver.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Optional, Callable
from typing import TYPE_CHECKING, Optional
from collections.abc import Iterator
from attrs import define, field, Factory
from griptape.events import StartPromptEvent, FinishPromptEvent, CompletionChunkEvent
Expand Down Expand Up @@ -32,42 +32,30 @@ class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC):
temperature: float = field(default=0.1, kw_only=True, metadata={"serializable": True})
max_tokens: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True})
structure: Optional[Structure] = field(default=None, kw_only=True)
prompt_stack_to_string: Callable[[PromptStack], str] = field(
default=Factory(lambda self: self.default_prompt_stack_to_string_converter, takes_self=True), kw_only=True
)
ignored_exception_types: tuple[type[Exception], ...] = field(
default=Factory(lambda: (ImportError, ValueError)), kw_only=True
)
model: str = field(metadata={"serializable": True})
tokenizer: BaseTokenizer
stream: bool = field(default=False, kw_only=True, metadata={"serializable": True})

def max_output_tokens(self, text: str | list) -> int:
tokens_left = self.tokenizer.count_output_tokens_left(text)

if self.max_tokens:
return min(self.max_tokens, tokens_left)
else:
return tokens_left

def token_count(self, prompt_stack: PromptStack) -> int:
return self.tokenizer.count_tokens(self.prompt_stack_to_string(prompt_stack))

def before_run(self, prompt_stack: PromptStack) -> None:
if self.structure:
self.structure.publish_event(
StartPromptEvent(
model=self.model,
token_count=self.token_count(prompt_stack),
token_count=self.tokenizer.count_tokens(prompt_stack),
prompt_stack=prompt_stack,
prompt=self.prompt_stack_to_string(prompt_stack),
prompt=self.tokenizer.prompt_stack_to_string(prompt_stack),
)
)

def after_run(self, result: TextArtifact) -> None:
if self.structure:
self.structure.publish_event(
FinishPromptEvent(model=self.model, token_count=result.token_count(self.tokenizer), result=result.value)
FinishPromptEvent(
model=self.model, result=result.value, token_count=self.tokenizer.count_tokens(result.value)
)
)

def run(self, prompt_stack: PromptStack) -> TextArtifact:
Expand All @@ -92,21 +80,6 @@ def run(self, prompt_stack: PromptStack) -> TextArtifact:
else:
raise Exception("prompt driver failed after all retry attempts")

def default_prompt_stack_to_string_converter(self, prompt_stack: PromptStack) -> str:
prompt_lines = []

for i in prompt_stack.inputs:
if i.is_user():
prompt_lines.append(f"User: {i.content}")
elif i.is_assistant():
prompt_lines.append(f"Assistant: {i.content}")
else:
prompt_lines.append(i.content)

prompt_lines.append("Assistant:")

return "\n\n".join(prompt_lines)

@abstractmethod
def try_run(self, prompt_stack: PromptStack) -> TextArtifact: ...

Expand Down
32 changes: 14 additions & 18 deletions griptape/drivers/prompt/cohere_prompt_driver.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING
from collections.abc import Iterator
from attrs import define, field, Factory
from griptape.artifacts import TextArtifact
from griptape.drivers import BasePromptDriver
from griptape.tokenizers import CohereTokenizer
from griptape.utils import PromptStack, import_optional_dependency
from griptape.tokenizers import BaseTokenizer, CohereTokenizer

if TYPE_CHECKING:
from cohere import Client
Expand All @@ -18,7 +18,6 @@ class CoherePromptDriver(BasePromptDriver):
api_key: Cohere API key.
model: Cohere model name.
client: Custom `cohere.Client`.
tokenizer: Custom `CohereTokenizer`.
"""

api_key: str = field(kw_only=True, metadata={"serializable": False})
Expand All @@ -27,7 +26,7 @@ class CoherePromptDriver(BasePromptDriver):
default=Factory(lambda self: import_optional_dependency("cohere").Client(self.api_key), takes_self=True),
kw_only=True,
)
tokenizer: CohereTokenizer = field(
tokenizer: BaseTokenizer = field(
default=Factory(lambda self: CohereTokenizer(model=self.model, client=self.client), takes_self=True),
kw_only=True,
)
Expand All @@ -44,26 +43,23 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
if event.event_type == "text-generation":
yield TextArtifact(value=event.text)

def prompt_stack_input_to_message(self, prompt_input: PromptStack.Input) -> dict:
if prompt_input.is_system():
return {"role": "SYSTEM", "text": prompt_input.content}

Check warning on line 48 in griptape/drivers/prompt/cohere_prompt_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/prompt/cohere_prompt_driver.py#L48

Added line #L48 was not covered by tests
elif prompt_input.is_user():
return {"role": "USER", "text": prompt_input.content}

Check warning on line 50 in griptape/drivers/prompt/cohere_prompt_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/prompt/cohere_prompt_driver.py#L50

Added line #L50 was not covered by tests
else:
return {"role": "ASSISTANT", "text": prompt_input.content}

Check warning on line 52 in griptape/drivers/prompt/cohere_prompt_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/prompt/cohere_prompt_driver.py#L52

Added line #L52 was not covered by tests

def _base_params(self, prompt_stack: PromptStack) -> dict:
user_message = prompt_stack.inputs[-1].content
history_messages = [self.__to_cohere_message(input) for input in prompt_stack.inputs[:-1]]

history_messages = [self.tokenizer.prompt_stack_input_to_message(input) for input in prompt_stack.inputs[:-1]]

return {
"message": user_message,
"chat_history": history_messages,
"temperature": self.temperature,
"stop_sequences": self.tokenizer.stop_sequences,
"max_tokens": self.max_tokens,
}

def __to_cohere_message(self, input: PromptStack.Input) -> dict[str, Any]:
return {"role": self.__to_cohere_role(input.role), "text": input.content}

def __to_cohere_role(self, role: str) -> str:
if role == PromptStack.SYSTEM_ROLE:
return "SYSTEM"
if role == PromptStack.USER_ROLE:
return "USER"
elif role == PromptStack.ASSISTANT_ROLE:
return "CHATBOT"
else:
return "USER"
24 changes: 12 additions & 12 deletions griptape/drivers/prompt/google_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ class GooglePromptDriver(BasePromptDriver):
api_key: Google API key.
model: Google model name.
model_client: Custom `GenerativeModel` client.
tokenizer: Custom `GoogleTokenizer`.
top_p: Optional value for top_p.
top_k: Optional value for top_k.
"""
Expand All @@ -42,7 +41,7 @@ def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
inputs,
generation_config=GenerationConfig(
stop_sequences=self.tokenizer.stop_sequences,
max_output_tokens=self.max_output_tokens(inputs),
max_output_tokens=self.max_tokens,
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
Expand All @@ -60,7 +59,7 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
stream=True,
generation_config=GenerationConfig(
stop_sequences=self.tokenizer.stop_sequences,
max_output_tokens=self.max_output_tokens(inputs),
max_output_tokens=self.max_tokens,
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
Expand All @@ -70,6 +69,14 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
for chunk in response:
yield TextArtifact(value=chunk.text)

def prompt_stack_input_to_message(self, prompt_input: PromptStack.Input) -> dict:
parts = [prompt_input.content]

Check warning on line 73 in griptape/drivers/prompt/google_prompt_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/prompt/google_prompt_driver.py#L73

Added line #L73 was not covered by tests

if prompt_input.is_assistant():
return {"role": "model", "parts": parts}

Check warning on line 76 in griptape/drivers/prompt/google_prompt_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/prompt/google_prompt_driver.py#L76

Added line #L76 was not covered by tests
else:
return {"role": "user", "parts": parts}

Check warning on line 78 in griptape/drivers/prompt/google_prompt_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/prompt/google_prompt_driver.py#L78

Added line #L78 was not covered by tests

def _default_model_client(self) -> GenerativeModel:
genai = import_optional_dependency("google.generativeai")
genai.configure(api_key=self.api_key)
Expand All @@ -90,13 +97,6 @@ def _prompt_stack_to_model_input(self, prompt_stack: PromptStack) -> list[Conten

def __to_content_dict(self, prompt_input: PromptStack.Input) -> ContentDict:
ContentDict = import_optional_dependency("google.generativeai.types").ContentDict
message = self.tokenizer.prompt_stack_input_to_message(prompt_input)

return ContentDict({"role": self.__to_google_role(prompt_input), "parts": [prompt_input.content]})

def __to_google_role(self, prompt_input: PromptStack.Input) -> str:
if prompt_input.is_system():
return "user"
elif prompt_input.is_assistant():
return "model"
else:
return "user"
return ContentDict(message)
15 changes: 4 additions & 11 deletions griptape/drivers/prompt/huggingface_hub_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,27 +52,20 @@ class HuggingFaceHubPromptDriver(BasePromptDriver):
)

def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
prompt = self.__to_prompt(prompt_stack)
prompt = self.tokenizer.prompt_stack_to_string(prompt_stack)

response = self.client.text_generation(
prompt, return_full_text=False, max_new_tokens=self.max_output_tokens(prompt), **self.params
prompt, return_full_text=False, max_new_tokens=self.max_tokens, **self.params
)

return TextArtifact(value=response)

def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
prompt = self.__to_prompt(prompt_stack)
prompt = self.tokenizer.prompt_stack_to_string(prompt_stack)

response = self.client.text_generation(
prompt, return_full_text=False, max_new_tokens=self.max_output_tokens(prompt), stream=True, **self.params
prompt, return_full_text=False, max_new_tokens=self.max_tokens, stream=True, **self.params
)

for token in response:
yield TextArtifact(value=token)

def __to_prompt(self, prompt_stack: PromptStack) -> str:
tokens = self.tokenizer.tokenizer.apply_chat_template(
[{"role": i.role, "content": i.content} for i in prompt_stack.inputs], add_generation_prompt=True
)

return self.tokenizer.tokenizer.decode(tokens)
Loading

0 comments on commit 9e1533e

Please sign in to comment.