From 237b76532b5756290edd8e5e276df90c648eb56a Mon Sep 17 00:00:00 2001 From: Matt Dancho Date: Wed, 8 Jan 2025 20:27:42 -0500 Subject: [PATCH] #16 experimental DataCleaningAgent class --- ai_data_science_team/agents/__init__.py | 2 +- .../agents/data_cleaning_agent.py | 72 +++++++++++++++++++ 2 files changed, 73 insertions(+), 1 deletion(-) diff --git a/ai_data_science_team/agents/__init__.py b/ai_data_science_team/agents/__init__.py index 10ea842..1e00b8a 100644 --- a/ai_data_science_team/agents/__init__.py +++ b/ai_data_science_team/agents/__init__.py @@ -1,4 +1,4 @@ -from ai_data_science_team.agents.data_cleaning_agent import make_data_cleaning_agent +from ai_data_science_team.agents.data_cleaning_agent import make_data_cleaning_agent, DataCleaningAgent from ai_data_science_team.agents.feature_engineering_agent import make_feature_engineering_agent from ai_data_science_team.agents.data_wrangling_agent import make_data_wrangling_agent from ai_data_science_team.agents.sql_database_agent import make_sql_database_agent diff --git a/ai_data_science_team/agents/data_cleaning_agent.py b/ai_data_science_team/agents/data_cleaning_agent.py index 7da1130..a7fc447 100644 --- a/ai_data_science_team/agents/data_cleaning_agent.py +++ b/ai_data_science_team/agents/data_cleaning_agent.py @@ -13,6 +13,8 @@ from langgraph.types import Command from langgraph.checkpoint.memory import MemorySaver +from langgraph.graph.state import CompiledStateGraph + import os import io import pandas as pd @@ -33,6 +35,76 @@ AGENT_NAME = "data_cleaning_agent" LOG_PATH = os.path.join(os.getcwd(), "logs/") + +# Class +class DataCleaningAgent(CompiledStateGraph): + """ + Wraps a compiled data cleaning agent (CompiledStateGraph) and extends + its functionality for data cleaning purposes. All methods not found on + this class automatically delegate to self._compiled_graph via __getattr__. + """ + def __init__( + self, + model, + n_samples=30, + log=False, + log_path=None, + file_name="data_cleaner.py", + overwrite=True, + human_in_the_loop=False, + bypass_recommended_steps=False, + bypass_explain_code=False + ): + self._params = { + "model": model, + "n_samples": n_samples, + "log": log, + "log_path": log_path, + "file_name": file_name, + "overwrite": overwrite, + "human_in_the_loop": human_in_the_loop, + "bypass_recommended_steps": bypass_recommended_steps, + "bypass_explain_code": bypass_explain_code, + } + self._compiled_graph = self._make_compiled_graph() + + def _make_compiled_graph(self): + return make_data_cleaning_agent(**self._params) + + def update_params(self, **kwargs): + """ + Update one or more parameters at once, then rebuild the compiled graph. + e.g. agent.update_params(model=new_llm, n_samples=100) + """ + self._params.update(kwargs) + self._compiled_graph = self._make_compiled_graph() + + def __getattr__(self, name: str): + """ + Delegate attribute access to `_compiled_graph` if `name` is not + found in this instance. This 'inherits' methods from the compiled graph. + """ + return getattr(self._compiled_graph, name) + + # def __dir__(self): + # """ + # Combine this class’s attributes with those of _compiled_graph + # for improved autocompletion in some IDEs. + # """ + # return sorted( + # set( + # dir(type(self)) + # + super().__dir__() + # + list(self.__dict__.keys()) + # + dir(self._compiled_graph) + # ) + # ) + # # return super().__dir__() + [str(k) for k in self.__dict__.keys()] + + + + + # Agent def make_data_cleaning_agent(