Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ResponseRagStage and PromptResponseRagModule updates #1056

Merged
merged 15 commits into from
Aug 12, 2024
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Method `try_find_task` to `Structure`.
- `TranslateQueryRagModule` `RagEngine` module for translating input queries.
- Global event bus, `griptape.events.event_bus`, for publishing and subscribing to events.
- Unique name generation for all `RagEngine` modules.

### Changed
- **BREAKING**: Removed all uses of `EventPublisherMixin` in favor of `event_bus`.
- **BREAKING**: Removed `EventPublisherMixin`.
- **BREAKING**: `RagContext.output` was changed to `RagContext.outputs` to support multiple outputs. All relevant RAG modules were adjusted accordingly.
- **BREAKING**: Removed before and after response modules from `ResponseRagStage`.
- **BREAKING**: Moved ruleset and metadata ingestion from standalone modules to `PromptResponseRagModule`.
- `BaseTask.add_parent/child` will now call `self.structure.add_task` if possible.

## [0.29.0] - 2024-07-30
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/src/query_webpage_astra_db_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
]
),
response_stage=ResponseRagStage(
response_module=PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"))
response_modules=[PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"))]
),
)

Expand Down
2 changes: 1 addition & 1 deletion docs/examples/src/talk_to_a_pdf_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
]
),
response_stage=ResponseRagStage(
response_module=PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"))
response_modules=[PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"))]
),
)
vector_store_tool = RagClient(
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/src/talk_to_a_webpage_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
]
),
response_stage=ResponseRagStage(
response_module=PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"))
response_modules=[PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"))]
),
)

Expand Down
23 changes: 17 additions & 6 deletions docs/griptape-framework/engines/src/rag_engines_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,24 @@
from griptape.engines.rag.modules import PromptResponseRagModule, TranslateQueryRagModule, VectorStoreRetrievalRagModule
from griptape.engines.rag.stages import QueryRagStage, ResponseRagStage, RetrievalRagStage
from griptape.loaders import WebLoader
from griptape.rules import Rule, Ruleset

prompt_driver = OpenAiChatPromptDriver(model="gpt-4o", temperature=0)

vector_store = LocalVectorStoreDriver(embedding_driver=OpenAiEmbeddingDriver())

artifacts = WebLoader(max_tokens=500).load("https://www.griptape.ai")

if isinstance(artifacts, ErrorArtifact):
raise ValueError(artifacts.value)
raise Exception(artifacts.value)

vector_store.upsert_text_artifacts({"griptape": artifacts})
vector_store.upsert_text_artifacts(
{
"griptape": artifacts,
}
)

rag_engine = RagEngine(
query_stage=QueryRagStage(query_modules=[TranslateQueryRagModule(prompt_driver=prompt_driver, language="English")]),
query_stage=QueryRagStage(query_modules=[TranslateQueryRagModule(prompt_driver=prompt_driver, language="english")]),
retrieval_stage=RetrievalRagStage(
max_chunks=5,
retrieval_modules=[
Expand All @@ -25,12 +30,18 @@
)
],
),
response_stage=ResponseRagStage(response_module=PromptResponseRagModule(prompt_driver=prompt_driver)),
response_stage=ResponseRagStage(
response_modules=[
PromptResponseRagModule(
prompt_driver=prompt_driver, rulesets=[Ruleset(name="persona", rules=[Rule("Talk like a pirate")])]
)
]
),
)

rag_context = RagContext(
query="¿Qué ofrecen los servicios en la nube de Griptape?",
module_configs={"MyAwesomeRetriever": {"query_params": {"namespace": "griptape"}}},
)

print(rag_engine.process(rag_context).output.to_text())
print(rag_engine.process(rag_context).outputs[0].to_text())
2 changes: 1 addition & 1 deletion docs/griptape-framework/structures/src/task_memory_6.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
]
),
response_stage=ResponseRagStage(
response_module=PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"))
response_modules=[PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"))]
),
),
retrieval_rag_module_name="VectorStoreRetrievalRagModule",
Expand Down
2 changes: 1 addition & 1 deletion docs/griptape-framework/structures/src/tasks_9.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
]
),
response_stage=ResponseRagStage(
response_module=PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"))
response_modules=[PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"))]
),
),
)
Expand Down
2 changes: 1 addition & 1 deletion docs/griptape-tools/official-tools/src/rag_client_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
]
),
response_stage=ResponseRagStage(
response_module=PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"))
response_modules=[PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"))]
),
),
)
Expand Down
4 changes: 0 additions & 4 deletions griptape/engines/rag/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
from .response.base_after_response_rag_module import BaseAfterResponseRagModule
from .response.base_response_rag_module import BaseResponseRagModule
from .response.prompt_response_rag_module import PromptResponseRagModule
from .response.rulesets_before_response_rag_module import RulesetsBeforeResponseRagModule
from .response.metadata_before_response_rag_module import MetadataBeforeResponseRagModule
from .response.text_chunks_response_rag_module import TextChunksResponseRagModule
from .response.footnote_prompt_response_rag_module import FootnotePromptResponseRagModule

