Skip to content

Commit

Permalink
Reformat
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewfrench committed Nov 8, 2023
1 parent 40fdf35 commit c4ea1c8
Show file tree
Hide file tree
Showing 246 changed files with 1,429 additions and 5,270 deletions.
13 changes: 3 additions & 10 deletions griptape/artifacts/base_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,9 @@
@define
class BaseArtifact(ABC):
id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True)
name: str = field(
default=Factory(lambda self: self.id, takes_self=True), kw_only=True
)
name: str = field(default=Factory(lambda self: self.id, takes_self=True), kw_only=True)
value: any = field()
type: str = field(
default=Factory(lambda self: self.__class__.__name__, takes_self=True),
kw_only=True,
)
type: str = field(default=Factory(lambda self: self.__class__.__name__, takes_self=True), kw_only=True)

@classmethod
def value_to_bytes(cls, value: any) -> bytes:
Expand Down Expand Up @@ -54,9 +49,7 @@ def from_dict(cls, artifact_dict: dict) -> BaseArtifact:
class_registry.register("ListArtifact", ListArtifactSchema)

try:
return class_registry.get_class(artifact_dict["type"])().load(
artifact_dict
)
return class_registry.get_class(artifact_dict["type"])().load(artifact_dict)
except RegistryError:
raise ValueError("Unsupported artifact type")

Expand Down
10 changes: 2 additions & 8 deletions griptape/artifacts/blob_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,10 @@ def __add__(self, other: BlobArtifact) -> BlobArtifact:

@property
def full_path(self) -> str:
return (
os.path.join(self.dir_name, self.name)
if self.dir_name
else self.name
)
return os.path.join(self.dir_name, self.name) if self.dir_name else self.name

def to_text(self) -> str:
return self.value.decode(
encoding=self.encoding, errors=self.encoding_error_handler
)
return self.value.decode(encoding=self.encoding, errors=self.encoding_error_handler)

def to_dict(self) -> dict:
from griptape.schemas import BlobArtifactSchema
Expand Down
5 changes: 1 addition & 4 deletions griptape/artifacts/csv_row_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,7 @@ def __add__(self, other: CsvRowArtifact) -> CsvRowArtifact:
def to_text(self) -> str:
with io.StringIO() as csvfile:
writer = csv.DictWriter(
csvfile,
fieldnames=self.value.keys(),
quoting=csv.QUOTE_MINIMAL,
delimiter=self.delimiter,
csvfile, fieldnames=self.value.keys(), quoting=csv.QUOTE_MINIMAL, delimiter=self.delimiter
)

writer.writerow(self.value)
Expand Down
4 changes: 1 addition & 3 deletions griptape/artifacts/list_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@ def validate_value(self, _, value: list[BaseArtifact]) -> None:
first_type = type(value[0])

if not all(isinstance(v, first_type) for v in value):
raise ValueError(
f"list elements in 'value' are not the same type"
)
raise ValueError(f"list elements in 'value' are not the same type")

@property
def child_type(self) -> Optional[type]:
Expand Down
8 changes: 1 addition & 7 deletions griptape/chunkers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,4 @@
from .markdown_chunker import MarkdownChunker


__all__ = [
"ChunkSeparator",
"BaseChunker",
"TextChunker",
"PdfChunker",
"MarkdownChunker",
]
__all__ = ["ChunkSeparator", "BaseChunker", "TextChunker", "PdfChunker", "MarkdownChunker"]
57 changes: 12 additions & 45 deletions griptape/chunkers/base_chunker.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,32 +12,19 @@ class BaseChunker(ABC):
DEFAULT_SEPARATORS = [ChunkSeparator(" ")]

separators: list[ChunkSeparator] = field(
default=Factory(lambda self: self.DEFAULT_SEPARATORS, takes_self=True),
kw_only=True,
default=Factory(lambda self: self.DEFAULT_SEPARATORS, takes_self=True), kw_only=True
)
tokenizer: BaseTokenizer = field(
default=Factory(
lambda: OpenAiTokenizer(
model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL
)
),
kw_only=True,
)
max_tokens: int = field(
default=Factory(
lambda self: self.tokenizer.max_tokens, takes_self=True
),
kw_only=True,
default=Factory(lambda: OpenAiTokenizer(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL)), kw_only=True
)
max_tokens: int = field(default=Factory(lambda self: self.tokenizer.max_tokens, takes_self=True), kw_only=True)

