diff --git a/griptape/engines/rag/modules/base_rag_module.py b/griptape/engines/rag/modules/base_rag_module.py index 1bf1ad91a..829a24565 100644 --- a/griptape/engines/rag/modules/base_rag_module.py +++ b/griptape/engines/rag/modules/base_rag_module.py @@ -23,13 +23,9 @@ def generate_prompt_stack(self, system_prompt: Optional[str], query: str) -> Pro messages = [] if system_prompt is not None: - messages.append( - Message(system_prompt, role=Message.SYSTEM_ROLE) - ) + messages.append(Message(system_prompt, role=Message.SYSTEM_ROLE)) - messages.append( - Message(query, role=Message.USER_ROLE) - ) + messages.append(Message(query, role=Message.USER_ROLE)) return PromptStack(messages=messages) diff --git a/griptape/engines/rag/modules/query/translate_query_rag_module.py b/griptape/engines/rag/modules/query/translate_query_rag_module.py index 509c42388..f1f9ca0ec 100644 --- a/griptape/engines/rag/modules/query/translate_query_rag_module.py +++ b/griptape/engines/rag/modules/query/translate_query_rag_module.py @@ -1,12 +1,15 @@ from __future__ import annotations + from typing import TYPE_CHECKING, Callable -from attrs import field, define, Factory -from griptape.engines.rag import RagContext + +from attrs import Factory, define, field + from griptape.engines.rag.modules import BaseQueryRagModule from griptape.utils import J2 if TYPE_CHECKING: from griptape.drivers import BasePromptDriver + from griptape.engines.rag import RagContext @define(kw_only=True) @@ -26,7 +29,4 @@ def run(self, context: RagContext) -> RagContext: return context def default_user_template_generator(self, query: str, language: str) -> str: - return J2("engines/rag/modules/query/translate/user.j2").render( - query=query, - language=language - ) + return J2("engines/rag/modules/query/translate/user.j2").render(query=query, language=language) diff --git a/tests/unit/engines/query/test_translate_query_rag_module.py b/tests/unit/engines/query/test_translate_query_rag_module.py index ecd9c5015..a1114deb3 100644 --- a/tests/unit/engines/query/test_translate_query_rag_module.py +++ b/tests/unit/engines/query/test_translate_query_rag_module.py @@ -5,8 +5,6 @@ class TestTranslateQueryRagModule: def test_run(self): - module = TranslateQueryRagModule( - prompt_driver=MockPromptDriver() - ) + module = TranslateQueryRagModule(prompt_driver=MockPromptDriver()) assert module.run(RagContext(query="foo")).query == "bar"