Expand All @@ -28,8 +26,6 @@
"BaseAfterResponseRagModule",
"BaseResponseRagModule",
"PromptResponseRagModule",
"RulesetsBeforeResponseRagModule",
"MetadataBeforeResponseRagModule",
"TextChunksResponseRagModule",
"FootnotePromptResponseRagModule",
]
5 changes: 4 additions & 1 deletion griptape/engines/rag/modules/base_rag_module.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import uuid
from abc import ABC
from concurrent import futures
from typing import TYPE_CHECKING, Any, Callable, Optional
Expand All @@ -14,7 +15,9 @@

@define(kw_only=True)
class BaseRagModule(ABC):
name: str = field(default=Factory(lambda self: self.__class__.__name__, takes_self=True), kw_only=True)
name: str = field(
default=Factory(lambda self: f"{self.__class__.__name__}-{uuid.uuid4().hex}", takes_self=True), kw_only=True
)
Comment on lines +18 to +20
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does name uniqueness get us?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Much easier to add multiple modules of the same type without having to explicitly define names. May be that's something we add to tools as well?

futures_executor_fn: Callable[[], futures.Executor] = field(
default=Factory(lambda: lambda: futures.ThreadPoolExecutor()),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

from attrs import define

from griptape.artifacts import BaseArtifact
from griptape.engines.rag import RagContext
from griptape.engines.rag.modules import BaseRagModule


@define(kw_only=True)
class BaseResponseRagModule(BaseRagModule, ABC):
@abstractmethod
def run(self, context: RagContext) -> RagContext: ...
def run(self, context: RagContext) -> BaseArtifact: ...

This file was deleted.

29 changes: 17 additions & 12 deletions griptape/engines/rag/modules/response/prompt_response_rag_module.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,30 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Callable
from typing import TYPE_CHECKING, Any, Callable, Optional

from attrs import Factory, define, field

from griptape.artifacts.text_artifact import TextArtifact
from griptape.engines.rag.modules import BaseResponseRagModule
from griptape.mixins import RuleMixin
from griptape.utils import J2

if TYPE_CHECKING:
from griptape.artifacts import BaseArtifact
from griptape.drivers import BasePromptDriver
from griptape.engines.rag import RagContext


@define(kw_only=True)
class PromptResponseRagModule(BaseResponseRagModule):
answer_token_offset: int = field(default=400)
class PromptResponseRagModule(BaseResponseRagModule, RuleMixin):
prompt_driver: BasePromptDriver = field()
answer_token_offset: int = field(default=400)
metadata: Optional[str] = field(default=None)
generate_system_template: Callable[[RagContext, list[TextArtifact]], str] = field(
default=Factory(lambda self: self.default_system_template_generator, takes_self=True),
)

def run(self, context: RagContext) -> RagContext:
def run(self, context: RagContext) -> BaseArtifact:
query = context.query
tokenizer = self.prompt_driver.tokenizer
included_chunks = []
Expand All @@ -45,15 +48,17 @@ def run(self, context: RagContext) -> RagContext:
output = self.prompt_driver.run(self.generate_prompt_stack(system_prompt, query)).to_artifact()

if isinstance(output, TextArtifact):
context.output = output
return output
else:
raise ValueError("Prompt driver did not return a TextArtifact")

return context

def default_system_template_generator(self, context: RagContext, artifacts: list[TextArtifact]) -> str:
return J2("engines/rag/modules/response/prompt/system.j2").render(
text_chunks=[c.to_text() for c in artifacts],
before_system_prompt="\n\n".join(context.before_query),
after_system_prompt="\n\n".join(context.after_query),
)
params: dict[str, Any] = {"text_chunks": [c.to_text() for c in artifacts]}

if len(self.all_rulesets) > 0:
params["rulesets"] = J2("rulesets/rulesets.j2").render(rulesets=self.all_rulesets)

if self.metadata is not None:
params["metadata"] = J2("engines/rag/modules/response/metadata/system.j2").render(metadata=self.metadata)

return J2("engines/rag/modules/response/prompt/system.j2").render(**params)

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
from attrs import define

from griptape.artifacts import ListArtifact
from griptape.artifacts import BaseArtifact, ListArtifact
from griptape.engines.rag import RagContext
from griptape.engines.rag.modules import BaseResponseRagModule


@define(kw_only=True)
class TextChunksResponseRagModule(BaseResponseRagModule):
def run(self, context: RagContext) -> RagContext:
context.output = ListArtifact(context.text_chunks)

return context
def run(self, context: RagContext) -> BaseArtifact:
return ListArtifact(context.text_chunks)
6 changes: 3 additions & 3 deletions griptape/engines/rag/rag_context.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING

from attrs import define, field

Expand All @@ -22,15 +22,15 @@ class RagContext(SerializableMixin):
before_query: An optional list of strings to add before the query in response modules.
after_query: An optional list of strings to add after the query in response modules.
text_chunks: A list of text chunks to pass around from the retrieval stage to the response stage.
output: Final output from the response stage.
outputs: List of outputs from the response stage.
"""

query: str = field(metadata={"serializable": True})
module_configs: dict[str, dict] = field(factory=dict, metadata={"serializable": True})
before_query: list[str] = field(factory=list, metadata={"serializable": True})
after_query: list[str] = field(factory=list, metadata={"serializable": True})
text_chunks: list[TextArtifact] = field(factory=list, metadata={"serializable": True})
output: Optional[BaseArtifact] = field(default=None, metadata={"serializable": True})
outputs: list[BaseArtifact] = field(factory=list, metadata={"serializable": True})

def get_references(self) -> list[Reference]:
return utils.references_from_artifacts(self.text_chunks)
2 changes: 1 addition & 1 deletion griptape/engines/rag/stages/query_rag_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def modules(self) -> Sequence[BaseRagModule]:
return self.query_modules

def run(self, context: RagContext) -> RagContext:
logging.info("QueryStage: running %s query generation modules sequentially", len(self.query_modules))
logging.info("QueryRagStage: running %s query generation modules sequentially", len(self.query_modules))

[qm.run(context) for qm in self.query_modules]

Expand Down
28 changes: 7 additions & 21 deletions griptape/engines/rag/stages/response_rag_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,49 +5,35 @@

from attrs import define, field

from griptape import utils
from griptape.engines.rag.stages import BaseRagStage

if TYPE_CHECKING:
from griptape.engines.rag import RagContext
from griptape.engines.rag.modules import (
BaseAfterResponseRagModule,
BaseBeforeResponseRagModule,
BaseRagModule,
BaseResponseRagModule,
)


@define(kw_only=True)
class ResponseRagStage(BaseRagStage):
before_response_modules: list[BaseBeforeResponseRagModule] = field(factory=list)
response_module: BaseResponseRagModule = field()
after_response_modules: list[BaseAfterResponseRagModule] = field(factory=list)
response_modules: list[BaseResponseRagModule] = field()

@property
def modules(self) -> list[BaseRagModule]:
ms = []

ms.extend(self.before_response_modules)
ms.extend(self.after_response_modules)

if self.response_module is not None:
ms.append(self.response_module)
ms.extend(self.response_modules)

return ms
Comment on lines 24 to 29
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still need this property?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do because we test for module name uniqueness by using this property.


def run(self, context: RagContext) -> RagContext:
logging.info("GenerationStage: running %s before modules sequentially", len(self.before_response_modules))

for generator in self.before_response_modules:
context = generator.run(context)

logging.info("GenerationStage: running generation module")

context = self.response_module.run(context)
logging.info("ResponseRagStage: running %s retrieval modules in parallel", len(self.response_modules))

logging.info("GenerationStage: running %s after modules sequentially", len(self.after_response_modules))
with self.futures_executor_fn() as executor:
results = utils.execute_futures_list([executor.submit(r.run, context) for r in self.response_modules])

for generator in self.after_response_modules:
context = generator.run(context)
context.outputs = results

return context
Loading
Loading