Skip to content

Commit

Permalink
Merge branch 'dev' into cjkindel/gtc_vector_store_driver
Browse files Browse the repository at this point in the history
  • Loading branch information
cjkindel authored Jul 8, 2024
2 parents 80e0aa2 + bb158ea commit fa89918
Show file tree
Hide file tree
Showing 42 changed files with 454 additions and 128 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- RAG modules:
- Retrieval:
- `VectorStoreRetrievalRagModule` for retrieving text chunks from vector stores.
- `TextLoaderRetrievalRagModule` for retrieving data with text loaders in real time.
- `TextChunksRerankRagModule` for re-ranking retrieved results.
- Response:
- `MetadataBeforeResponseRagModule` for appending metadata.
- `RulesetsBeforeResponseRagModule` for appending rulesets.
- `PromptResponseRagModule` for generating responses based on retrieved text chunks.
- `TextChunksResponseRagModule` for responding with retrieved text chunks.
- `FootnotePromptResponseRagModule` for responding with automatic footnotes from text chunk references.
- `RagClient` tool for exposing `RagEngines` to LLM agents.
- `RagTask` task for including `RagEngines` in any structure.
- Rerank drivers:
Expand Down Expand Up @@ -52,6 +54,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `Agent.input` for passing Artifacts as input.
- Support for `PromptTask`s to take `TextArtifact`s, `ImageArtifact`s, and `ListArtifact`s as input.
- Parameters `sort_key` and `sort_key_value` on `AmazonDynamoDbConversationMemoryDriver` for tables with sort keys.
- `Reference` for supporting artifact citations in loaders and RAG engine modules.
- `GriptapeCloudKnowledgeBaseVectorStoreDriver` to query Griptape Cloud Knowledge Bases.

### Changed
Expand Down
4 changes: 3 additions & 1 deletion docs/griptape-framework/engines/rag-engines.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@ No modules implemented yet.

#### Retrieval
- `TextRetrievalRagModule` is for retrieving text chunks.
- `TextLoaderRetrievalRagModule` is for retrieving data with text loaders in real time.
- `TextChunksRerankRagModule` is for re-ranking retrieved results.

#### Response
- `MetadataBeforeResponseRagModule` is for appending metadata.
- `RulesetsBeforeResponseRagModule` is for appending rulesets.
- `PromptResponseRagModule` is for generating responses based on retrieved text chunks.
- `TextChunksResponseRagModule` for responding with retrieved text chunks.
- `TextChunksResponseRagModule` is for responding with retrieved text chunks.
- `FootnotePromptResponseRagModule` is for responding with automatic footnotes from text chunk references.

### Example

Expand Down
9 changes: 7 additions & 2 deletions griptape/artifacts/base_artifact.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
from __future__ import annotations
from griptape.mixins import SerializableMixin
from typing import Any
from typing import Any, TYPE_CHECKING, Optional
import json
import uuid
from abc import ABC, abstractmethod
from attrs import define, field, Factory

if TYPE_CHECKING:
from griptape.common import Reference

@define()

@define
class BaseArtifact(SerializableMixin, ABC):
id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True, metadata={"serializable": True})
reference: Optional[Reference] = field(default=None, kw_only=True, metadata={"serializable": True})
meta: dict[str, Any] = field(factory=dict, kw_only=True, metadata={"serializable": True})
name: str = field(
default=Factory(lambda self: self.id, takes_self=True), kw_only=True, metadata={"serializable": True}
)
Expand Down
3 changes: 1 addition & 2 deletions griptape/artifacts/boolean_artifact.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from __future__ import annotations
from typing import Any, Union
from typing import Union
from attrs import define, field
from griptape.artifacts import BaseArtifact


@define
class BooleanArtifact(BaseArtifact):
value: bool = field(converter=bool, metadata={"serializable": True})
meta: dict[str, Any] = field(factory=dict, kw_only=True, metadata={"serializable": True})

@classmethod
def parse_bool(cls, value: Union[str, bool]) -> BooleanArtifact:
Expand Down
3 changes: 1 addition & 2 deletions griptape/artifacts/text_artifact.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Optional
from attrs import define, field
from griptape.artifacts import BaseArtifact

