Skip to content

Commit

Permalink
FIx formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
vasinov committed Aug 6, 2024
1 parent 882992b commit f2ce5f5
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 15 deletions.
8 changes: 2 additions & 6 deletions griptape/engines/rag/modules/base_rag_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,9 @@ def generate_prompt_stack(self, system_prompt: Optional[str], query: str) -> Pro
messages = []

if system_prompt is not None:
messages.append(
Message(system_prompt, role=Message.SYSTEM_ROLE)
)
messages.append(Message(system_prompt, role=Message.SYSTEM_ROLE))

messages.append(
Message(query, role=Message.USER_ROLE)
)
messages.append(Message(query, role=Message.USER_ROLE))

return PromptStack(messages=messages)

Expand Down
12 changes: 6 additions & 6 deletions griptape/engines/rag/modules/query/translate_query_rag_module.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Callable
from attrs import field, define, Factory
from griptape.engines.rag import RagContext

from attrs import Factory, define, field

from griptape.engines.rag.modules import BaseQueryRagModule
from griptape.utils import J2

if TYPE_CHECKING:
from griptape.drivers import BasePromptDriver
from griptape.engines.rag import RagContext


@define(kw_only=True)
Expand All @@ -26,7 +29,4 @@ def run(self, context: RagContext) -> RagContext:
return context

def default_user_template_generator(self, query: str, language: str) -> str:
return J2("engines/rag/modules/query/translate/user.j2").render(
query=query,
language=language
)
return J2("engines/rag/modules/query/translate/user.j2").render(query=query, language=language)
4 changes: 1 addition & 3 deletions tests/unit/engines/query/test_translate_query_rag_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@

class TestTranslateQueryRagModule:
def test_run(self):
module = TranslateQueryRagModule(
prompt_driver=MockPromptDriver()
)
module = TranslateQueryRagModule(prompt_driver=MockPromptDriver())

assert module.run(RagContext(query="foo")).query == "bar"

0 comments on commit f2ce5f5

Please sign in to comment.