From 198af20a129b863de5162bc5339fbf404fab9d21 Mon Sep 17 00:00:00 2001 From: Arslan Saleem Date: Tue, 10 Dec 2024 13:33:26 +0100 Subject: [PATCH] =?UTF-8?q?feat(unified=5Fconfig):=20keep=20global=20confi?= =?UTF-8?q?g=20and=20use=20throughout=20the=20appli=E2=80=A6=20(#1467)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat(Unified_config): keep global config and use throughout the application unless changed explicit * remove extra comments * feat: improve deprecation warning Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> --------- Co-authored-by: Gabriele Venturi Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> --- pandasai/__init__.py | 8 + pandasai/agent/base.py | 21 ++- pandasai/agent/state.py | 23 ++- pandasai/config.py | 49 ++++++ pandasai/smart_dataframe/__init__.py | 223 +++++++++++++++++++++++++++ pandasai/smart_datalake/__init__.py | 197 +++++++++++++++++++++++ 6 files changed, 514 insertions(+), 7 deletions(-) create mode 100644 pandasai/smart_dataframe/__init__.py create mode 100644 pandasai/smart_datalake/__init__.py diff --git a/pandasai/__init__.py b/pandasai/__init__.py index 8ccd2fe21..31ecf943c 100644 --- a/pandasai/__init__.py +++ b/pandasai/__init__.py @@ -10,6 +10,7 @@ import pandas as pd +from pandasai.config import ConfigManager from pandasai.exceptions import DatasetNotFound, PandasAIApiKeyError from pandasai.helpers.path import find_project_root from pandasai.helpers.request import get_pandaai_session @@ -17,11 +18,16 @@ from .helpers.cache import Cache from .dataframe.base import DataFrame from .data_loader.loader import DatasetLoader +from .smart_dataframe import SmartDataframe +from .smart_datalake import SmartDatalake # Global variable to store the current agent _current_agent = None +config = ConfigManager() + + def clear_cache(filename: str = None): """Clear the cache""" cache = Cache(filename) if filename else Cache() @@ -119,4 +125,6 @@ def read_csv(filepath: str) -> DataFrame: "chat", "follow_up", "load", + "SmartDataframe", + "SmartDatalake", ] diff --git a/pandasai/agent/base.py b/pandasai/agent/base.py index 0d5206df0..7107e9e2b 100644 --- a/pandasai/agent/base.py +++ b/pandasai/agent/base.py @@ -36,6 +36,7 @@ from ..llm.base import LLM from importlib.util import find_spec from ..config import Config +import warnings class Agent: @@ -59,6 +60,13 @@ def __init__( memory_size (int, optional): Conversation history to use during chat. Defaults to 1. """ + if config is not None: + warnings.warn( + "The 'config' parameter is deprecated and will be removed in a future version. " + "Please use the global configuration instead.", + DeprecationWarning, + stacklevel=2, + ) self._state = AgentState() @@ -72,9 +80,6 @@ def __init__( # Instantiate the config self._state.config = self._get_config(config) - # Set llm in state - self._state.llm = self._get_llm(self._state.config.llm) - # Validate df input with configurations self._validate_input() @@ -86,6 +91,10 @@ def __init__( save_logs=self._state.config.save_logs, verbose=self._state.config.verbose ) + # If user provided config but not llm but have setup the env for BambooLLM, will be deprecated in future + if config: + self._state.config.llm = self._get_llm(self._state.config.llm) + # Initiate VectorStore self._state.vectorstore = vectorstore @@ -298,6 +307,10 @@ def _process_query(self, query: str, output_type: Optional[str] = None): try: self._assign_prompt_id() + # To ensure the cache is set properly if config is changed in between + if self._state.config.enable_cache and self._state.cache is None: + self._state.cache = Cache() + # Generate code code, additional_dependencies = self.generate_code(query) @@ -354,7 +367,7 @@ def _get_config(self, config: Union[Config, dict]): """ config = load_config_from_json(config) - return Config(**config) + return Config(**config) if config else None def _get_llm(self, llm: Optional[LLM] = None) -> LLM: """ diff --git a/pandasai/agent/state.py b/pandasai/agent/state.py index f9bfefc5c..693832103 100644 --- a/pandasai/agent/state.py +++ b/pandasai/agent/state.py @@ -12,7 +12,6 @@ if TYPE_CHECKING: from pandasai.dataframe import DataFrame from pandasai.dataframe import VirtualDataFrame - from pandasai.llm.base import LLM @dataclass @@ -22,10 +21,9 @@ class AgentState: """ dfs: List[Union[DataFrame, VirtualDataFrame]] = field(default_factory=list) - config: Union[Config, dict] = field(default_factory=dict) + _config: Union[Config, dict] = field(default_factory=dict) memory: Memory = field(default_factory=Memory) cache: Optional[Cache] = None - llm: LLM = None vectorstore: Optional[VectorStore] = None intermediate_values: Dict[str, Any] = field(default_factory=dict) logger: Optional[Logger] = None @@ -58,3 +56,22 @@ def add_many(self, values: Dict[str, Any]): def get(self, key: str, default: Any = "") -> Any: """Fetches a value from intermediate values or returns a default.""" return self.intermediate_values.get(key, default) + + @property + def config(self): + """ + Returns the local config if set, otherwise fetches the global config. + """ + if self._config is not None: + return self._config + + import pandasai as pai + + return pai.config.get() + + @config.setter + def config(self, value: Union[Config, dict, None]): + """ + Allows setting a new config value. + """ + self._config = Config(**value) if isinstance(value, dict) else value diff --git a/pandasai/config.py b/pandasai/config.py index c7edd5e18..cab856108 100644 --- a/pandasai/config.py +++ b/pandasai/config.py @@ -1,4 +1,6 @@ +from importlib.util import find_spec import json +import os import pandasai.llm as llm from pandasai.llm.base import LLM @@ -33,6 +35,53 @@ def from_dict(cls, config: Dict[str, Any]) -> "Config": return cls(**config) +class ConfigManager: + """A singleton class to manage the global configuration.""" + + _config: Config = Config() + + @classmethod + def set(cls, config_dict: Dict[str, Any]) -> None: + """Set the global configuration.""" + cls._config = Config.from_dict(config_dict) + cls.validate_llm() + + @classmethod + def get(cls) -> Config: + """Get the global configuration.""" + if cls._config.llm is None and os.environ.get("PANDASAI_API_KEY"): + from pandasai.llm.bamboo_llm import BambooLLM + + cls._config.llm = BambooLLM() + + return cls._config + + @classmethod + def update(cls, config_dict: Dict[str, Any]) -> None: + """Update the existing configuration with new values.""" + current_config = cls._config.model_dump() + current_config.update(config_dict) + cls._config = Config.from_dict(current_config) + + @classmethod + def validate_llm(cls): + """ + Initializes a default LLM if not provided. + """ + if cls._config.llm is None and os.environ.get("PANDASAI_API_KEY"): + from pandasai.llm.bamboo_llm import BambooLLM + + cls._config.llm = BambooLLM() + return + + # Check if pandasai_langchain is installed + if find_spec("pandasai_langchain") is not None: + from pandasai_langchain.langchain import LangchainLLM, is_langchain_llm + + if is_langchain_llm(cls._config.llm): + cls._config.llm = LangchainLLM(cls._config.llm) + + def load_config_from_json( override_config: Optional[Union[Config, dict]] = None, ): diff --git a/pandasai/smart_dataframe/__init__.py b/pandasai/smart_dataframe/__init__.py new file mode 100644 index 000000000..438a12b80 --- /dev/null +++ b/pandasai/smart_dataframe/__init__.py @@ -0,0 +1,223 @@ +import uuid +from functools import cached_property +from io import StringIO +from typing import Any, List, Optional, Union +import warnings +import pandas as pd +from pandasai.agent import Agent +from pandasai.dataframe.base import DataFrame +from ..helpers.logger import Logger +from ..config import Config + + +class SmartDataframe: + _table_name: str + _table_description: str + _custom_head: str = None + _original_import: any + + def __init__( + self, + df: pd.DataFrame, + name: str = None, + description: str = None, + custom_head: pd.DataFrame = None, + config: Config = None, + ): + warnings.warn( + "\n" + + "*" * 80 + + "\n" + + "\033[1;33mDEPRECATION WARNING:\033[0m\n" + SmartDataframe will soon be deprecated. Please use df.chat() instead. + + "*" * 80 + + "\n", + DeprecationWarning, + stacklevel=2, + ) + + self._original_import = df + self.dataframe = self.load_df(df, name, description, custom_head) + self._agent = Agent([self.dataframe], config=config) + self._table_description = description + self._table_name = name + if custom_head is not None: + self._custom_head = custom_head.to_csv(index=False) + + def load_df(self, df, name: str, description: str, custom_head: pd.DataFrame): + if isinstance(df, pd.DataFrame): + df = DataFrame( + df, + name=name, + description=description, + ) + else: + raise ValueError("Invalid input data. We cannot convert it to a dataframe.") + return df + + def chat(self, query: str, output_type: Optional[str] = None): + """ + Run a query on the dataframe. + Args: + query (str): Query to run on the dataframe + output_type (Optional[str]): Add a hint for LLM of which + type should be returned by `analyze_data()` in generated + code. Possible values: "number", "dataframe", "plot", "string": + * number - specifies that user expects to get a number + as a response object + * dataframe - specifies that user expects to get + pandas dataframe as a response object + * plot - specifies that user expects LLM to build + a plot + * string - specifies that user expects to get text + as a response object + Raises: + ValueError: If the query is empty + """ + return self._agent.chat(query, output_type) + + @cached_property + def head_df(self): + """ + Get the head of the dataframe as a dataframe. + Returns: + pd.DataFrame: Pandas dataframe + """ + return self.dataframe.get_head() + + @cached_property + def head_csv(self): + """ + Get the head of the dataframe as a CSV string. + Returns: + str: CSV string + """ + df_head = self.dataframe.get_head() + return df_head.to_csv(index=False) + + @property + def last_prompt(self): + return self._agent.last_prompt + + @property + def last_prompt_id(self) -> uuid.UUID: + return self._agent.last_prompt_id + + @property + def last_code_generated(self): + return self._agent.last_code_executed + + @property + def last_code_executed(self): + return self._agent.last_code_executed + + def original_import(self): + return self._original_import + + @property + def logger(self): + return self._agent.logger + + @logger.setter + def logger(self, logger: Logger): + self._agent.logger = logger + + @property + def logs(self): + return self._agent.context.config.logs + + @property + def verbose(self): + return self._agent.context.config.verbose + + @verbose.setter + def verbose(self, verbose: bool): + self._agent.context.config.verbose = verbose + + @property + def save_logs(self): + return self._agent.context.config.save_logs + + @save_logs.setter + def save_logs(self, save_logs: bool): + self._agent.context.config.save_logs = save_logs + + @property + def enforce_privacy(self): + return self._agent.context.config.enforce_privacy + + @enforce_privacy.setter + def enforce_privacy(self, enforce_privacy: bool): + self._agent.context.config.enforce_privacy = enforce_privacy + + @property + def enable_cache(self): + return self._agent.context.config.enable_cache + + @enable_cache.setter + def enable_cache(self, enable_cache: bool): + self._agent.context.config.enable_cache = enable_cache + + @property + def save_charts(self): + return self._agent.context.config.save_charts + + @save_charts.setter + def save_charts(self, save_charts: bool): + self._agent.context.config.save_charts = save_charts + + @property + def save_charts_path(self): + return self._agent.context.config.save_charts_path + + @save_charts_path.setter + def save_charts_path(self, save_charts_path: str): + self._agent.context.config.save_charts_path = save_charts_path + + @property + def table_name(self): + return self._table_name + + @property + def table_description(self): + return self._table_description + + @property + def custom_head(self): + data = StringIO(self._custom_head) + return pd.read_csv(data) + + def __len__(self): + return len(self.dataframe) + + def __eq__(self, other): + return self.dataframe.equals(other.dataframe) + + def __getattr__(self, name): + if name in self.dataframe.__dir__(): + return getattr(self.dataframe, name) + else: + return self.__getattribute__(name) + + def __getitem__(self, key): + return self.dataframe.__getitem__(key) + + def __setitem__(self, key, value): + return self.dataframe.__setitem__(key, value) + + +def load_smartdataframes( + dfs: List[Union[pd.DataFrame, Any]], config: Config +) -> List[SmartDataframe]: + """ + Load all the dataframes to be used in the smart datalake. + Args: + dfs (List[Union[pd.DataFrame, Any]]): List of dataframes to be used + """ + smart_dfs = [] + for df in dfs: + if not isinstance(df, SmartDataframe): + smart_dfs.append(SmartDataframe(df, config=config)) + else: + smart_dfs.append(df) + return smart_dfs diff --git a/pandasai/smart_datalake/__init__.py b/pandasai/smart_datalake/__init__.py new file mode 100644 index 000000000..1d1a90937 --- /dev/null +++ b/pandasai/smart_datalake/__init__.py @@ -0,0 +1,197 @@ +import uuid +import warnings +import pandas as pd +from typing import List, Optional, Union +from pandasai.agent import Agent +from pandasai.dataframe.base import DataFrame +from ..helpers.cache import Cache +from ..config import Config + + +class SmartDatalake: + def __init__( + self, + dfs: List[pd.DataFrame], + config: Optional[Union[Config, dict]] = None, + ): + warnings.warn( + "\n" + + "*" * 80 + + "\n" + + "\033[1;33mDEPRECATION WARNING:\033[0m\n" + + "SmartDatalake will be deprecated soon. Use df.chat() instead.\n" + + "*" * 80 + + "\n", + DeprecationWarning, + stacklevel=2, + ) + dfs = self.load_dfs(dfs) + self._agent = Agent(dfs, config=config) + + def load_dfs(self, dfs: List[pd.DataFrame]): + load_dfs = [] + for df in dfs: + if isinstance(df, pd.DataFrame): + load_dfs.append(DataFrame(df)) + else: + raise ValueError( + "Invalid input data. We cannot convert it to a dataframe." + ) + return load_dfs + + def chat(self, query: str, output_type: Optional[str] = None): + """ + Run a query on the dataframe. + Args: + query (str): Query to run on the dataframe + output_type (Optional[str]): Add a hint for LLM which + type should be returned by `analyze_data()` in generated + code. Possible values: "number", "dataframe", "plot", "string": + * number - specifies that user expects to get a number + as a response object + * dataframe - specifies that user expects to get + pandas dataframe as a response object + * plot - specifies that user expects LLM to build + a plot + * string - specifies that user expects to get text + as a response object + If none `output_type` is specified, the type can be any + of the above or "text". + Raises: + ValueError: If the query is empty + """ + return self._agent.chat(query, output_type) + + def clear_memory(self): + """ + Clears the memory + """ + self._agent.clear_memory() + + @property + def last_prompt(self): + return self._agent.last_prompt + + @property + def last_prompt_id(self) -> uuid.UUID: + """Return the id of the last prompt that was run.""" + if self._agent.last_prompt_id is None: + raise ValueError("Pandas AI has not been run yet.") + return self._agent.last_prompt_id + + @property + def logs(self): + return self._agent.logger.logs + + @property + def logger(self): + return self._agent.logger + + @logger.setter + def logger(self, logger): + self._agent.logger = logger + + @property + def config(self): + return self._agent.context.config + + @property + def cache(self): + return self._agent.context.cache + + @property + def verbose(self): + return self._agent.context.config.verbose + + @verbose.setter + def verbose(self, verbose: bool): + self._agent.context.config.verbose = verbose + self._agent.logger.verbose = verbose + + @property + def save_logs(self): + return self._agent.context.config.save_logs + + @save_logs.setter + def save_logs(self, save_logs: bool): + self._agent.context.config.save_logs = save_logs + self._agent.logger.save_logs = save_logs + + @property + def enforce_privacy(self): + return self._agent.context.config.enforce_privacy + + @enforce_privacy.setter + def enforce_privacy(self, enforce_privacy: bool): + self._agent.context.config.enforce_privacy = enforce_privacy + + @property + def enable_cache(self): + return self._agent.context.config.enable_cache + + @enable_cache.setter + def enable_cache(self, enable_cache: bool): + self._agent.context.config.enable_cache = enable_cache + if enable_cache: + if self.cache is None: + self._cache = Cache() + else: + self._cache = None + + @property + def use_error_correction_framework(self): + return self._agent.context.config.use_error_correction_framework + + @use_error_correction_framework.setter + def use_error_correction_framework(self, use_error_correction_framework: bool): + self._agent.context.config.use_error_correction_framework = ( + use_error_correction_framework + ) + + @property + def custom_prompts(self): + return self._agent.context.config.custom_prompts + + @custom_prompts.setter + def custom_prompts(self, custom_prompts: dict): + self._agent.context.config.custom_prompts = custom_prompts + + @property + def save_charts(self): + return self._agent.context.config.save_charts + + @save_charts.setter + def save_charts(self, save_charts: bool): + self._agent.context.config.save_charts = save_charts + + @property + def save_charts_path(self): + return self._agent.context.config.save_charts_path + + @save_charts_path.setter + def save_charts_path(self, save_charts_path: str): + self._agent.context.config.save_charts_path = save_charts_path + + @property + def last_code_generated(self): + return self._agent.last_code_generated + + @property + def last_code_executed(self): + return self._agent.last_code_executed + + @property + def last_result(self): + return self._agent.last_result + + @property + def last_error(self): + return self._agent.last_error + + @property + def dfs(self): + return self._agent.context.dfs + + @property + def memory(self): + return self._agent.context.memory