Skip to content

Commit

Permalink
Fix formatting and types
Browse files Browse the repository at this point in the history
  • Loading branch information
vasinov committed Jul 2, 2024
1 parent ed166c3 commit f228503
Show file tree
Hide file tree
Showing 18 changed files with 42 additions and 81 deletions.
2 changes: 1 addition & 1 deletion griptape/engines/rag/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,5 @@
"PromptResponseRagModule",
"RulesetsBeforeResponseRagModule",
"MetadataBeforeResponseRagModule",
"TextChunksResponseRagModule"
"TextChunksResponseRagModule",
]
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@ def run(self, context: RagContext) -> RagContext:
metadata = self.metadata if context_metadata is None else context_metadata

if metadata is not None:
context.before_query.append(
J2("engines/rag/modules/response/metadata/system.j2").render(metadata=metadata)
)
context.before_query.append(J2("engines/rag/modules/response/metadata/system.j2").render(metadata=metadata))

return context
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

@define(kw_only=True)
class TextChunksResponseRagModule(BaseResponseRagModule):

def run(self, context: RagContext) -> RagContext:
context.output = ListArtifact(context.text_chunks)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Sequence, Any, Callable
from attrs import define, field, Factory
from griptape.artifacts import TextArtifact, BaseArtifact
from griptape.artifacts import TextArtifact
from griptape.engines.rag import RagContext
from griptape.engines.rag.modules import BaseRetrievalRagModule

