Skip to content

Commit

Permalink
#16 OOP SQLDatabaseAgent Class
Browse files Browse the repository at this point in the history
  • Loading branch information
mdancho84 committed Jan 10, 2025
1 parent bf46c00 commit 0c0f5fd
Showing 1 changed file with 305 additions and 3 deletions.
308 changes: 305 additions & 3 deletions ai_data_science_team/agents/sql_database_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,18 @@
from langgraph.checkpoint.memory import MemorySaver

import os
import io
import pandas as pd
import sqlalchemy as sql

from IPython.display import Markdown

from ai_data_science_team.templates import(
node_func_execute_agent_from_sql_connection,
node_func_human_review,
node_func_fix_agent_code,
node_func_explain_agent_code,
create_coding_agent_graph
create_coding_agent_graph,
BaseAgent,
)
from ai_data_science_team.tools.parsers import PythonOutputParser, SQLOutputParser
from ai_data_science_team.tools.regex import relocate_imports_inside_function, add_comments_to_top, format_agent_name
Expand All @@ -30,6 +32,306 @@
AGENT_NAME = "sql_database_agent"
LOG_PATH = os.path.join(os.getcwd(), "logs/")

# Class

class SQLDatabaseAgent(BaseAgent):
"""
Creates a SQL Database Agent that can recommend SQL steps and generate SQL code to query a database.
The agent can:
- Propose recommended steps to answer a user's query or instructions.
- Generate a SQL query based on the recommended steps and user instructions.
- Execute that SQL query against the provided database connection.
- Return the resulting data as a dictionary, suitable for conversion to a DataFrame or other structures.
- Log generated code and errors if enabled.
Parameters
----------
model : ChatOpenAI or langchain.llms.base.LLM
The language model used to generate the SQL code.
connection : sqlalchemy.engine.base.Engine or sqlalchemy.engine.base.Connection
The SQLAlchemy connection (or engine) to the database.
n_samples : int, optional
Number of sample rows (per column) to retrieve when summarizing database metadata. Defaults to 10.
log : bool, optional
Whether to log the generated code and errors. Defaults to False.
log_path : str, optional
Directory path for storing log files. Defaults to None.
file_name : str, optional
Name of the file for saving the generated response. Defaults to "sql_database.py".
overwrite : bool, optional
Whether to overwrite the log file if it exists. If False, a unique file name is created. Defaults to True.
human_in_the_loop : bool, optional
Enables user review of the recommended steps before generating code. Defaults to False.
bypass_recommended_steps : bool, optional
If True, skips the step that generates recommended SQL steps. Defaults to False.
bypass_explain_code : bool, optional
If True, skips the step that provides code explanations. Defaults to False.
Methods
-------
update_params(**kwargs)
Updates the agent's parameters and rebuilds the compiled state graph.
ainvoke(user_instructions: str, max_retries=3, retry_count=0)
Asynchronously runs the agent to generate and execute a SQL query based on user instructions.
invoke(user_instructions: str, max_retries=3, retry_count=0)
Synchronously runs the agent to generate and execute a SQL query based on user instructions.
explain_sql_steps()
Returns an explanation of the SQL steps performed by the agent.
get_log_summary()
Retrieves a summary of logged operations if logging is enabled.
get_data_sql()
Retrieves the resulting data from the SQL query as a dictionary.
(You can convert this to a DataFrame if desired.)
get_sql_query_code()
Retrieves the exact SQL query generated by the agent.
get_sql_database_function()
Retrieves the Python function that executes the SQL query.
get_recommended_sql_steps()
Retrieves the recommended steps for querying the SQL database.
get_response()
Returns the full response dictionary from the agent.
show()
Displays the agent's mermaid diagram for visual inspection of the compiled graph.
Examples
--------
```python
import sqlalchemy as sql
from langchain_openai import ChatOpenAI
from ai_data_science_team.agents import SQLDatabaseAgent
# Create the engine/connection
sql_engine = sql.create_engine("sqlite:///data/my_database.db")
conn = sql_engine.connect()
llm = ChatOpenAI(model="gpt-4o-mini")
sql_database_agent = SQLDatabaseAgent(
model=llm,
connection=conn,
n_samples=10,
log=True,
log_path="logs",
human_in_the_loop=True
)
# Example usage
sql_database_agent.invoke(
user_instructions="List all the tables in the database.",
max_retries=3,
retry_count=0
)
data_result = sql_database_agent.get_data_sql() # dictionary of rows returned
sql_code = sql_database_agent.get_sql_query_code()
response = sql_database_agent.get_response()
```
Returns
-------
SQLDatabaseAgent : langchain.graphs.CompiledStateGraph
A SQL database agent implemented as a compiled state graph.
"""