Expand All @@ -13,7 +13,6 @@ class TextArtifact(BaseArtifact):
value: str = field(converter=str, metadata={"serializable": True})
encoding: str = field(default="utf-8", kw_only=True)
encoding_error_handler: str = field(default="strict", kw_only=True)
meta: dict[str, Any] = field(factory=dict, kw_only=True, metadata={"serializable": True})
_embedding: list[float] = field(factory=list, kw_only=True)

@property
Expand Down
3 changes: 3 additions & 0 deletions griptape/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

from .prompt_stack.prompt_stack import PromptStack

from .reference import Reference

__all__ = [
"BaseMessage",
"BaseDeltaMessageContent",
Expand All @@ -20,4 +22,5 @@
"TextMessageContent",
"ImageMessageContent",
"PromptStack",
"Reference",
]
9 changes: 4 additions & 5 deletions griptape/common/prompt_stack/contents/base_message_content.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
from __future__ import annotations

from abc import ABC
from collections.abc import Sequence

from typing import TYPE_CHECKING
from attrs import define, field

from griptape.artifacts.base_artifact import BaseArtifact
from griptape.mixins import SerializableMixin

from .base_delta_message_content import BaseDeltaMessageContent

if TYPE_CHECKING:
from griptape.artifacts.base_artifact import BaseArtifact


@define
class BaseMessageContent(ABC, SerializableMixin):
Expand Down
14 changes: 14 additions & 0 deletions griptape/common/reference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import uuid
from typing import Optional
from attrs import define, field, Factory
from griptape.mixins import SerializableMixin


@define(kw_only=True)
class Reference(SerializableMixin):
id: str = field(default=Factory(lambda: uuid.uuid4().hex), metadata={"serializable": True})
title: str = field(metadata={"serializable": True})
authors: list[str] = field(factory=list, metadata={"serializable": True})
source: Optional[str] = field(default=None, metadata={"serializable": True})
year: Optional[str] = field(default=None, metadata={"serializable": True})
url: Optional[str] = field(default=None, metadata={"serializable": True})
5 changes: 3 additions & 2 deletions griptape/drivers/rerank/cohere_rerank_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@ class CohereRerankDriver(BaseRerankDriver):
)

def run(self, query: str, artifacts: list[TextArtifact]) -> list[TextArtifact]:
artifacts_dict = {str(hash(a.value)): a for a in artifacts}
response = self.client.rerank(
model=self.model,
query=query,
documents=[a.value for a in artifacts],
documents=[a.value for a in artifacts_dict.values()],
return_documents=True,
top_n=self.top_n,
)

return [TextArtifact(r.document.text) for r in response.results]
return [artifacts_dict[str(hash(r.document.text))] for r in response.results]
26 changes: 17 additions & 9 deletions griptape/drivers/vector/base_vector_store_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,21 @@ def to_artifact(self) -> BaseArtifact:
)

def upsert_text_artifacts(
self, artifacts: dict[str, list[TextArtifact]], meta: Optional[dict] = None, **kwargs
self, artifacts: list[TextArtifact] | dict[str, list[TextArtifact]], meta: Optional[dict] = None, **kwargs
) -> None:
with self.futures_executor_fn() as executor:
utils.execute_futures_dict(
{
namespace: executor.submit(self.upsert_text_artifact, a, namespace, meta, **kwargs)
for namespace, artifact_list in artifacts.items()
for a in artifact_list
}
)
if isinstance(artifacts, list):
utils.execute_futures_list(
[executor.submit(self.upsert_text_artifact, a, None, meta, **kwargs) for a in artifacts]
)
else:
utils.execute_futures_dict(
{
namespace: executor.submit(self.upsert_text_artifact, a, namespace, meta, **kwargs)
for namespace, artifact_list in artifacts.items()
for a in artifact_list
}
)

