Skip to content

Commit

Permalink
Rename deltas
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Jun 19, 2024
1 parent e94f503 commit a53762d
Show file tree
Hide file tree
Showing 25 changed files with 81 additions and 95 deletions.
4 changes: 2 additions & 2 deletions griptape/common/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .prompt_stack.contents.base_prompt_stack_content import BasePromptStackContent
from .prompt_stack.contents.base_delta_prompt_stack_content import BaseDeltaPromptStackContent
from .prompt_stack.contents.delta_text_prompt_stack_content import DeltaTextPromptStackContent
from .prompt_stack.contents.text_delta_prompt_stack_content import TextDeltaPromptStackContent
from .prompt_stack.contents.text_prompt_stack_content import TextPromptStackContent
from .prompt_stack.contents.image_prompt_stack_content import ImagePromptStackContent

Expand All @@ -16,7 +16,7 @@
"BasePromptStackContent",
"DeltaPromptStackMessage",
"PromptStackMessage",
"DeltaTextPromptStackContent",
"TextDeltaPromptStackContent",
"TextPromptStackContent",
"ImagePromptStackContent",
"PromptStack",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

from abc import ABC
from typing import Optional

from attrs import define, field

Expand All @@ -11,4 +10,3 @@
@define
class BaseDeltaPromptStackContent(ABC, SerializableMixin):
index: int = field(kw_only=True, default=0, metadata={"serializable": True})
role: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": True})
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@


@define
class DeltaTextPromptStackContent(BaseDeltaPromptStackContent):
class TextDeltaPromptStackContent(BaseDeltaPromptStackContent):
text: str = field(metadata={"serializable": True})
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from collections.abc import Sequence

from griptape.artifacts import TextArtifact
from griptape.common import BasePromptStackContent, BaseDeltaPromptStackContent, DeltaTextPromptStackContent
from griptape.common import BasePromptStackContent, BaseDeltaPromptStackContent, TextDeltaPromptStackContent


@define
Expand All @@ -13,7 +13,7 @@ class TextPromptStackContent(BasePromptStackContent):

@classmethod
def from_deltas(cls, deltas: Sequence[BaseDeltaPromptStackContent]) -> TextPromptStackContent:
text_deltas = [delta for delta in deltas if isinstance(delta, DeltaTextPromptStackContent)]
text_deltas = [delta for delta in deltas if isinstance(delta, TextDeltaPromptStackContent)]

artifact = TextArtifact(value="".join(delta.text for delta in text_deltas))

Expand Down
Original file line number Diff line number Diff line change
@@ -1,19 +1,40 @@
from __future__ import annotations

from abc import ABC
from typing import Optional, Union
from attrs import Factory, define, field

from attrs import define, field

from griptape.common import BasePromptStackContent, BaseDeltaPromptStackContent
from griptape.mixins import SerializableMixin


@define
class BasePromptStackMessage(ABC, SerializableMixin):
@define
class Usage(SerializableMixin):
input_tokens: Optional[float] = field(kw_only=True, default=None, metadata={"serializable": True})
output_tokens: Optional[float] = field(kw_only=True, default=None, metadata={"serializable": True})

@property
def total_tokens(self) -> float:
return (self.input_tokens or 0) + (self.output_tokens or 0)

def __add__(self, other: BasePromptStackMessage.Usage) -> BasePromptStackMessage.Usage:
return BasePromptStackMessage.Usage(
input_tokens=(self.input_tokens or 0) + (other.input_tokens or 0),
output_tokens=(self.output_tokens or 0) + (other.output_tokens or 0),
)

USER_ROLE = "user"
ASSISTANT_ROLE = "assistant"
SYSTEM_ROLE = "system"

content: list[Union[BasePromptStackContent, BaseDeltaPromptStackContent]] = field(metadata={"serializable": True})
role: str = field(kw_only=True, metadata={"serializable": True})
usage: Usage = field(
kw_only=True, default=Factory(lambda: BasePromptStackMessage.Usage()), metadata={"serializable": True}
)

def is_system(self) -> bool:
return self.role == self.SYSTEM_ROLE
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,13 @@

