Skip to content

Commit

Permalink
fix: update type hint for config parameter (#526)
Browse files Browse the repository at this point in the history
* (fix): add `dict` as a possible type of object to pass in methods and
  functions where `config` parameter occurs
* (feat): add logging of an exception to `load_config()` instead of
  silently suppressing an error
  • Loading branch information
nautics889 committed Sep 4, 2023
1 parent 167a15a commit 0db9bc7
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 41 deletions.
58 changes: 30 additions & 28 deletions pandasai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,22 +98,22 @@ class PandasAI:
"""

_dl: SmartDatalake = None
_config: Config
_config: [Config | dict]

def __init__(
self,
llm=None,
conversational=False,
verbose=False,
enforce_privacy=False,
save_charts=False,
save_charts_path="",
enable_cache=True,
middlewares=None,
custom_whitelisted_dependencies=None,
enable_logging=True,
non_default_prompts: Optional[Dict[str, Type[Prompt]]] = None,
callback: Optional[BaseCallback] = None,
self,
llm=None,
conversational=False,
verbose=False,
enforce_privacy=False,
save_charts=False,
save_charts_path="",
enable_cache=True,
middlewares=None,
custom_whitelisted_dependencies=None,
enable_logging=True,
non_default_prompts: Optional[Dict[str, Type[Prompt]]] = None,
callback: Optional[BaseCallback] = None,
):
"""
__init__ method of the Class PandasAI
Expand Down Expand Up @@ -142,8 +142,10 @@ def __init__(
# noinspection PyArgumentList
# https://stackoverflow.com/questions/61226587/pycharm-does-not-recognize-logging-basicconfig-handlers-argument

warnings.warn("`PandasAI` (class) is deprecated since v1.0 and will be removed "
"in a future release. Please use `SmartDataframe` instead.")
warnings.warn(
"`PandasAI` (class) is deprecated since v1.0 and will be removed "
"in a future release. Please use `SmartDataframe` instead."
)

self._config = Config(
conversational=conversational,
Expand All @@ -161,12 +163,12 @@ def __init__(
)

def run(
self,
data_frame: Union[pd.DataFrame, List[pd.DataFrame]],
prompt: str,
show_code: bool = False,
anonymize_df: bool = True,
use_error_correction_framework: bool = True,
self,
data_frame: Union[pd.DataFrame, List[pd.DataFrame]],
prompt: str,
show_code: bool = False,
anonymize_df: bool = True,
use_error_correction_framework: bool = True,
) -> Union[str, pd.DataFrame]:
"""
Run the PandasAI to make Dataframes Conversational.
Expand Down Expand Up @@ -198,12 +200,12 @@ def run(
return self._dl.chat(prompt)

def __call__(
self,
data_frame: Union[pd.DataFrame, List[pd.DataFrame]],
prompt: str,
show_code: bool = False,
anonymize_df: bool = True,
use_error_correction_framework: bool = True,
self,
data_frame: Union[pd.DataFrame, List[pd.DataFrame]],
prompt: str,
show_code: bool = False,
anonymize_df: bool = True,
use_error_correction_framework: bool = True,
) -> Union[str, pd.DataFrame]:
"""
__call__ method of PandasAI class. It calls the `run` method.
Expand Down
9 changes: 7 additions & 2 deletions pandasai/config.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import json
import logging
from typing import Optional

from . import llm, middlewares, callbacks
from .helpers.path import find_closest
from .schemas.df_config import Config

logger = logging.getLogger(__name__)


def load_config(override_config: Config = None):
def load_config(override_config: Optional[Config | dict] = None):
config = {}

if override_config is None:
Expand All @@ -27,7 +32,7 @@ def load_config(override_config: Config = None):
if config.get("callback") and not override_config.get("callback"):
config["callback"] = getattr(callbacks, config["callback"])()
except Exception:
pass
logger.error("Could not load configuration", exc_info=True)

if override_config:
config.update(override_config)
Expand Down
6 changes: 3 additions & 3 deletions pandasai/helpers/code_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
class CodeManager:
_dfs: List
_middlewares: List[Middleware] = [ChartsMiddleware()]
_config: Config
_config: [Config | dict]
_logger: Logger = None
_additional_dependencies: List[dict] = []

Expand All @@ -30,12 +30,12 @@ class CodeManager:
def __init__(
self,
dfs: List,
config: Config,
config: [Config | dict],
logger: Logger,
):
"""
Args:
config (Config, optional): Config to be used. Defaults to None.
config ([Config | dict], optional): Config to be used. Defaults to None.
logger (Logger, optional): Logger to be used. Defaults to None.
"""

Expand Down
6 changes: 3 additions & 3 deletions pandasai/smart_dataframe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from ..helpers.logger import Logger
from ..helpers.df_config_manager import DfConfigManager
from ..helpers.from_google_sheets import from_google_sheets
from typing import List, Union
from typing import List, Union, Optional
from ..middlewares.base import Middleware
from ..helpers.df_info import DataFrameType, df_type
from .abstract_df import DataframeAbstract
Expand All @@ -53,7 +53,7 @@ def __init__(
df: DataFrameType,
name: str = None,
description: str = None,
config: Config = None,
config: Optional[Config | dict] = None,
sample_head: pd.DataFrame = None,
logger: Logger = None,
):
Expand All @@ -62,7 +62,7 @@ def __init__(
df (Union[pd.DataFrame, pl.DataFrame]): Pandas or Polars dataframe
name (str, optional): Name of the dataframe. Defaults to None.
description (str, optional): Description of the dataframe. Defaults to "".
config (Config, optional): Config to be used. Defaults to None.
config ([Config | dict], optional): Config to be used. Defaults to None.
logger (Logger, optional): Logger to be used. Defaults to None.
"""
self._original_import = df
Expand Down
10 changes: 5 additions & 5 deletions pandasai/smart_datalake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@

class SmartDatalake:
_dfs: List[DataFrameType]
_config: Config
_config: [Config | dict]
_llm: LLM
_cache: Cache = None
_logger: Logger
Expand All @@ -60,14 +60,14 @@ class SmartDatalake:
def __init__(
self,
dfs: List[Union[DataFrameType, Any]],
config: Config = None,
config: Optional[Config | dict] = None,
logger: Logger = None,
memory: Memory = None,
):
"""
Args:
dfs (List[Union[DataFrameType, Any]]): List of dataframes to be used
config (Config, optional): Config to be used. Defaults to None.
config ([Config | dict], optional): Config to be used. Defaults to None.
logger (Logger, optional): Logger to be used. Defaults to None.
"""

Expand Down Expand Up @@ -135,12 +135,12 @@ def _load_dfs(self, dfs: List[Union[DataFrameType, Any]]):
smart_dfs.append(df)
self._dfs = smart_dfs

def _load_config(self, config: Config):
def _load_config(self, config: [Config | dict]):
"""
Load a config to be used to run the queries.
Args:
config (Config): Config to be used
config ([Config | dict]): Config to be used
"""

config = load_config(config)
Expand Down

0 comments on commit 0db9bc7

Please sign in to comment.