def chunk(self, text: TextArtifact | str) -> list[TextArtifact]:
text = text.value if isinstance(text, TextArtifact) else text

return [TextArtifact(c) for c in self._chunk_recursively(text)]

def _chunk_recursively(
self, chunk: str, current_separator: Optional[ChunkSeparator] = None
) -> list[str]:
def _chunk_recursively(self, chunk: str, current_separator: Optional[ChunkSeparator] = None) -> list[str]:
token_count = self.tokenizer.count_tokens(chunk)

if token_count <= self.max_tokens:
Expand All @@ -50,9 +37,7 @@ def _chunk_recursively(

# If a separator is provided, only use separators after it.
if current_separator:
separators = self.separators[
self.separators.index(current_separator) :
]
separators = self.separators[self.separators.index(current_separator) :]
else:
separators = self.separators

Expand Down Expand Up @@ -81,32 +66,16 @@ def _chunk_recursively(
# Create the two subchunks based on the best separator.
if separator.is_prefix:
# If the separator is a prefix, append it before this subchunk.
first_subchunk = separator.value + separator.value.join(
subchunks[: balance_index + 1]
)
second_subchunk = (
separator.value
+ separator.value.join(
subchunks[balance_index + 1 :]
)
)
first_subchunk = separator.value + separator.value.join(subchunks[: balance_index + 1])
second_subchunk = separator.value + separator.value.join(subchunks[balance_index + 1 :])
else:
# If the separator is not a prefix, append it after this subchunk.
first_subchunk = (
separator.value.join(subchunks[: balance_index + 1])
+ separator.value
)
second_subchunk = separator.value.join(
subchunks[balance_index + 1 :]
)
first_subchunk = separator.value.join(subchunks[: balance_index + 1]) + separator.value
second_subchunk = separator.value.join(subchunks[balance_index + 1 :])

# Continue recursively chunking the subchunks.
first_subchunk_rec = self._chunk_recursively(
first_subchunk.strip(), separator
)
second_subchunk_rec = self._chunk_recursively(
second_subchunk.strip(), separator
)
first_subchunk_rec = self._chunk_recursively(first_subchunk.strip(), separator)
second_subchunk_rec = self._chunk_recursively(second_subchunk.strip(), separator)

# Return the concatenated results of the subchunks if both are non-empty.
if first_subchunk_rec and second_subchunk_rec:
Expand All @@ -120,6 +89,4 @@ def _chunk_recursively(
return []
# If none of the separators result in a balanced split, split the chunk in half.
midpoint = len(chunk) // 2
return self._chunk_recursively(
chunk[:midpoint]
) + self._chunk_recursively(chunk[midpoint:])
return self._chunk_recursively(chunk[:midpoint]) + self._chunk_recursively(chunk[midpoint:])
48 changes: 12 additions & 36 deletions griptape/drivers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,35 +2,23 @@
from .prompt.openai_chat_prompt_driver import OpenAiChatPromptDriver
from .prompt.openai_completion_prompt_driver import OpenAiCompletionPromptDriver
from .prompt.azure_openai_chat_prompt_driver import AzureOpenAiChatPromptDriver
from .prompt.azure_openai_completion_prompt_driver import (
AzureOpenAiCompletionPromptDriver,
)
from .prompt.azure_openai_completion_prompt_driver import AzureOpenAiCompletionPromptDriver
from .prompt.cohere_prompt_driver import CoherePromptDriver
from .prompt.hugging_face_pipeline_prompt_driver import (
HuggingFacePipelinePromptDriver,
)
from .prompt.hugging_face_pipeline_prompt_driver import HuggingFacePipelinePromptDriver
from .prompt.hugging_face_hub_prompt_driver import HuggingFaceHubPromptDriver
from .prompt.anthropic_prompt_driver import AnthropicPromptDriver
from .prompt.amazon_sagemaker_prompt_driver import AmazonSageMakerPromptDriver
from .prompt.amazon_bedrock_prompt_driver import AmazonBedrockPromptDriver
from .prompt.base_multi_model_prompt_driver import BaseMultiModelPromptDriver

from .memory.conversation.base_conversation_memory_driver import (
BaseConversationMemoryDriver,
)
from .memory.conversation.local_conversation_memory_driver import (
LocalConversationMemoryDriver,
)
from .memory.conversation.amazon_dynamodb_conversation_memory_driver import (
AmazonDynamoDbConversationMemoryDriver,
)
from .memory.conversation.base_conversation_memory_driver import BaseConversationMemoryDriver
from .memory.conversation.local_conversation_memory_driver import LocalConversationMemoryDriver
from .memory.conversation.amazon_dynamodb_conversation_memory_driver import AmazonDynamoDbConversationMemoryDriver

from .embedding.base_embedding_driver import BaseEmbeddingDriver
from .embedding.openai_embedding_driver import OpenAiEmbeddingDriver
from .embedding.azure_openai_embedding_driver import AzureOpenAiEmbeddingDriver
from .embedding.bedrock_titan_embedding_driver import (
BedrockTitanEmbeddingDriver,
)
from .embedding.bedrock_titan_embedding_driver import BedrockTitanEmbeddingDriver

from .vector.base_vector_store_driver import BaseVectorStoreDriver
from .vector.local_vector_store_driver import LocalVectorStoreDriver
Expand All @@ -39,9 +27,7 @@
from .vector.mongodb_vector_store_driver import MongoDbAtlasVectorStoreDriver
from .vector.redis_vector_store_driver import RedisVectorStoreDriver
from .vector.opensearch_vector_store_driver import OpenSearchVectorStoreDriver
from .vector.amazon_opensearch_vector_store_driver import (
AmazonOpenSearchVectorStoreDriver,
)
from .vector.amazon_opensearch_vector_store_driver import AmazonOpenSearchVectorStoreDriver
from .vector.pgvector_vector_store_driver import PgVectorVectorStoreDriver

from .sql.base_sql_driver import BaseSqlDriver
Expand All @@ -50,21 +36,11 @@
from .sql.sql_driver import SqlDriver

from .prompt_model.base_prompt_model_driver import BasePromptModelDriver
from .prompt_model.sagemaker_llama_prompt_model_driver import (
SageMakerLlamaPromptModelDriver,
)
from .prompt_model.sagemaker_falcon_prompt_model_driver import (
SageMakerFalconPromptModelDriver,
)
from .prompt_model.bedrock_titan_prompt_model_driver import (
BedrockTitanPromptModelDriver,
)
from .prompt_model.bedrock_claude_prompt_model_driver import (
BedrockClaudePromptModelDriver,
)
from .prompt_model.bedrock_jurassic_prompt_model_driver import (
BedrockJurassicPromptModelDriver,
)
from .prompt_model.sagemaker_llama_prompt_model_driver import SageMakerLlamaPromptModelDriver
from .prompt_model.sagemaker_falcon_prompt_model_driver import SageMakerFalconPromptModelDriver
from .prompt_model.bedrock_titan_prompt_model_driver import BedrockTitanPromptModelDriver
from .prompt_model.bedrock_claude_prompt_model_driver import BedrockClaudePromptModelDriver
from .prompt_model.bedrock_jurassic_prompt_model_driver import BedrockJurassicPromptModelDriver


__all__ = [
Expand Down
5 changes: 1 addition & 4 deletions griptape/drivers/embedding/azure_openai_embedding_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,7 @@ class AzureOpenAiEmbeddingDriver(OpenAiEmbeddingDriver):
api_type: str = field(default="azure", kw_only=True)
api_version: str = field(default="2023-05-15", kw_only=True)
tokenizer: OpenAiTokenizer = field(
default=Factory(
lambda self: OpenAiTokenizer(model=self.model), takes_self=True
),
kw_only=True,
default=Factory(lambda self: OpenAiTokenizer(model=self.model), takes_self=True), kw_only=True
)

def _params(self, chunk: list[int] | str) -> dict:
Expand Down
9 changes: 2 additions & 7 deletions griptape/drivers/embedding/base_embedding_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,7 @@ def embed_text_artifact(self, artifact: TextArtifact) -> list[float]:
def embed_string(self, string: str) -> list[float]:
for attempt in self.retrying():
with attempt:
if (
self.tokenizer.count_tokens(string)
> self.tokenizer.max_tokens
):
if self.tokenizer.count_tokens(string) > self.tokenizer.max_tokens:
return self._embed_long_string(string)
else:
return self.try_embed_chunk(string)
Expand All @@ -57,9 +54,7 @@ def _embed_long_string(self, string: str) -> list[float]:
length_chunks.append(len(chunk))

# generate weighted averages
embedding_chunks = np.average(
embedding_chunks, axis=0, weights=length_chunks
)
embedding_chunks = np.average(embedding_chunks, axis=0, weights=length_chunks)

# normalize length to 1
embedding_chunks = embedding_chunks / np.linalg.norm(embedding_chunks)
Expand Down
22 changes: 4 additions & 18 deletions griptape/drivers/embedding/bedrock_titan_embedding_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,34 +26,20 @@ class BedrockTitanEmbeddingDriver(BaseEmbeddingDriver):

model: str = field(default=DEFAULT_MODEL, kw_only=True)
dimensions: int = field(default=DEFAULT_MAX_TOKENS, kw_only=True)
session: boto3.Session = field(
default=Factory(lambda: import_optional_dependency("boto3").Session()),
kw_only=True,
)
session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
tokenizer: BedrockTitanTokenizer = field(
default=Factory(
lambda self: BedrockTitanTokenizer(
model=self.model, session=self.session
),
takes_self=True,
),
default=Factory(lambda self: BedrockTitanTokenizer(model=self.model, session=self.session), takes_self=True),
kw_only=True,
)
bedrock_client: Any = field(
default=Factory(
lambda self: self.session.client("bedrock-runtime"), takes_self=True
),
kw_only=True,
default=Factory(lambda self: self.session.client("bedrock-runtime"), takes_self=True), kw_only=True
)

def try_embed_chunk(self, chunk: str) -> list[float]:
payload = {"inputText": chunk}

response = self.bedrock_client.invoke_model(
body=json.dumps(payload),
modelId=self.model,
accept="application/json",
contentType="application/json",
body=json.dumps(payload), modelId=self.model, accept="application/json", contentType="application/json"
)
response_body = json.loads(response.get("body").read())

Expand Down
17 changes: 4 additions & 13 deletions griptape/drivers/embedding/openai_embedding_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,10 @@ class OpenAiEmbeddingDriver(BaseEmbeddingDriver):
api_type: str = field(default=openai.api_type, kw_only=True)
api_version: Optional[str] = field(default=openai.api_version, kw_only=True)
api_base: str = field(default=openai.api_base, kw_only=True)
api_key: Optional[str] = field(
default=Factory(lambda: os.environ.get("OPENAI_API_KEY")), kw_only=True
)
organization: Optional[str] = field(
default=openai.organization, kw_only=True
)
api_key: Optional[str] = field(default=Factory(lambda: os.environ.get("OPENAI_API_KEY")), kw_only=True)
organization: Optional[str] = field(default=openai.organization, kw_only=True)
tokenizer: OpenAiTokenizer = field(
default=Factory(
lambda self: OpenAiTokenizer(model=self.model), takes_self=True
),
kw_only=True,
default=Factory(lambda self: OpenAiTokenizer(model=self.model), takes_self=True), kw_only=True
)

def __attrs_post_init__(self) -> None:
Expand All @@ -54,9 +47,7 @@ def try_embed_chunk(self, chunk: str) -> list[float]:
# https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
if self.model.endswith("001"):
chunk = chunk.replace("\n", " ")
return openai.Embedding.create(**self._params(chunk))["data"][0][
"embedding"
]
return openai.Embedding.create(**self._params(chunk))["data"][0]["embedding"]

def _params(self, chunk: str) -> dict:
return {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,7 @@

@define
class AmazonDynamoDbConversationMemoryDriver(BaseConversationMemoryDriver):
session: boto3.Session = field(
default=Factory(lambda: import_optional_dependency("boto3").Session()),
kw_only=True,
)
session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
table_name: str = field(kw_only=True)
partition_key: str = field(kw_only=True)
value_attribute_key: str = field(kw_only=True)
Expand All @@ -36,9 +33,7 @@ def store(self, memory: ConversationMemory) -> None:
)

def load(self) -> Optional[ConversationMemory]:
response = self.table.get_item(
Key={self.partition_key: self.partition_key_value}
)
response = self.table.get_item(Key={self.partition_key: self.partition_key_value})

if "Item" in response and self.value_attribute_key in response["Item"]:
memory_value = response["Item"][self.value_attribute_key]
Expand Down
Loading

0 comments on commit c4ea1c8

Please sign in to comment.