From 0c0f5fdfcf2858a20565f2f88f4ff89811a835f8 Mon Sep 17 00:00:00 2001 From: Matt Dancho Date: Thu, 9 Jan 2025 21:53:05 -0500 Subject: [PATCH] #16 OOP SQLDatabaseAgent Class --- .../agents/sql_database_agent.py | 308 +++++++++++++++++- 1 file changed, 305 insertions(+), 3 deletions(-) diff --git a/ai_data_science_team/agents/sql_database_agent.py b/ai_data_science_team/agents/sql_database_agent.py index 9e595b3..0467f3a 100644 --- a/ai_data_science_team/agents/sql_database_agent.py +++ b/ai_data_science_team/agents/sql_database_agent.py @@ -10,16 +10,18 @@ from langgraph.checkpoint.memory import MemorySaver import os -import io import pandas as pd import sqlalchemy as sql +from IPython.display import Markdown + from ai_data_science_team.templates import( node_func_execute_agent_from_sql_connection, node_func_human_review, node_func_fix_agent_code, node_func_explain_agent_code, - create_coding_agent_graph + create_coding_agent_graph, + BaseAgent, ) from ai_data_science_team.tools.parsers import PythonOutputParser, SQLOutputParser from ai_data_science_team.tools.regex import relocate_imports_inside_function, add_comments_to_top, format_agent_name @@ -30,6 +32,306 @@ AGENT_NAME = "sql_database_agent" LOG_PATH = os.path.join(os.getcwd(), "logs/") +# Class + +class SQLDatabaseAgent(BaseAgent): + """ + Creates a SQL Database Agent that can recommend SQL steps and generate SQL code to query a database. + The agent can: + - Propose recommended steps to answer a user's query or instructions. + - Generate a SQL query based on the recommended steps and user instructions. + - Execute that SQL query against the provided database connection. + - Return the resulting data as a dictionary, suitable for conversion to a DataFrame or other structures. + - Log generated code and errors if enabled. + + Parameters + ---------- + model : ChatOpenAI or langchain.llms.base.LLM + The language model used to generate the SQL code. + connection : sqlalchemy.engine.base.Engine or sqlalchemy.engine.base.Connection + The SQLAlchemy connection (or engine) to the database. + n_samples : int, optional + Number of sample rows (per column) to retrieve when summarizing database metadata. Defaults to 10. + log : bool, optional + Whether to log the generated code and errors. Defaults to False. + log_path : str, optional + Directory path for storing log files. Defaults to None. + file_name : str, optional + Name of the file for saving the generated response. Defaults to "sql_database.py". + overwrite : bool, optional + Whether to overwrite the log file if it exists. If False, a unique file name is created. Defaults to True. + human_in_the_loop : bool, optional + Enables user review of the recommended steps before generating code. Defaults to False. + bypass_recommended_steps : bool, optional + If True, skips the step that generates recommended SQL steps. Defaults to False. + bypass_explain_code : bool, optional + If True, skips the step that provides code explanations. Defaults to False. + + Methods + ------- + update_params(**kwargs) + Updates the agent's parameters and rebuilds the compiled state graph. + ainvoke(user_instructions: str, max_retries=3, retry_count=0) + Asynchronously runs the agent to generate and execute a SQL query based on user instructions. + invoke(user_instructions: str, max_retries=3, retry_count=0) + Synchronously runs the agent to generate and execute a SQL query based on user instructions. + explain_sql_steps() + Returns an explanation of the SQL steps performed by the agent. + get_log_summary() + Retrieves a summary of logged operations if logging is enabled. + get_data_sql() + Retrieves the resulting data from the SQL query as a dictionary. + (You can convert this to a DataFrame if desired.) + get_sql_query_code() + Retrieves the exact SQL query generated by the agent. + get_sql_database_function() + Retrieves the Python function that executes the SQL query. + get_recommended_sql_steps() + Retrieves the recommended steps for querying the SQL database. + get_response() + Returns the full response dictionary from the agent. + show() + Displays the agent's mermaid diagram for visual inspection of the compiled graph. + + Examples + -------- + ```python + import sqlalchemy as sql + from langchain_openai import ChatOpenAI + from ai_data_science_team.agents import SQLDatabaseAgent + + # Create the engine/connection + sql_engine = sql.create_engine("sqlite:///data/my_database.db") + conn = sql_engine.connect() + + llm = ChatOpenAI(model="gpt-4o-mini") + + sql_database_agent = SQLDatabaseAgent( + model=llm, + connection=conn, + n_samples=10, + log=True, + log_path="logs", + human_in_the_loop=True + ) + + # Example usage + sql_database_agent.invoke( + user_instructions="List all the tables in the database.", + max_retries=3, + retry_count=0 + ) + + data_result = sql_database_agent.get_data_sql() # dictionary of rows returned + sql_code = sql_database_agent.get_sql_query_code() + response = sql_database_agent.get_response() + ``` + + Returns + ------- + SQLDatabaseAgent : langchain.graphs.CompiledStateGraph + A SQL database agent implemented as a compiled state graph. + """ + + def __init__( + self, + model, + connection, + n_samples=10, + log=False, + log_path=None, + file_name="sql_database.py", + overwrite=True, + human_in_the_loop=False, + bypass_recommended_steps=False, + bypass_explain_code=False + ): + self._params = { + "model": model, + "connection": connection, + "n_samples": n_samples, + "log": log, + "log_path": log_path, + "file_name": file_name, + "overwrite": overwrite, + "human_in_the_loop": human_in_the_loop, + "bypass_recommended_steps": bypass_recommended_steps, + "bypass_explain_code": bypass_explain_code + } + self._compiled_graph = self._make_compiled_graph() + self.response = None + + def _make_compiled_graph(self): + """ + Create or rebuild the compiled graph for the SQL Database Agent. + Running this method resets the response to None. + """ + self.response = None + return make_sql_database_agent(**self._params) + + def update_params(self, **kwargs): + """ + Updates the agent's parameters (e.g. connection, n_samples, log, etc.) + and rebuilds the compiled graph. + """ + for k, v in kwargs.items(): + self._params[k] = v + self._compiled_graph = self._make_compiled_graph() + + def ainvoke(self, user_instructions: str, max_retries=3, retry_count=0): + """ + Asynchronously runs the SQL Database Agent based on user instructions. + + Parameters + ---------- + user_instructions : str + Instructions for the SQL query or metadata request. + max_retries : int, optional + Maximum retry attempts. Defaults to 3. + retry_count : int, optional + Current retry count. Defaults to 0. + + Returns + ------- + None + """ + response = self._compiled_graph.ainvoke({ + "user_instructions": user_instructions, + "max_retries": max_retries, + "retry_count": retry_count + }) + self.response = response + + def invoke(self, user_instructions: str, max_retries=3, retry_count=0): + """ + Synchronously runs the SQL Database Agent based on user instructions. + + Parameters + ---------- + user_instructions : str + Instructions for the SQL query or metadata request. + max_retries : int, optional + Maximum retry attempts. Defaults to 3. + retry_count : int, optional + Current retry count. Defaults to 0. + + Returns + ------- + None + """ + response = self._compiled_graph.invoke({ + "user_instructions": user_instructions, + "max_retries": max_retries, + "retry_count": retry_count + }) + self.response = response + + def explain_sql_steps(self): + """ + Provides an explanation of the SQL steps performed by the agent + if the explain step is not bypassed. + + Returns + ------- + str or list + An explanation of the SQL steps. + """ + if self.response: + return self.response.get("messages", []) + return [] + + def get_log_summary(self, markdown=False): + """ + Retrieves a summary of the logging details if logging is enabled. + + Parameters + ---------- + markdown : bool, optional + If True, returns the summary in Markdown format. + + Returns + ------- + str or None + Log details or None if logging is not used or data is unavailable. + """ + if self.response and self.response.get("sql_database_function_path"): + log_details = f"Log Path: {self.response['sql_database_function_path']}" + if markdown: + return Markdown(log_details) + return log_details + return None + + def get_data_sql(self): + """ + Retrieves the SQL query result from the agent's response. + + Returns + ------- + dict or None + The returned data as a dictionary of column -> list_of_values, + or None if no data is found. + """ + if self.response and "data_sql" in self.response: + return self.response["data_sql"] + return None + + def get_sql_query_code(self): + """ + Retrieves the raw SQL query code generated by the agent (if available). + + Returns + ------- + str or None + The SQL query as a string, or None if not available. + """ + if self.response and "sql_query_code" in self.response: + return self.response["sql_query_code"] + return None + + def get_sql_database_function(self, markdown=False): + """ + Retrieves the Python function code used to execute the SQL query. + + Parameters + ---------- + markdown : bool, optional + If True, returns the code in a Markdown code block. + + Returns + ------- + str or None + The function code if available, otherwise None. + """ + if self.response and "sql_database_function" in self.response: + code = self.response["sql_database_function"] + if markdown: + return Markdown(f"```python\n{code}\n```") + return code + return None + + def get_recommended_sql_steps(self, markdown=False): + """ + Retrieves the recommended SQL steps from the agent's response. + + Parameters + ---------- + markdown : bool, optional + If True, returns the steps in Markdown format. + + Returns + ------- + str or None + Recommended steps or None if not available. + """ + if self.response and "recommended_steps" in self.response: + if markdown: + return Markdown(self.response["recommended_steps"]) + return self.response["recommended_steps"] + return None + + + +# Function def make_sql_database_agent( model, connection, @@ -71,7 +373,7 @@ def make_sql_database_agent( Returns ------- - app : langchain.graphs.StateGraph + app : langchain.graphs.CompiledStateGraph The data cleaning agent as a state graph. Examples