Skip to content

Commit

Permalink
feat: add output_type parameter (#519) (#562)
Browse files Browse the repository at this point in the history
* fix(connectors): make sqlalchemy non-optional and check yfinance imports (#561)

* feat: use duckdb for the cache

* Release v1.2.2

* feat: `output_type` parameter (#519)

* (feat): update prompt template in `GeneratePythonCodePrompt`, add
  `output_type_hint` variable to be interpolated
* (feat): update `.chat()` method for `SmartDataframe` and
  `SmartDatalake`, add optional `output_type` parameter
* (feat): add `get_output_type_hint()` in `GeneratePythonCodePrompt`
  class
* (feat): add "output_type_hint" to `default_values` when forming
  prompt's template context
* (tests): update tests in `TestGeneratePythonCodePrompt`
* (tests): add tests for checking `output_type` interpotaion to a
  prompt

* refactor: `output_type` parameter (#519)

* (refactor): update setting value for `{output_type_hint}` in prompt
  class

* fix: `output_type` parameter (#519)

* (tests): fix error in `TestGeneratePythonCodePrompt` with confused
  actual prompt's content and excepted prompt's content (which led to
  tests being failed)
* (refactor): update test method `test_str_with_args()` with
  `parametrize` decorator, remove duplication of code (DRY)

* tests: `output_type` parameter (#519)

* (tests): parametrizing `test_run_passing_output_type()` with different
  output types.

* refactor: `output_type` parameter (#519)

* (refactor): move output types and their hints to a separate python
  module
* (feat): validation for inappropriate output type and value
* (tests): update tests accordingly, add a new test method to check if
  the message about incorrect output type is added to logs

* chore: `output_type` parameter (#519)

* (chore): remove unused lines

* refactor: `output_type` parameter (#519)

* (refactor): pass `output_type_hint` in the `default_values`, don't
  bother with setting this template variable in prompt's `__init__()`
* (fix): correct templates examples, add them as a part of
  `output_type_hint`
* (refactor): use `df_type()` (utility functions) to get rid from
  imports of third packages in _output_types.py
* (refactor): add logging to the factory-function
  `output_type_factory()` to enhance verbosity for behaviour inspection

---------
  • Loading branch information
nautics889 authored Sep 18, 2023
1 parent b489c9c commit 69e2b14
Show file tree
Hide file tree
Showing 14 changed files with 939 additions and 499 deletions.
2 changes: 1 addition & 1 deletion mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ nav:
- Documents Building: building_docs.md
- License: license.md
extra:
version: "1.2.1"
version: "1.2.2"
plugins:
- search
- mkdocstrings:
Expand Down
2 changes: 1 addition & 1 deletion pandasai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def last_prompt(self) -> str:

def clear_cache(filename: str = None):
"""Clear the cache"""
cache = Cache(filename or "cache")
cache = Cache(filename or "cache_db")
cache.clear()


Expand Down
15 changes: 10 additions & 5 deletions pandasai/connectors/yahoo_finance.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
import yfinance as yf
import pandas as pd
from .base import ConnectorConfig, BaseConnector
import time
Expand All @@ -15,6 +14,13 @@ class YahooFinanceConnector(BaseConnector):
_cache_interval: int = 600 # 10 minutes

def __init__(self, stock_ticker, where=None, cache_interval: int = 600):
try:
import yfinance
except ImportError:
raise ImportError(
"Could not import yfinance python package. "
"Please install it with `pip install yfinance`."
)
yahoo_finance_config = ConnectorConfig(
dialect="yahoo_finance",
username="",
Expand All @@ -27,6 +33,7 @@ def __init__(self, stock_ticker, where=None, cache_interval: int = 600):
)
self._cache_interval = cache_interval
super().__init__(yahoo_finance_config)
self.ticker = yfinance.Ticker(self._config.table)

def head(self):
"""
Expand All @@ -36,8 +43,7 @@ def head(self):
DataFrameType: The head of the data source that the connector is
connected to.
"""
ticker = yf.Ticker(self._config.table)
head_data = ticker.history(period="5d")
head_data = self.ticker.history(period="5d")
return head_data

def _get_cache_path(self, include_additional_filters: bool = False):
Expand Down Expand Up @@ -105,8 +111,7 @@ def execute(self):
return pd.read_csv(cached_path)

# Use yfinance to retrieve historical stock data
ticker = yf.Ticker(self._config.table)
stock_data = ticker.history(period="max")
stock_data = self.ticker.history(period="max")

# Save the result to the cache
stock_data.to_csv(self._get_cache_path(), index=False)
Expand Down
41 changes: 21 additions & 20 deletions pandasai/helpers/cache.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Cache module for caching queries."""
import glob
import os
import shelve
import glob
import duckdb
from .path import find_project_root


Expand All @@ -13,17 +12,20 @@ class Cache:
filename (str): filename to store the cache.
"""

def __init__(self, filename="cache"):
# define cache directory and create directory if it does not exist
def __init__(self, filename="cache_db"):
# Define cache directory and create directory if it does not exist
try:
cache_dir = os.path.join((find_project_root()), "cache")
cache_dir = os.path.join(find_project_root(), "cache")
except ValueError:
cache_dir = os.path.join(os.getcwd(), "cache")

os.makedirs(cache_dir, mode=0o777, exist_ok=True)

self.filepath = os.path.join(cache_dir, filename)
self.cache = shelve.open(self.filepath)
self.filepath = os.path.join(cache_dir, filename + ".db")
self.connection = duckdb.connect(self.filepath)
self.connection.execute(
"CREATE TABLE IF NOT EXISTS cache (key STRING, value STRING)"
)

def set(self, key: str, value: str) -> None:
"""Set a key value pair in the cache.
Expand All @@ -32,8 +34,7 @@ def set(self, key: str, value: str) -> None:
key (str): key to store the value.
value (str): value to store in the cache.
"""

self.cache[key] = value
self.connection.execute("INSERT INTO cache VALUES (?, ?)", [key, value])

def get(self, key: str) -> str:
"""Get a value from the cache.
Expand All @@ -44,31 +45,31 @@ def get(self, key: str) -> str:
Returns:
str: value from the cache.
"""

return self.cache.get(key)
result = self.connection.execute("SELECT value FROM cache WHERE key=?", [key])
row = result.fetchone()
if row:
return row[0]
else:
return None

def delete(self, key: str) -> None:
"""Delete a key value pair from the cache.
Args:
key (str): key to delete the value from the cache.
"""

if key in self.cache:
del self.cache[key]
self.connection.execute("DELETE FROM cache WHERE key=?", [key])

def close(self) -> None:
"""Close the cache."""

self.cache.close()
self.connection.close()

def clear(self) -> None:
"""Clean the cache."""

self.cache.clear()
self.connection.execute("DELETE FROM cache")

def destroy(self) -> None:
"""Destroy the cache."""
self.cache.close()
self.connection.close()
for cache_file in glob.glob(self.filepath + ".*"):
os.remove(cache_file)
69 changes: 69 additions & 0 deletions pandasai/helpers/output_types/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import logging
from typing import Union, Optional

from ._output_types import (
NumberOutputType,
DataFrameOutputType,
PlotOutputType,
StringOutputType,
DefaultOutputType,
)
from .. import Logger


output_types_map = {
"number": NumberOutputType,
"dataframe": DataFrameOutputType,
"plot": PlotOutputType,
"string": StringOutputType,
}


def output_type_factory(
output_type: str = None, logger: Optional[Logger] = None
) -> Union[
NumberOutputType,
DataFrameOutputType,
PlotOutputType,
StringOutputType,
DefaultOutputType,
]:
"""
Factory function to get appropriate instance for output type.
Uses `output_types_map` to determine the output type class.
Args:
output_type (Optional[str]): A name of the output type.
Defaults to None, an instance of `DefaultOutputType` will be
returned.
logger (Optional[str]): If passed, collects logs about correctness
of the `output_type` argument and what kind of OutputType
is created.
Returns:
(Union[
NumberOutputType,
DataFrameOutputType,
PlotOutputType,
StringOutputType,
DefaultOutputType
]): An instance of the output type.
"""
if output_type is not None and output_type not in output_types_map and logger:
possible_types_msg = ", ".join(f"'{type_}'" for type_ in output_types_map)
logger.log(
f"Unknown value for the parameter `output_type`: '{output_type}'."
f"Possible values are: {possible_types_msg} and None for default "
f"output type (miscellaneous).",
level=logging.WARNING,
)

output_type_helper = output_types_map.get(output_type, DefaultOutputType)()

if logger:
logger.log(
f"{output_type_helper.__class__} is going to be used.", level=logging.DEBUG
)

return output_type_helper
169 changes: 169 additions & 0 deletions pandasai/helpers/output_types/_output_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
import re
from decimal import Decimal
from abc import abstractmethod, ABC
from typing import Any, Iterable

from ..df_info import df_type


class BaseOutputType(ABC):
@property
@abstractmethod
def template_hint(self) -> str:
...

@property
@abstractmethod
def name(self) -> str:
...

def _validate_type(self, actual_type: str) -> bool:
if actual_type != self.name:
return False
return True

@abstractmethod
def _validate_value(self, actual_value):
...

def validate(self, result: dict[str, Any]) -> tuple[bool, Iterable[str]]:
"""
Validate 'type' and 'value' from the result dict.
Args:
result (dict[str, Any]): The result of code execution in
dict representation. Should have the following schema:
{
"type": <output_type_name>,
"value": <generated_value>
}
Returns:
(tuple(bool, Iterable(str)):
Boolean value whether the result matches output type
and collection of logs containing messages about
'type' or 'value' mismatches.
"""
validation_logs = []
actual_type, actual_value = result.get("type"), result.get("value")

type_ok = self._validate_type(actual_type)
if not type_ok:
validation_logs.append(
f"The result dict contains inappropriate 'type'. "
f"Expected '{self.name}', actual '{actual_type}'."
)
value_ok = self._validate_value(actual_value)
if not value_ok:
validation_logs.append(
f"Actual value {repr(actual_value)} seems to be inappropriate "
f"for the type '{self.name}'."
)

return all((type_ok, value_ok)), validation_logs


class NumberOutputType(BaseOutputType):
@property
def template_hint(self):
return """- type (must be "number")
- value (must be a number)
Example output: { "type": "number", "value": 125 }"""

@property
def name(self):
return "number"

def _validate_value(self, actual_value: Any) -> bool:
if isinstance(actual_value, (int, float, Decimal)):
return True
return False


class DataFrameOutputType(BaseOutputType):
@property
def template_hint(self):
return """- type (must be "dataframe")
- value (must be a pandas dataframe)
Example output: { "type": "dataframe", "value": pd.DataFrame({...}) }"""

@property
def name(self):
return "dataframe"

def _validate_value(self, actual_value: Any) -> bool:
return bool(df_type(actual_value))


class PlotOutputType(BaseOutputType):
@property
def template_hint(self):
return """- type (must be "plot")
- value (must be a string containing the path of the plot image)
Example output: { "type": "plot", "value": "export/charts/temp.png" }"""

@property
def name(self):
return "plot"

def _validate_value(self, actual_value: Any) -> bool:
if not isinstance(actual_value, str):
return False

path_to_plot_pattern = r"^(\/[\w.-]+)+(/[\w.-]+)*$|^[^\s/]+(/[\w.-]+)*$"
if re.match(path_to_plot_pattern, actual_value):
return True

return False


class StringOutputType(BaseOutputType):
@property
def template_hint(self):
return """- type (must be "string")
- value (must be a conversational answer, as a string)
Example output: { "type": "string", "value": "The highest salary is $9,000." }"""

@property
def name(self):
return "string"

def _validate_value(self, actual_value: Any) -> bool:
if isinstance(actual_value, str):
return True
return False


class DefaultOutputType(BaseOutputType):
@property
def template_hint(self):
return """- type (possible values "string", "number", "dataframe", "plot")
- value (can be a string, a dataframe or the path of the plot, NOT a dictionary)
Examples:
{ "type": "string", "value": "The highest salary is $9,000." }
or
{ "type": "number", "value": 125 }
or
{ "type": "dataframe", "value": pd.DataFrame({...}) }
or
{ "type": "plot", "value": "export/charts/temp.png" }""" # noqa E501

@property
def name(self):
return "default"

def _validate_type(self, actual_type: str) -> bool:
return True

def _validate_value(self, actual_value: Any) -> bool:
return True

def validate(self, result: dict[str, Any]) -> tuple[bool, Iterable]:
"""
Validate 'type' and 'value' from the result dict.
Returns:
(bool): True since the `DefaultOutputType`
is supposed to have no validation
"""
return True, ()
Loading

0 comments on commit 69e2b14

Please sign in to comment.