Expand All @@ -13,7 +13,7 @@
class VectorStoreRetrievalRagModule(BaseRetrievalRagModule):
vector_store_driver: BaseVectorStoreDriver = field()
query_params: dict[str, Any] = field(factory=dict)
process_query_output_fn: Callable[[list[BaseVectorStoreDriver.Entry]], BaseArtifact] = field(
process_query_output_fn: Callable[[list[BaseVectorStoreDriver.Entry]], Sequence[TextArtifact]] = field(
default=Factory(lambda: lambda es: [e.to_artifact() for e in es])
)

Expand Down
2 changes: 1 addition & 1 deletion griptape/engines/rag/rag_context.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Any
from typing import Optional
from attrs import define, field
from griptape.artifacts import TextArtifact, BaseArtifact

Expand Down
3 changes: 1 addition & 2 deletions griptape/engines/rag/stages/base_rag_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,4 @@ def run(self, context: RagContext) -> RagContext: ...

@property
@abstractmethod
def modules(self) -> list[BaseRagModule]:
...
def modules(self) -> list[BaseRagModule]: ...
6 changes: 2 additions & 4 deletions griptape/engines/rag/stages/query_rag_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,12 @@ class QueryRagStage(BaseRagStage):

@property
def modules(self) -> list[BaseRagModule]:
return self.query_modules
return self.query_modules # pyright: ignore

Check warning on line 15 in griptape/engines/rag/stages/query_rag_stage.py

View check run for this annotation

Codecov / codecov/patch

griptape/engines/rag/stages/query_rag_stage.py#L15

Added line #L15 was not covered by tests

def run(self, context: RagContext) -> RagContext:
logging.info(f"QueryStage: running {len(self.query_modules)} query generation modules in parallel")

Check warning on line 18 in griptape/engines/rag/stages/query_rag_stage.py

View check run for this annotation

Codecov / codecov/patch

griptape/engines/rag/stages/query_rag_stage.py#L18

Added line #L18 was not covered by tests

with self.futures_executor_fn() as executor:
utils.execute_futures_list(
[executor.submit(r.run, context) for r in self.query_modules]
)
utils.execute_futures_list([executor.submit(r.run, context) for r in self.query_modules])

return context
3 changes: 2 additions & 1 deletion griptape/engines/rag/stages/response_rag_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from griptape.engines.rag.modules import (
BaseResponseRagModule,
BaseBeforeResponseRagModule,
BaseAfterResponseRagModule, BaseRagModule,
BaseAfterResponseRagModule,
BaseRagModule,
)
from griptape.engines.rag.stages import BaseRagStage

Expand Down
2 changes: 1 addition & 1 deletion griptape/memory/task/storage/base_artifact_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@ def can_store(self, artifact: BaseArtifact) -> bool: ...
def summarize(self, namespace: str) -> TextArtifact | InfoArtifact: ...

@abstractmethod
def query(self, namespace: str, query: str, metadata: Any = None) -> TextArtifact | InfoArtifact: ...
def query(self, namespace: str, query: str, metadata: Any = None) -> BaseArtifact: ...
2 changes: 1 addition & 1 deletion griptape/memory/task/storage/blob_artifact_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,5 @@ def load_artifacts(self, namespace: str) -> ListArtifact:
def summarize(self, namespace: str) -> InfoArtifact:
return InfoArtifact("can't summarize artifacts")

def query(self, namespace: str, query: str, metadata: Any = None) -> InfoArtifact:
def query(self, namespace: str, query: str, metadata: Any = None) -> BaseArtifact:
return InfoArtifact("can't query artifacts")
8 changes: 4 additions & 4 deletions griptape/memory/task/storage/text_artifact_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,21 +42,21 @@ def summarize(self, namespace: str) -> TextArtifact:

return self.summary_engine.summarize_artifacts(self.load_artifacts(namespace))

def query(self, namespace: str, query: str, metadata: Any = None) -> TextArtifact | InfoArtifact:
def query(self, namespace: str, query: str, metadata: Any = None) -> BaseArtifact:
if self.rag_engine is None:
raise ValueError("RAG engine is not set.")

result = self.rag_engine.process(
RagContext(
query=query,
module_params={
self.retrieval_rag_module_name: {
self.retrieval_rag_module_name: { # pyright: ignore
"query_params": {
"namespace": namespace,
"metadata": None if metadata is None else str(metadata)
"metadata": None if metadata is None else str(metadata),
}
}
}
},
)
).output

Expand Down
2 changes: 1 addition & 1 deletion griptape/memory/task/task_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def summarize_namespace(self, namespace: str) -> TextArtifact | InfoArtifact:
else:
return InfoArtifact("Can't find memory content")

def query_namespace(self, namespace: str, query: str) -> TextArtifact | InfoArtifact:
def query_namespace(self, namespace: str, query: str) -> BaseArtifact:
storage = self.namespace_storage.get(namespace)

if storage:
Expand Down
4 changes: 2 additions & 2 deletions griptape/tasks/rag_task.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations
from attrs import define, field
from griptape.artifacts import TextArtifact, ErrorArtifact
from griptape.artifacts import ErrorArtifact, BaseArtifact
from griptape.engines.rag import RagEngine
from griptape.tasks import BaseTextInputTask

Expand All @@ -22,7 +22,7 @@ def rag_engine(self) -> RagEngine:
def rag_engine(self, value: RagEngine) -> None:
self._rag_engine = value

def run(self) -> TextArtifact | ErrorArtifact:
def run(self) -> BaseArtifact:
result = self.rag_engine.process_query(self.input.to_text()).output

if result is None:
Expand Down
4 changes: 2 additions & 2 deletions griptape/tools/task_memory_client/tool.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations
from attrs import define
from schema import Schema, Literal
from griptape.artifacts import TextArtifact, ErrorArtifact, InfoArtifact
from griptape.artifacts import TextArtifact, ErrorArtifact, InfoArtifact, BaseArtifact
from griptape.tools import BaseTool
from griptape.utils.decorators import activity

Expand Down Expand Up @@ -39,7 +39,7 @@ def summarize(self, params: dict) -> TextArtifact | InfoArtifact | ErrorArtifact
),
}
)
def query(self, params: dict) -> TextArtifact | InfoArtifact | ErrorArtifact:
def query(self, params: dict) -> BaseArtifact:
memory = self.find_input_memory(params["values"]["memory_name"])
artifact_namespace = params["values"]["artifact_namespace"]
query = params["values"]["query"]
Expand Down
3 changes: 1 addition & 2 deletions tests/mocks/mock_rag_module.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from griptape.engines.rag.modules import BaseRagModule


class MockRagModule(BaseRagModule):
...
class MockRagModule(BaseRagModule): ...
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,14 @@ class TestMetadataBeforeResponseRagModule:
def test_run(self):
module = MetadataBeforeResponseRagModule(name="foo")

assert "foo" in module.run(
RagContext(
module_params={"foo": {"metadata": "foo"}},
query="test"
)
).before_query[0]
assert "foo" in module.run(RagContext(module_params={"foo": {"metadata": "foo"}}, query="test")).before_query[0]

def test_run_with_override(self):
module = MetadataBeforeResponseRagModule(name="foo", metadata="bar")