def upsert_text_artifact(
self,
Expand All @@ -57,7 +62,10 @@ def upsert_text_artifact(
**kwargs,
) -> str:
meta = {} if meta is None else meta
vector_id = self._get_default_vector_id(artifact.to_text()) if vector_id is None else vector_id

if vector_id is None:
value = artifact.to_text() if artifact.reference is None else artifact.to_text() + str(artifact.reference)
vector_id = self._get_default_vector_id(value)

if self.does_entry_exist(vector_id, namespace):
return vector_id
Expand Down
4 changes: 4 additions & 0 deletions griptape/engines/rag/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
from .retrieval.base_rerank_rag_module import BaseRerankRagModule
from .retrieval.text_chunks_rerank_rag_module import TextChunksRerankRagModule
from .retrieval.vector_store_retrieval_rag_module import VectorStoreRetrievalRagModule
from .retrieval.text_loader_retrieval_rag_module import TextLoaderRetrievalRagModule
from .response.base_before_response_rag_module import BaseBeforeResponseRagModule
from .response.base_after_response_rag_module import BaseAfterResponseRagModule
from .response.base_response_rag_module import BaseResponseRagModule
from .response.prompt_response_rag_module import PromptResponseRagModule
from .response.rulesets_before_response_rag_module import RulesetsBeforeResponseRagModule
from .response.metadata_before_response_rag_module import MetadataBeforeResponseRagModule
from .response.text_chunks_response_rag_module import TextChunksResponseRagModule
from .response.footnote_prompt_response_rag_module import FootnotePromptResponseRagModule

__all__ = [
"BaseRagModule",
Expand All @@ -19,11 +21,13 @@
"BaseRerankRagModule",
"TextChunksRerankRagModule",
"VectorStoreRetrievalRagModule",
"TextLoaderRetrievalRagModule",
"BaseBeforeResponseRagModule",
"BaseAfterResponseRagModule",
"BaseResponseRagModule",
"PromptResponseRagModule",
"RulesetsBeforeResponseRagModule",
"MetadataBeforeResponseRagModule",
"TextChunksResponseRagModule",
"FootnotePromptResponseRagModule",
]
8 changes: 4 additions & 4 deletions griptape/engines/rag/modules/base_rag_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ def generate_query_prompt_stack(self, system_prompt: str, query: str) -> PromptS
)

def get_context_param(self, context: RagContext, key: str) -> Optional[Any]:
return context.module_params.get(self.name, {}).get(key)
return context.module_configs.get(self.name, {}).get(key)

def set_context_param(self, context: RagContext, key: str, value: Any) -> None:
if not isinstance(context.module_params.get(self.name), dict):
context.module_params[self.name] = {}
if not isinstance(context.module_configs.get(self.name), dict):
context.module_configs[self.name] = {}

context.module_params[self.name][key] = value
context.module_configs[self.name][key] = value
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from attrs import define

from griptape import utils
from griptape.artifacts import TextArtifact
from griptape.engines.rag import RagContext
from griptape.engines.rag.modules import PromptResponseRagModule
from griptape.utils import J2


