From dcbfb5bc7a5ae3f06e1e42022394222b3346291c Mon Sep 17 00:00:00 2001 From: Ihor <31508183+nautics889@users.noreply.github.com> Date: Wed, 13 Dec 2023 02:24:58 +0200 Subject: [PATCH] refactor: DRY enhancement prompt classes (#808) (#809) * refactor: DRY enhancement prompt classes (#808) * (refactor): `DirectSQLPrompt` extends `GeneratePythonCodePrompt` class * (refactor): reduce duplication of logic in `DirectSQLPrompt` * (refactor): use `kwargs.pop()` in setup method * refactor: DRY enhancement prompt classes (#808) * (fix): linter issue --- pandasai/prompts/direct_sql_prompt.py | 41 +++--------------------- pandasai/prompts/generate_python_code.py | 31 +++++++----------- 2 files changed, 15 insertions(+), 57 deletions(-) diff --git a/pandasai/prompts/direct_sql_prompt.py b/pandasai/prompts/direct_sql_prompt.py index 37c28d5b4..d48c8afdb 100644 --- a/pandasai/prompts/direct_sql_prompt.py +++ b/pandasai/prompts/direct_sql_prompt.py @@ -1,12 +1,8 @@ """ Prompt to explain code generation by the LLM""" -from .file_based_prompt import FileBasedPrompt -from .generate_python_code import ( - CurrentCodePrompt, - SimpleReasoningPrompt, -) +from .generate_python_code import CurrentCodePrompt, GeneratePythonCodePrompt -class DirectSQLPrompt(FileBasedPrompt): +class DirectSQLPrompt(GeneratePythonCodePrompt): """Prompt to explain code generation by the LLM""" _path_to_template = "assets/prompt_templates/direct_sql_connector.tmpl" @@ -27,35 +23,6 @@ def _prepare_tables_data(self, tables): def setup(self, tables, **kwargs) -> None: self.set_var("tables", self._prepare_tables_data(tables)) - if "custom_instructions" in kwargs: - self.set_var("instructions", kwargs["custom_instructions"]) - else: - self.set_var("instructions", "") + super(DirectSQLPrompt, self).setup(**kwargs) - if "current_code" in kwargs: - self.set_var("current_code", kwargs["current_code"]) - else: - self.set_var("current_code", CurrentCodePrompt()) - - if "code_description" in kwargs: - self.set_var("code_description", kwargs["code_description"]) - else: - self.set_var("code_description", "Update this initial code:") - - if "last_message" in kwargs: - self.set_var("last_message", kwargs["last_message"]) - else: - self.set_var("last_message", "") - - if "prev_conversation" in kwargs: - self.set_var("prev_conversation", kwargs["prev_conversation"]) - else: - self.set_var("prev_conversation", "") - - def on_prompt_generation(self) -> None: - default_import = "import pandas as pd" - engine_df_name = "pd.DataFrame" - - self.set_var("default_import", default_import) - self.set_var("engine_df_name", engine_df_name) - self.set_var("reasoning", SimpleReasoningPrompt()) + self.set_var("current_code", kwargs.pop("current_code", CurrentCodePrompt())) diff --git a/pandasai/prompts/generate_python_code.py b/pandasai/prompts/generate_python_code.py index a41930974..13a5b81a0 100644 --- a/pandasai/prompts/generate_python_code.py +++ b/pandasai/prompts/generate_python_code.py @@ -53,30 +53,21 @@ class GeneratePythonCodePrompt(FileBasedPrompt): _path_to_template = "assets/prompt_templates/generate_python_code.tmpl" def setup(self, **kwargs) -> None: - if "custom_instructions" in kwargs: - self.set_var("instructions", kwargs["custom_instructions"]) - else: - self.set_var("instructions", "") + self.set_var("instructions", kwargs.pop("custom_instructions", "")) - if "current_code" in kwargs: - self.set_var("current_code", kwargs["current_code"]) - else: - self.set_var("current_code", CurrentCodePrompt(dfs_declared=True)) + self.set_var( + "current_code", + kwargs.pop("current_code", CurrentCodePrompt(dfs_declared=True)), + ) - if "code_description" in kwargs: - self.set_var("code_description", kwargs["code_description"]) - else: - self.set_var("code_description", "Update this initial code:") + self.set_var( + "code_description", + kwargs.pop("code_description", "Update this initial code:"), + ) - if "last_message" in kwargs: - self.set_var("last_message", kwargs["last_message"]) - else: - self.set_var("last_message", "") + self.set_var("last_message", kwargs.pop("last_message", "")) - if "prev_conversation" in kwargs: - self.set_var("prev_conversation", kwargs["prev_conversation"]) - else: - self.set_var("prev_conversation", "") + self.set_var("prev_conversation", kwargs.pop("prev_conversation", "")) def on_prompt_generation(self) -> None: default_import = "import pandas as pd"