assert "bar" in module.run(
RagContext(
module_params={"foo": {"query_params": {"metadata": "foo"}}},
query="test"
)
).before_query[0]
assert (
"bar"
in module.run(
RagContext(module_params={"foo": {"query_params": {"metadata": "foo"}}}, query="test")
).before_query[0]
)
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,12 @@
class TestVectorStoreRetrievalRagModule:
def test_run_without_namespace(self):
vector_store_driver = LocalVectorStoreDriver(embedding_driver=MockEmbeddingDriver())
module = VectorStoreRetrievalRagModule(
vector_store_driver=vector_store_driver
)
module = VectorStoreRetrievalRagModule(vector_store_driver=vector_store_driver)

vector_store_driver.upsert_text_artifact(TextArtifact("foobar1"), namespace="test")
vector_store_driver.upsert_text_artifact(TextArtifact("foobar2"), namespace="test")

result = module.run(
RagContext(
query="test",
)
)
result = module.run(RagContext(query="test"))

assert len(result) == 2
assert result[0].value == "foobar1"
Expand All @@ -28,18 +22,13 @@ def test_run_without_namespace(self):
def test_run_with_namespace(self):
vector_store_driver = LocalVectorStoreDriver(embedding_driver=MockEmbeddingDriver())
module = VectorStoreRetrievalRagModule(
vector_store_driver=vector_store_driver,
query_params={"namespace": "test"}
vector_store_driver=vector_store_driver, query_params={"namespace": "test"}
)

vector_store_driver.upsert_text_artifact(TextArtifact("foobar1"), namespace="test")
vector_store_driver.upsert_text_artifact(TextArtifact("foobar2"), namespace="test")

result = module.run(
RagContext(
query="test",
)
)
result = module.run(RagContext(query="test"))

assert len(result) == 2
assert result[0].value == "foobar1"
Expand All @@ -48,26 +37,23 @@ def test_run_with_namespace(self):
def test_run_with_namespace_overrides(self):
vector_store_driver = LocalVectorStoreDriver(embedding_driver=MockEmbeddingDriver())
module = VectorStoreRetrievalRagModule(
vector_store_driver=vector_store_driver,
query_params={"namespace": "test"}
vector_store_driver=vector_store_driver, query_params={"namespace": "test"}
)

vector_store_driver.upsert_text_artifact(TextArtifact("foobar1"), namespace="test")
vector_store_driver.upsert_text_artifact(TextArtifact("foobar2"), namespace="test")

result1 = module.run(
RagContext(
query="test",
module_params={"VectorStoreRetrievalRagModule": {"query_params": {"namespace": "empty"}}}
query="test", module_params={"VectorStoreRetrievalRagModule": {"query_params": {"namespace": "empty"}}}
)
)

result2 = module.run(
RagContext(
query="test",
module_params={"VectorStoreRetrievalRagModule": {"query_params": {"namespace": "test"}}}
query="test", module_params={"VectorStoreRetrievalRagModule": {"query_params": {"namespace": "test"}}}
)
)

assert len(result1) == 0
assert len(result2) == 2
assert len(result2) == 2
24 changes: 5 additions & 19 deletions tests/unit/engines/rag/test_rag_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@ def engine(self):
)
]
),
response_stage=ResponseRagStage(
response_module=PromptResponseRagModule(prompt_driver=MockPromptDriver())
),
response_stage=ResponseRagStage(response_module=PromptResponseRagModule(prompt_driver=MockPromptDriver())),
)

def test_module_name_uniqueness(self):
Expand All @@ -30,29 +28,17 @@ def test_module_name_uniqueness(self):
RagEngine(
retrieval_stage=RetrievalRagStage(
retrieval_modules=[
VectorStoreRetrievalRagModule(
name="test",
vector_store_driver=vector_store_driver
),
VectorStoreRetrievalRagModule(
name="test",
vector_store_driver=vector_store_driver
)
VectorStoreRetrievalRagModule(name="test", vector_store_driver=vector_store_driver),
VectorStoreRetrievalRagModule(name="test", vector_store_driver=vector_store_driver),
]
)
)

assert RagEngine(
retrieval_stage=RetrievalRagStage(
retrieval_modules=[
VectorStoreRetrievalRagModule(
name="test1",
vector_store_driver=vector_store_driver
),
VectorStoreRetrievalRagModule(
name="test2",
vector_store_driver=vector_store_driver
)
VectorStoreRetrievalRagModule(name="test1", vector_store_driver=vector_store_driver),
VectorStoreRetrievalRagModule(name="test2", vector_store_driver=vector_store_driver),
]
)
)
Expand Down

0 comments on commit f228503

Please sign in to comment.