@define(kw_only=True)
class FootnotePromptResponseRagModule(PromptResponseRagModule):
def default_system_template_generator(self, context: RagContext, artifacts: list[TextArtifact]) -> str:
return J2("engines/rag/modules/response/footnote_prompt/system.j2").render(
text_chunk_artifacts=artifacts,
references=utils.references_from_artifacts(artifacts),
before_system_prompt="\n\n".join(context.before_query),
after_system_prompt="\n\n".join(context.after_query),
)
53 changes: 23 additions & 30 deletions griptape/engines/rag/modules/response/prompt_response_rag_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,50 +11,43 @@
class PromptResponseRagModule(BaseResponseRagModule):
answer_token_offset: int = field(default=400)
prompt_driver: BasePromptDriver = field()
generate_system_template: Callable[[list[str], list[str], list[str]], str] = field(
generate_system_template: Callable[[RagContext, list[TextArtifact]], str] = field(
default=Factory(lambda self: self.default_system_template_generator, takes_self=True)
)

def run(self, context: RagContext) -> RagContext:
query = context.query
before_query = context.before_query
after_query = context.after_query
text_artifact_chunks = context.text_chunks
tokenizer = self.prompt_driver.tokenizer
included_chunks = []
system_prompt = self.generate_system_template(context, included_chunks)

if query:
tokenizer = self.prompt_driver.tokenizer
text_chunks = []
system_prompt = self.generate_system_template(text_chunks, before_query, after_query)
for artifact in context.text_chunks:
included_chunks.append(artifact)

for artifact in text_artifact_chunks:
text_chunks.append(artifact.value)
system_prompt = self.generate_system_template(context, included_chunks)
message_token_count = self.prompt_driver.tokenizer.count_tokens(
self.prompt_driver.prompt_stack_to_string(self.generate_query_prompt_stack(system_prompt, query))
)

system_prompt = self.generate_system_template(text_chunks, before_query, after_query)
message_token_count = self.prompt_driver.tokenizer.count_tokens(
self.prompt_driver.prompt_stack_to_string(self.generate_query_prompt_stack(system_prompt, query))
)
if message_token_count + self.answer_token_offset >= tokenizer.max_input_tokens:
included_chunks.pop()

if message_token_count + self.answer_token_offset >= tokenizer.max_input_tokens:
text_chunks.pop()
system_prompt = self.generate_system_template(context, included_chunks)

system_prompt = self.generate_system_template(text_chunks, before_query, after_query)
break

break
output = self.prompt_driver.run(self.generate_query_prompt_stack(system_prompt, query)).to_artifact()

output = self.prompt_driver.run(self.generate_query_prompt_stack(system_prompt, query)).to_artifact()

if isinstance(output, TextArtifact):
context.output = output
else:
raise ValueError("Prompt driver did not return a TextArtifact")
if isinstance(output, TextArtifact):
context.output = output
else:
raise ValueError("Prompt driver did not return a TextArtifact")

return context

def default_system_template_generator(
self, text_chunks: list[str], before_system_prompt: list, after_system_prompt: list
) -> str:
def default_system_template_generator(self, context: RagContext, artifacts: list[TextArtifact]) -> str:
return J2("engines/rag/modules/response/prompt/system.j2").render(
text_chunks=text_chunks,
before_system_prompt="\n\n".join(before_system_prompt),
after_system_prompt="\n\n".join(after_system_prompt),
text_chunks=[c.to_text() for c in artifacts],
before_system_prompt="\n\n".join(context.before_query),
after_system_prompt="\n\n".join(context.after_query),
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from __future__ import annotations
import uuid
from typing import TYPE_CHECKING, Sequence, Any, Callable
from attrs import define, field, Factory
from griptape import utils
from griptape.artifacts import TextArtifact, ErrorArtifact
from griptape.engines.rag import RagContext
from griptape.engines.rag.modules import BaseRetrievalRagModule

if TYPE_CHECKING:
from griptape.drivers import BaseVectorStoreDriver
from griptape.loaders import BaseTextLoader


@define(kw_only=True)
class TextLoaderRetrievalRagModule(BaseRetrievalRagModule):
loader: BaseTextLoader = field()
vector_store_driver: BaseVectorStoreDriver = field()
source: Any = field()
query_params: dict[str, Any] = field(factory=dict)
process_query_output_fn: Callable[[list[BaseVectorStoreDriver.Entry]], Sequence[TextArtifact]] = field(
default=Factory(lambda: lambda es: [e.to_artifact() for e in es])
)

def run(self, context: RagContext) -> Sequence[TextArtifact]:
namespace = uuid.uuid4().hex
context_source = self.get_context_param(context, "source")
source = self.source if context_source is None else context_source

query_params = utils.dict_merge(self.query_params, self.get_context_param(context, "query_params"))

query_params["namespace"] = namespace

loader_output = self.loader.load(source)

if isinstance(loader_output, ErrorArtifact):
raise Exception(loader_output.to_text() if loader_output.exception is None else loader_output.exception)
else:
self.vector_store_driver.upsert_text_artifacts({namespace: loader_output})

return self.process_query_output_fn(self.vector_store_driver.query(context.query, **query_params))
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Sequence, Any, Callable
from attrs import define, field, Factory
from griptape import utils
from griptape.artifacts import TextArtifact
from griptape.engines.rag import RagContext
from griptape.engines.rag.modules import BaseRetrievalRagModule
Expand All @@ -18,7 +19,6 @@ class VectorStoreRetrievalRagModule(BaseRetrievalRagModule):
)

def run(self, context: RagContext) -> Sequence[TextArtifact]:
context_query_params = self.get_context_param(context, "query_params")
query_params = self.query_params if context_query_params is None else context_query_params
query_params = utils.dict_merge(self.query_params, self.get_context_param(context, "query_params"))

return self.process_query_output_fn(self.vector_store_driver.query(context.query, **query_params))
Loading

0 comments on commit fa89918

Please sign in to comment.