Skip to content

Commit

Permalink
#16 OOP DataVisualizationAgent Class
Browse files Browse the repository at this point in the history
  • Loading branch information
mdancho84 committed Jan 10, 2025
1 parent bbaf682 commit 77f55b4
Show file tree
Hide file tree
Showing 2 changed files with 331 additions and 3 deletions.
2 changes: 1 addition & 1 deletion ai_data_science_team/agents/data_cleaning_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import io
import pandas as pd

from IPython.display import Image, display, Markdown
from IPython.display import Markdown

from ai_data_science_team.templates import(
node_func_execute_agent_code_on_data,
Expand Down
332 changes: 330 additions & 2 deletions ai_data_science_team/agents/data_visualization_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,17 @@
from langgraph.checkpoint.memory import MemorySaver

import os
import io
import pandas as pd

from IPython.display import Markdown

from ai_data_science_team.templates import(
node_func_execute_agent_code_on_data,
node_func_human_review,
node_func_fix_agent_code,
node_func_explain_agent_code,
create_coding_agent_graph
create_coding_agent_graph,
BaseAgent,
)
from ai_data_science_team.tools.parsers import PythonOutputParser
from ai_data_science_team.tools.regex import relocate_imports_inside_function, add_comments_to_top, format_agent_name
Expand All @@ -36,6 +38,332 @@
AGENT_NAME = "data_visualization_agent"
LOG_PATH = os.path.join(os.getcwd(), "logs/")

# Class

class DataVisualizationAgent(BaseAgent):
"""
Creates a data visualization agent that can generate Plotly charts based on user-defined instructions or
default visualization steps (if any). The agent generates a Python function to produce the visualization,
executes it, and logs the process, including code and errors. It is designed to facilitate reproducible
and customizable data visualization workflows.
The agent may use default instructions for creating charts unless instructed otherwise, such as:
- Generating a recommended chart type (bar, scatter, line, etc.)
- Creating user-friendly titles and axis labels
- Applying consistent styling (template, font sizes, color themes)
- Handling theme details (white background, base font size, line size, etc.)
User instructions can modify, add, or remove any of these steps to tailor the visualization process.
Parameters
----------
model : langchain.llms.base.LLM
The language model used to generate the data visualization function.
n_samples : int, optional
Number of samples used when summarizing the dataset for chart instructions. Defaults to 30.
Reducing this number can help avoid exceeding the model's token limits.
log : bool, optional
Whether to log the generated code and errors. Defaults to False.
log_path : str, optional
Directory path for storing log files. Defaults to None.
file_name : str, optional
Name of the file for saving the generated response. Defaults to "data_visualization.py".
overwrite : bool, optional
Whether to overwrite the log file if it exists. If False, a unique file name is created. Defaults to True.
human_in_the_loop : bool, optional
Enables user review of data visualization instructions. Defaults to False.
bypass_recommended_steps : bool, optional
If True, skips the default recommended visualization steps. Defaults to False.
bypass_explain_code : bool, optional
If True, skips the step that provides code explanations. Defaults to False.
Methods
-------
update_params(**kwargs)
Updates the agent's parameters and rebuilds the compiled state graph.
ainvoke(user_instructions: str, data_raw: pd.DataFrame, max_retries=3, retry_count=0)
Asynchronously generates a visualization based on user instructions.
invoke(user_instructions: str, data_raw: pd.DataFrame, max_retries=3, retry_count=0)
Synchronously generates a visualization based on user instructions.
explain_visualization_steps()
Returns an explanation of the visualization steps performed by the agent.
get_log_summary()
Retrieves a summary of logged operations if logging is enabled.
get_plotly_graph()
Retrieves the Plotly graph (as a dictionary) produced by the agent.
get_data_raw()
Retrieves the raw dataset as a pandas DataFrame (based on the last response).
get_data_visualization_function()
Retrieves the generated Python function used for data visualization.
get_recommended_visualization_steps()
Retrieves the agent's recommended visualization steps.
get_response()
Returns the response from the agent as a dictionary.
show()
Displays the agent's mermaid diagram.
Examples
--------
```python
import pandas as pd
from langchain_openai import ChatOpenAI
from ai_data_science_team.agents import DataVisualizationAgent
llm = ChatOpenAI(model="gpt-4o-mini")
data_visualization_agent = DataVisualizationAgent(
model=llm,
n_samples=30,
log=True,
log_path="logs",
human_in_the_loop=True
)
df = pd.read_csv("https://raw.githubusercontent.com/business-science/ai-data-science-team/refs/heads/master/data/churn_data.csv")
data_visualization_agent.invoke(
user_instructions="Generate a scatter plot of age vs. total charges with a trend line.",
data_raw=df,
max_retries=3,
retry_count=0
)
plotly_graph_dict = data_visualization_agent.get_plotly_graph()
# You can render plotly_graph_dict with plotly.io.from_json or
# something similar in a Jupyter Notebook.
response = data_visualization_agent.get_response()
```
Returns
--------
DataVisualizationAgent : langchain.graphs.CompiledStateGraph
A data visualization agent implemented as a compiled state graph.
"""

def __init__(
self,
model,
n_samples=30,
log=False,
log_path=None,
file_name="data_visualization.py",
overwrite=True,
human_in_the_loop=False,
bypass_recommended_steps=False,
bypass_explain_code=False
):
self._params = {
"model": model,
"n_samples": n_samples,
"log": log,
"log_path": log_path,
"file_name": file_name,
"overwrite": overwrite,
"human_in_the_loop": human_in_the_loop,
"bypass_recommended_steps": bypass_recommended_steps,
"bypass_explain_code": bypass_explain_code,
}
self._compiled_graph = self._make_compiled_graph()
self.response = None

def _make_compiled_graph(self):
"""
Create the compiled graph for the data visualization agent.
Running this method will reset the response to None.
"""
self.response = None
return make_data_visualization_agent(**self._params)

def update_params(self, **kwargs):
"""
Updates the agent's parameters and rebuilds the compiled graph.
"""
# Update parameters
for k, v in kwargs.items():
self._params[k] = v
# Rebuild the compiled graph
self._compiled_graph = self._make_compiled_graph()

def ainvoke(self, user_instructions: str, data_raw: pd.DataFrame, max_retries=3, retry_count=0):
"""
Asynchronously invokes the agent to generate a visualization.
The response is stored in the 'response' attribute.
Parameters
----------
user_instructions : str
Instructions for data visualization.
data_raw : pd.DataFrame
The raw dataset to be visualized.
max_retries : int
Maximum retry attempts.
retry_count : int
Current retry attempt count.
Returns
-------
None
"""
response = self._compiled_graph.ainvoke({
"user_instructions": user_instructions,
"data_raw": data_raw.to_dict(),
"max_retries": max_retries,
"retry_count": retry_count,
})
self.response = response
return None

def invoke(self, user_instructions: str, data_raw: pd.DataFrame, max_retries=3, retry_count=0):
"""
Synchronously invokes the agent to generate a visualization.
The response is stored in the 'response' attribute.
Parameters
----------
user_instructions : str
Instructions for data visualization.
data_raw : pd.DataFrame
The raw dataset to be visualized.
max_retries : int
Maximum retry attempts.
retry_count : int
Current retry attempt count.
Returns
-------
None
"""
response = self._compiled_graph.invoke({
"user_instructions": user_instructions,
"data_raw": data_raw.to_dict(),
"max_retries": max_retries,
"retry_count": retry_count,
})
self.response = response
return None

def explain_visualization_steps(self):
"""
Provides an explanation of the visualization steps performed by the agent.
Returns
-------
str
Explanation of the visualization steps, if any are available.
"""
if self.response:
return self.response.get("messages", [])
return []

def get_log_summary(self, markdown=False):
"""
Logs a summary of the agent's operations, if logging is enabled.
Parameters
----------
markdown : bool, optional
If True, returns Markdown-formatted output.
Returns
-------
str or None
Summary of logs or None if no logs are available.
"""
if self.response and self.response.get('data_visualization_function_path'):
log_details = f"Log Path: {self.response.get('data_visualization_function_path')}"
if markdown:
return Markdown(log_details)
else:
return log_details
return None

def get_plotly_graph(self):
"""
Retrieves the Plotly graph (in dictionary form) produced by the agent.
Returns
-------
dict or None
The Plotly graph dictionary if available, otherwise None.
"""
if self.response:
return self.response.get("plotly_graph", None)
return None

def get_data_raw(self):
"""
Retrieves the raw dataset used in the last invocation.
Returns
-------
pd.DataFrame or None
The raw dataset as a DataFrame if available, otherwise None.
"""
if self.response and self.response.get("data_raw"):
return pd.DataFrame(self.response.get("data_raw"))
return None

def get_data_visualization_function(self, markdown=False):
"""
Retrieves the generated Python function used for data visualization.
Parameters
----------
markdown : bool, optional
If True, returns the function in Markdown code block format.
Returns
-------
str or None
The Python function code as a string if available, otherwise None.
"""
if self.response:
func_code = self.response.get("data_visualization_function", "")
if markdown:
return Markdown(f"```python\n{func_code}\n```")
return func_code
return None

def get_recommended_visualization_steps(self, markdown=False):
"""
Retrieves the agent's recommended visualization steps.
Parameters
----------
markdown : bool, optional
If True, returns the steps in Markdown format.
Returns
-------
str or None
The recommended steps if available, otherwise None.
"""
if self.response:
steps = self.response.get("recommended_steps", "")
if markdown:
return Markdown(steps)
return steps
return None

def get_response(self):
"""
Returns the agent's full response dictionary.
Returns
-------
dict or None
The response dictionary if available, otherwise None.
"""
return self.response

def show(self):
"""
Displays the agent's mermaid diagram for visual inspection of the compiled graph.
"""
return self._compiled_graph.show()


# Agent

def make_data_visualization_agent(
Expand Down

0 comments on commit 77f55b4

Please sign in to comment.