def __init__(
self,
model,
connection,
n_samples=10,
log=False,
log_path=None,
file_name="sql_database.py",
overwrite=True,
human_in_the_loop=False,
bypass_recommended_steps=False,
bypass_explain_code=False
):
self._params = {
"model": model,
"connection": connection,
"n_samples": n_samples,
"log": log,
"log_path": log_path,
"file_name": file_name,
"overwrite": overwrite,
"human_in_the_loop": human_in_the_loop,
"bypass_recommended_steps": bypass_recommended_steps,
"bypass_explain_code": bypass_explain_code
}
self._compiled_graph = self._make_compiled_graph()
self.response = None

def _make_compiled_graph(self):
"""
Create or rebuild the compiled graph for the SQL Database Agent.
Running this method resets the response to None.
"""
self.response = None
return make_sql_database_agent(**self._params)

def update_params(self, **kwargs):
"""
Updates the agent's parameters (e.g. connection, n_samples, log, etc.)
and rebuilds the compiled graph.
"""
for k, v in kwargs.items():
self._params[k] = v
self._compiled_graph = self._make_compiled_graph()

def ainvoke(self, user_instructions: str, max_retries=3, retry_count=0):
"""
Asynchronously runs the SQL Database Agent based on user instructions.
Parameters
----------
user_instructions : str
Instructions for the SQL query or metadata request.
max_retries : int, optional
Maximum retry attempts. Defaults to 3.
retry_count : int, optional
Current retry count. Defaults to 0.
Returns
-------
None
"""
response = self._compiled_graph.ainvoke({
"user_instructions": user_instructions,
"max_retries": max_retries,
"retry_count": retry_count
})
self.response = response

def invoke(self, user_instructions: str, max_retries=3, retry_count=0):
"""
Synchronously runs the SQL Database Agent based on user instructions.
Parameters
----------
user_instructions : str
Instructions for the SQL query or metadata request.
max_retries : int, optional
Maximum retry attempts. Defaults to 3.
retry_count : int, optional
Current retry count. Defaults to 0.
Returns
-------
None
"""
response = self._compiled_graph.invoke({
"user_instructions": user_instructions,
"max_retries": max_retries,
"retry_count": retry_count
})
self.response = response

def explain_sql_steps(self):
"""
Provides an explanation of the SQL steps performed by the agent
if the explain step is not bypassed.
Returns
-------
str or list
An explanation of the SQL steps.
"""
if self.response:
return self.response.get("messages", [])
return []

def get_log_summary(self, markdown=False):
"""
Retrieves a summary of the logging details if logging is enabled.
Parameters
----------
markdown : bool, optional
If True, returns the summary in Markdown format.
Returns
-------
str or None
Log details or None if logging is not used or data is unavailable.
"""
if self.response and self.response.get("sql_database_function_path"):
log_details = f"Log Path: {self.response['sql_database_function_path']}"
if markdown:
return Markdown(log_details)
return log_details
return None

def get_data_sql(self):
"""
Retrieves the SQL query result from the agent's response.
Returns
-------
dict or None
The returned data as a dictionary of column -> list_of_values,
or None if no data is found.
"""
if self.response and "data_sql" in self.response:
return self.response["data_sql"]
return None

def get_sql_query_code(self):
"""
Retrieves the raw SQL query code generated by the agent (if available).
Returns
-------
str or None
The SQL query as a string, or None if not available.
"""
if self.response and "sql_query_code" in self.response:
return self.response["sql_query_code"]
return None

def get_sql_database_function(self, markdown=False):
"""
Retrieves the Python function code used to execute the SQL query.
Parameters
----------
markdown : bool, optional
If True, returns the code in a Markdown code block.
Returns
-------
str or None
The function code if available, otherwise None.
"""
if self.response and "sql_database_function" in self.response:
code = self.response["sql_database_function"]
if markdown:
return Markdown(f"```python\n{code}\n```")
return code
return None

def get_recommended_sql_steps(self, markdown=False):
"""
Retrieves the recommended SQL steps from the agent's response.
Parameters
----------
markdown : bool, optional
If True, returns the steps in Markdown format.
Returns
-------
str or None
Recommended steps or None if not available.
"""
if self.response and "recommended_steps" in self.response:
if markdown:
return Markdown(self.response["recommended_steps"])
return self.response["recommended_steps"]
return None



# Function

def make_sql_database_agent(
model, connection,
Expand Down Expand Up @@ -71,7 +373,7 @@ def make_sql_database_agent(
Returns
-------
app : langchain.graphs.StateGraph
app : langchain.graphs.CompiledStateGraph
The data cleaning agent as a state graph.
Examples
Expand Down

0 comments on commit 0c0f5fd

Please sign in to comment.