Skip to content

Commit

Permalink
feat(unified_config): keep global config and use throughout the appli… (
Browse files Browse the repository at this point in the history
#1467)

* feat(Unified_config): keep global config and use throughout the application unless changed explicit

* remove extra comments

* feat: improve deprecation warning

Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>

---------

Co-authored-by: Gabriele Venturi <[email protected]>
Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Dec 10, 2024
1 parent b12cb49 commit 198af20
Show file tree
Hide file tree
Showing 6 changed files with 514 additions and 7 deletions.
8 changes: 8 additions & 0 deletions pandasai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,24 @@

import pandas as pd

from pandasai.config import ConfigManager
from pandasai.exceptions import DatasetNotFound, PandasAIApiKeyError
from pandasai.helpers.path import find_project_root
from pandasai.helpers.request import get_pandaai_session
from .agent import Agent
from .helpers.cache import Cache
from .dataframe.base import DataFrame
from .data_loader.loader import DatasetLoader
from .smart_dataframe import SmartDataframe
from .smart_datalake import SmartDatalake

# Global variable to store the current agent
_current_agent = None


config = ConfigManager()


def clear_cache(filename: str = None):
"""Clear the cache"""
cache = Cache(filename) if filename else Cache()
Expand Down Expand Up @@ -119,4 +125,6 @@ def read_csv(filepath: str) -> DataFrame:
"chat",
"follow_up",
"load",
"SmartDataframe",
"SmartDatalake",
]
21 changes: 17 additions & 4 deletions pandasai/agent/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from ..llm.base import LLM
from importlib.util import find_spec
from ..config import Config
import warnings


class Agent:
Expand All @@ -59,6 +60,13 @@ def __init__(
memory_size (int, optional): Conversation history to use during chat.
Defaults to 1.
"""
if config is not None:
warnings.warn(
"The 'config' parameter is deprecated and will be removed in a future version. "
"Please use the global configuration instead.",
DeprecationWarning,
stacklevel=2,
)

self._state = AgentState()

Expand All @@ -72,9 +80,6 @@ def __init__(
# Instantiate the config
self._state.config = self._get_config(config)

# Set llm in state
self._state.llm = self._get_llm(self._state.config.llm)

# Validate df input with configurations
self._validate_input()

Expand All @@ -86,6 +91,10 @@ def __init__(
save_logs=self._state.config.save_logs, verbose=self._state.config.verbose
)

# If user provided config but not llm but have setup the env for BambooLLM, will be deprecated in future
if config:
self._state.config.llm = self._get_llm(self._state.config.llm)

# Initiate VectorStore
self._state.vectorstore = vectorstore

Expand Down Expand Up @@ -298,6 +307,10 @@ def _process_query(self, query: str, output_type: Optional[str] = None):
try:
self._assign_prompt_id()

# To ensure the cache is set properly if config is changed in between
if self._state.config.enable_cache and self._state.cache is None:
self._state.cache = Cache()

# Generate code
code, additional_dependencies = self.generate_code(query)

Expand Down Expand Up @@ -354,7 +367,7 @@ def _get_config(self, config: Union[Config, dict]):
"""

config = load_config_from_json(config)
return Config(**config)
return Config(**config) if config else None

def _get_llm(self, llm: Optional[LLM] = None) -> LLM:
"""
Expand Down
23 changes: 20 additions & 3 deletions pandasai/agent/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
if TYPE_CHECKING:
from pandasai.dataframe import DataFrame
from pandasai.dataframe import VirtualDataFrame
from pandasai.llm.base import LLM


@dataclass
Expand All @@ -22,10 +21,9 @@ class AgentState:
"""

dfs: List[Union[DataFrame, VirtualDataFrame]] = field(default_factory=list)
config: Union[Config, dict] = field(default_factory=dict)
_config: Union[Config, dict] = field(default_factory=dict)
memory: Memory = field(default_factory=Memory)
cache: Optional[Cache] = None
llm: LLM = None
vectorstore: Optional[VectorStore] = None
intermediate_values: Dict[str, Any] = field(default_factory=dict)
logger: Optional[Logger] = None
Expand Down Expand Up @@ -58,3 +56,22 @@ def add_many(self, values: Dict[str, Any]):
def get(self, key: str, default: Any = "") -> Any:
"""Fetches a value from intermediate values or returns a default."""
return self.intermediate_values.get(key, default)

@property
def config(self):
"""
Returns the local config if set, otherwise fetches the global config.
"""
if self._config is not None:
return self._config

import pandasai as pai

return pai.config.get()

@config.setter
def config(self, value: Union[Config, dict, None]):
"""
Allows setting a new config value.
"""
self._config = Config(**value) if isinstance(value, dict) else value
49 changes: 49 additions & 0 deletions pandasai/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from importlib.util import find_spec
import json
import os

import pandasai.llm as llm
from pandasai.llm.base import LLM
Expand Down Expand Up @@ -33,6 +35,53 @@ def from_dict(cls, config: Dict[str, Any]) -> "Config":
return cls(**config)


class ConfigManager:
"""A singleton class to manage the global configuration."""

_config: Config = Config()

@classmethod
def set(cls, config_dict: Dict[str, Any]) -> None:
"""Set the global configuration."""
cls._config = Config.from_dict(config_dict)
cls.validate_llm()

@classmethod
def get(cls) -> Config:
"""Get the global configuration."""
if cls._config.llm is None and os.environ.get("PANDASAI_API_KEY"):
from pandasai.llm.bamboo_llm import BambooLLM

cls._config.llm = BambooLLM()

return cls._config

@classmethod
def update(cls, config_dict: Dict[str, Any]) -> None:
"""Update the existing configuration with new values."""
current_config = cls._config.model_dump()
current_config.update(config_dict)
cls._config = Config.from_dict(current_config)

@classmethod
def validate_llm(cls):
"""
Initializes a default LLM if not provided.
"""
if cls._config.llm is None and os.environ.get("PANDASAI_API_KEY"):
from pandasai.llm.bamboo_llm import BambooLLM

cls._config.llm = BambooLLM()
return

# Check if pandasai_langchain is installed
if find_spec("pandasai_langchain") is not None:
from pandasai_langchain.langchain import LangchainLLM, is_langchain_llm

if is_langchain_llm(cls._config.llm):
cls._config.llm = LangchainLLM(cls._config.llm)


def load_config_from_json(
override_config: Optional[Union[Config, dict]] = None,
):
Expand Down
Loading

0 comments on commit 198af20

Please sign in to comment.