from attrs import define, field

from griptape.common.prompt_stack.contents.delta_text_prompt_stack_content import DeltaTextPromptStackContent
from griptape.common.prompt_stack.contents.text_delta_prompt_stack_content import TextDeltaPromptStackContent


from .base_prompt_stack_message import BasePromptStackMessage


@define
class DeltaPromptStackMessage(BasePromptStackMessage):
@define
class DeltaUsage:
input_tokens: Optional[float] = field(kw_only=True, default=None, metadata={"serializable": True})
output_tokens: Optional[float] = field(kw_only=True, default=None, metadata={"serializable": True})

@property
def total_tokens(self) -> float:
return (self.input_tokens or 0) + (self.output_tokens or 0)

def __add__(self, other: DeltaPromptStackMessage.DeltaUsage) -> DeltaPromptStackMessage.DeltaUsage:
return DeltaPromptStackMessage.DeltaUsage(
input_tokens=(self.input_tokens or 0) + (other.input_tokens or 0),
output_tokens=(self.output_tokens or 0) + (other.output_tokens or 0),
)

role: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": True})
delta_content: Optional[DeltaTextPromptStackContent] = field(
kw_only=True, default=None, metadata={"serializable": True}
)
delta_usage: DeltaUsage = field(kw_only=True, default=DeltaUsage(), metadata={"serializable": True})
content: Optional[TextDeltaPromptStackContent] = field(kw_only=True, default=None, metadata={"serializable": True})
17 changes: 2 additions & 15 deletions griptape/common/prompt_stack/messages/prompt_stack_message.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,23 @@
from __future__ import annotations

from typing import Any, Optional
from typing import Any

from attrs import Factory, define, field
from attrs import define, field

from griptape.artifacts import TextArtifact
from griptape.common import BasePromptStackContent, TextPromptStackContent
from griptape.mixins.serializable_mixin import SerializableMixin

from .base_prompt_stack_message import BasePromptStackMessage


@define
class PromptStackMessage(BasePromptStackMessage):
@define
class Usage(SerializableMixin):
input_tokens: Optional[float] = field(kw_only=True, default=None, metadata={"serializable": True})
output_tokens: Optional[float] = field(kw_only=True, default=None, metadata={"serializable": True})

@property
def total_tokens(self) -> float:
return (self.input_tokens or 0) + (self.output_tokens or 0)

def __init__(self, content: str | list[BasePromptStackContent], **kwargs: Any):
if isinstance(content, str):
content = [TextPromptStackContent(TextArtifact(value=content))]
self.__attrs_init__(content, **kwargs) # pyright: ignore[reportAttributeAccessIssue]

content: list[BasePromptStackContent] = field(metadata={"serializable": True})
usage: Usage = field(
kw_only=True, default=Factory(lambda: PromptStackMessage.Usage()), metadata={"serializable": True}
)

