Skip to content

Commit

Permalink
refactor: DRY enhancement prompt classes (#808)
Browse files Browse the repository at this point in the history
* (refactor): `DirectSQLPrompt` extends `GeneratePythonCodePrompt` class
* (refactor): reduce duplication of logic in `DirectSQLPrompt`
* (refactor): use `kwargs.pop()` in setup method
  • Loading branch information
nautics889 committed Dec 9, 2023
1 parent 26857ef commit 55c5c74
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 58 deletions.
44 changes: 6 additions & 38 deletions pandasai/prompts/direct_sql_prompt.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
""" Prompt to explain code generation by the LLM"""
from .file_based_prompt import FileBasedPrompt
from .generate_python_code import (
CurrentCodePrompt,
SimpleReasoningPrompt,
)
from .generate_python_code import CurrentCodePrompt, GeneratePythonCodePrompt


class DirectSQLPrompt(FileBasedPrompt):
class DirectSQLPrompt(GeneratePythonCodePrompt):
"""Prompt to explain code generation by the LLM"""

_path_to_template = "assets/prompt_templates/direct_sql_connector.tmpl"
Expand All @@ -27,35 +23,7 @@ def _prepare_tables_data(self, tables):
def setup(self, tables, **kwargs) -> None:
self.set_var("tables", self._prepare_tables_data(tables))

if "custom_instructions" in kwargs:
self.set_var("instructions", kwargs["custom_instructions"])
else:
self.set_var("instructions", "")

if "current_code" in kwargs:
self.set_var("current_code", kwargs["current_code"])
else:
self.set_var("current_code", CurrentCodePrompt())

if "code_description" in kwargs:
self.set_var("code_description", kwargs["code_description"])
else:
self.set_var("code_description", "Update this initial code:")

if "last_message" in kwargs:
self.set_var("last_message", kwargs["last_message"])
else:
self.set_var("last_message", "")

if "prev_conversation" in kwargs:
self.set_var("prev_conversation", kwargs["prev_conversation"])
else:
self.set_var("prev_conversation", "")

def on_prompt_generation(self) -> None:
default_import = "import pandas as pd"
engine_df_name = "pd.DataFrame"

self.set_var("default_import", default_import)
self.set_var("engine_df_name", engine_df_name)
self.set_var("reasoning", SimpleReasoningPrompt())
super(DirectSQLPrompt, self).setup(**kwargs)

self.set_var("current_code", kwargs.pop("current_code", CurrentCodePrompt()))

31 changes: 11 additions & 20 deletions pandasai/prompts/generate_python_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,30 +53,21 @@ class GeneratePythonCodePrompt(FileBasedPrompt):
_path_to_template = "assets/prompt_templates/generate_python_code.tmpl"

def setup(self, **kwargs) -> None:
if "custom_instructions" in kwargs:
self.set_var("instructions", kwargs["custom_instructions"])
else:
self.set_var("instructions", "")
self.set_var("instructions", kwargs.pop("custom_instructions", ""))

if "current_code" in kwargs:
self.set_var("current_code", kwargs["current_code"])
else:
self.set_var("current_code", CurrentCodePrompt(dfs_declared=True))
self.set_var(
"current_code",
kwargs.pop("current_code", CurrentCodePrompt(dfs_declared=True)),
)

if "code_description" in kwargs:
self.set_var("code_description", kwargs["code_description"])
else:
self.set_var("code_description", "Update this initial code:")
self.set_var(
"code_description",
kwargs.pop("code_description", "Update this initial code:"),
)

if "last_message" in kwargs:
self.set_var("last_message", kwargs["last_message"])
else:
self.set_var("last_message", "")
self.set_var("last_message", kwargs.pop("last_message", ""))

if "prev_conversation" in kwargs:
self.set_var("prev_conversation", kwargs["prev_conversation"])
else:
self.set_var("prev_conversation", "")
self.set_var("prev_conversation", kwargs.pop("prev_conversation", ""))

def on_prompt_generation(self) -> None:
default_import = "import pandas as pd"
Expand Down

0 comments on commit 55c5c74

Please sign in to comment.