@property
def value(self) -> Any:
Expand Down
6 changes: 3 additions & 3 deletions griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
BaseDeltaPromptStackContent,
DeltaPromptStackMessage,
PromptStackMessage,
DeltaTextPromptStackContent,
TextDeltaPromptStackContent,
BasePromptStackContent,
TextPromptStackContent,
ImagePromptStackContent,
Expand Down Expand Up @@ -56,13 +56,13 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMess
for event in stream:
if "contentBlockDelta" in event:
content_block_delta = event["contentBlockDelta"]
yield DeltaTextPromptStackContent(
yield TextDeltaPromptStackContent(
content_block_delta["delta"]["text"], index=content_block_delta["contentBlockIndex"]
)
elif "metadata" in event:
usage = event["metadata"]["usage"]
yield DeltaPromptStackMessage(
delta_usage=DeltaPromptStackMessage.DeltaUsage(
usage=DeltaPromptStackMessage.Usage(
input_tokens=usage["inputTokens"], output_tokens=usage["outputTokens"]
)
)
Expand Down
8 changes: 4 additions & 4 deletions griptape/drivers/prompt/anthropic_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
BaseDeltaPromptStackContent,
BasePromptStackContent,
DeltaPromptStackMessage,
DeltaTextPromptStackContent,
TextDeltaPromptStackContent,
ImagePromptStackContent,
PromptStack,
PromptStackMessage,
Expand Down Expand Up @@ -68,11 +68,11 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMess
yield self.__message_content_delta_to_prompt_stack_content_delta(event)
elif event.type == "message_start":
yield DeltaPromptStackMessage(
delta_usage=DeltaPromptStackMessage.DeltaUsage(input_tokens=event.message.usage.input_tokens)
usage=DeltaPromptStackMessage.Usage(input_tokens=event.message.usage.input_tokens)
)
elif event.type == "message_delta":
yield DeltaPromptStackMessage(
delta_usage=DeltaPromptStackMessage.DeltaUsage(output_tokens=event.usage.output_tokens)
usage=DeltaPromptStackMessage.Usage(output_tokens=event.usage.output_tokens)
)

def _prompt_stack_messages_to_messages(self, elements: list[PromptStackMessage]) -> list[dict]:
Expand Down Expand Up @@ -135,6 +135,6 @@ def __message_content_delta_to_prompt_stack_content_delta(
index = content_delta.index

if content_delta.delta.type == "text_delta":
return DeltaTextPromptStackContent(content_delta.delta.text, index=index)
return TextDeltaPromptStackContent(content_delta.delta.text, index=index)
else:
raise ValueError(f"Unsupported message content delta type : {content_delta.delta.type}")
12 changes: 6 additions & 6 deletions griptape/drivers/prompt/base_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from griptape.common import (
BaseDeltaPromptStackContent,
DeltaPromptStackMessage,
DeltaTextPromptStackContent,
TextDeltaPromptStackContent,
PromptStack,
PromptStackMessage,
TextPromptStackContent,
Expand Down Expand Up @@ -119,33 +119,33 @@ def __process_run(self, prompt_stack: PromptStack) -> PromptStackMessage:

def __process_stream(self, prompt_stack: PromptStack) -> PromptStackMessage:
delta_contents: dict[int, list[BaseDeltaPromptStackContent]] = {}
delta_usage = DeltaPromptStackMessage.DeltaUsage()
usage = DeltaPromptStackMessage.Usage()

deltas = self.try_stream(prompt_stack)

for delta in deltas:
if isinstance(delta, DeltaPromptStackMessage):
delta_usage += delta.delta_usage
usage += delta.usage
elif isinstance(delta, BaseDeltaPromptStackContent):
if delta.index in delta_contents:
delta_contents[delta.index].append(delta)
else:
delta_contents[delta.index] = [delta]

if isinstance(delta, DeltaTextPromptStackContent):
if isinstance(delta, TextDeltaPromptStackContent):
self.structure.publish_event(CompletionChunkEvent(token=delta.text))

content = []
for index, deltas in delta_contents.items():
text_deltas = [delta for delta in deltas if isinstance(delta, DeltaTextPromptStackContent)]
text_deltas = [delta for delta in deltas if isinstance(delta, TextDeltaPromptStackContent)]
if text_deltas:
content.append(TextPromptStackContent.from_deltas(text_deltas))

result = PromptStackMessage(
content=content,
role=PromptStackMessage.ASSISTANT_ROLE,
usage=PromptStackMessage.Usage(
input_tokens=delta_usage.input_tokens or 0, output_tokens=delta_usage.output_tokens or 0
input_tokens=usage.input_tokens or 0, output_tokens=usage.output_tokens or 0
),
)

Expand Down
6 changes: 3 additions & 3 deletions griptape/drivers/prompt/cohere_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
BaseDeltaPromptStackContent,
TextPromptStackContent,
BasePromptStackContent,
DeltaTextPromptStackContent,
TextDeltaPromptStackContent,
)
from griptape.utils import import_optional_dependency
from griptape.tokenizers import BaseTokenizer
Expand Down Expand Up @@ -56,12 +56,12 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMess

for event in result:
if event.event_type == "text-generation":
yield DeltaTextPromptStackContent(event.text, index=0)
yield TextDeltaPromptStackContent(event.text, index=0)
if event.event_type == "stream-end":
usage = event.response.meta.tokens

yield DeltaPromptStackMessage(
delta_usage=DeltaPromptStackMessage.DeltaUsage(
usage=DeltaPromptStackMessage.Usage(
input_tokens=usage.input_tokens, output_tokens=usage.output_tokens
)
)
Expand Down
6 changes: 3 additions & 3 deletions griptape/drivers/prompt/google_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
BaseDeltaPromptStackContent,
BasePromptStackContent,
DeltaPromptStackMessage,
DeltaTextPromptStackContent,
TextDeltaPromptStackContent,
ImagePromptStackContent,
PromptStack,
PromptStackMessage,
Expand Down Expand Up @@ -90,11 +90,11 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMess
for chunk in response:
usage_metadata = chunk.usage_metadata

yield DeltaTextPromptStackContent(chunk.text)
yield TextDeltaPromptStackContent(chunk.text)

# TODO: Only yield the first one
yield DeltaPromptStackMessage(
delta_usage=DeltaPromptStackMessage.DeltaUsage(
usage=DeltaPromptStackMessage.Usage(
input_tokens=usage_metadata.prompt_token_count, output_tokens=usage_metadata.candidates_token_count
)
)
Expand Down
6 changes: 3 additions & 3 deletions griptape/drivers/prompt/huggingface_hub_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
DeltaPromptStackMessage,
BaseDeltaPromptStackContent,
TextPromptStackContent,
DeltaTextPromptStackContent,
TextDeltaPromptStackContent,
)
from griptape.utils import import_optional_dependency

Expand Down Expand Up @@ -81,11 +81,11 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMess
full_text = ""
for token in response:
full_text += token
yield DeltaTextPromptStackContent(token, index=0)
yield TextDeltaPromptStackContent(token, index=0)

output_tokens = len(self.tokenizer.tokenizer.encode(full_text))
yield DeltaPromptStackMessage(
delta_usage=DeltaPromptStackMessage.DeltaUsage(input_tokens=input_tokens, output_tokens=output_tokens)
usage=DeltaPromptStackMessage.Usage(input_tokens=input_tokens, output_tokens=output_tokens)
)

def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str:
Expand Down
4 changes: 2 additions & 2 deletions griptape/drivers/prompt/ollama_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
PromptStackMessage,
BaseDeltaPromptStackContent,
DeltaPromptStackMessage,
DeltaTextPromptStackContent,
TextDeltaPromptStackContent,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -69,7 +69,7 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMess

if isinstance(stream, Iterator):
for chunk in stream:
yield DeltaTextPromptStackContent(chunk["message"]["content"])
yield TextDeltaPromptStackContent(chunk["message"]["content"])
else:
raise Exception("invalid model response")

Expand Down
8 changes: 4 additions & 4 deletions griptape/drivers/prompt/openai_chat_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
BaseDeltaPromptStackContent,
BasePromptStackContent,
DeltaPromptStackMessage,
DeltaTextPromptStackContent,
TextDeltaPromptStackContent,
ImagePromptStackContent,
PromptStack,
PromptStackMessage,
Expand Down Expand Up @@ -98,7 +98,7 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMess
for chunk in result:
if chunk.usage is not None:
yield DeltaPromptStackMessage(
delta_usage=DeltaPromptStackMessage.DeltaUsage(
usage=DeltaPromptStackMessage.Usage(
input_tokens=chunk.usage.prompt_tokens, output_tokens=chunk.usage.completion_tokens
)
)
Expand Down Expand Up @@ -168,8 +168,8 @@ def __message_to_prompt_stack_content(self, message: ChatCompletionMessage) -> B

def __message_delta_to_prompt_stack_content_delta(self, content_delta: ChoiceDelta) -> BaseDeltaPromptStackContent:
if content_delta.content is None:
return DeltaTextPromptStackContent("")
return TextDeltaPromptStackContent("")
else:
delta_content = content_delta.content

return DeltaTextPromptStackContent(delta_content)
return TextDeltaPromptStackContent(delta_content)
Loading

0 comments on commit a53762d

Please sign in to comment.