Skip to content

Commit

Permalink
refactor: use parquet for caching dfs
Browse files Browse the repository at this point in the history
  • Loading branch information
gventuri committed Sep 20, 2023
1 parent 58a18cb commit 1fd67db
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 20 deletions.
6 changes: 3 additions & 3 deletions pandasai/connectors/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def _get_cache_path(self, include_additional_filters: bool = False):

filename = (
self._get_column_hash(include_additional_filters=include_additional_filters)
+ ".csv"
+ ".parquet"
)
path = os.path.join(cache_dir, filename)

Expand Down Expand Up @@ -225,7 +225,7 @@ def _save_cache(self, df):
include_additional_filters=self._additional_filters is not None
and len(self._additional_filters) > 0
)
df.to_csv(filename, index=False)
df.to_parquet(filename)

def execute(self):
"""
Expand All @@ -239,7 +239,7 @@ def execute(self):
# filters as a fallback
cached = self._cached() or self._cached(include_additional_filters=True)
if cached:
return pd.read_csv(cached)
return pd.read_parquet(cached)

if self.logger:
self.logger.log(
Expand Down
8 changes: 4 additions & 4 deletions pandasai/connectors/yahoo_finance.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def _get_cache_path(self, include_additional_filters: bool = False):
except ValueError:
cache_dir = os.path.join(os.getcwd(), "cache")

return os.path.join(cache_dir, f"{self._config.table}_data.csv")
return os.path.join(cache_dir, f"{self._config.table}_data.parquet")

def _get_cache_path(self):
"""
Expand All @@ -86,7 +86,7 @@ def _get_cache_path(self):

os.makedirs(cache_dir, mode=0o777, exist_ok=True)

return os.path.join(cache_dir, f"{self._config.table}_data.csv")
return os.path.join(cache_dir, f"{self._config.table}_data.parquet")

def _cached(self):
"""
Expand Down Expand Up @@ -122,13 +122,13 @@ def execute(self):
"""
cached_path = self._cached()
if cached_path:
return pd.read_csv(cached_path)
return pd.read_parquet(cached_path)

# Use yfinance to retrieve historical stock data
stock_data = self.ticker.history(period="max")

# Save the result to the cache
stock_data.to_csv(self._get_cache_path(), index=False)
stock_data.to_parquet(self._get_cache_path())

return stock_data

Expand Down
14 changes: 7 additions & 7 deletions pandasai/helpers/df_config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,14 @@ def __init__(self, sdf):
else:
raise TypeError("Expected instance of type 'SmartDataFrame'")

def _create_csv_save_path(self):
def _create_save_path(self):
"""
Creates the path for the csv file to be saved
"""

directory_path = os.path.join(find_project_root(), "cache")
create_directory(directory_path)
csv_file_path = os.path.join(directory_path, f"{self._sdf.table_name}.csv")
csv_file_path = os.path.join(directory_path, f"{self._sdf.table_name}.parquet")
return csv_file_path

def _check_for_duplicates(self, saved_dfs, name: str):
Expand Down Expand Up @@ -78,16 +78,16 @@ def _get_import_path(self):
# Save df if pandas or polar
dataframe_type = df_type(self.original_import)
if dataframe_type == "pandas":
csv_file_path = self._create_csv_save_path()
self._sdf.dataframe.to_csv(csv_file_path)
file_path = self._create_save_path()
self._sdf.dataframe.to_parquet(file_path)
elif dataframe_type == "polars":
csv_file_path = self._create_csv_save_path()
with open(csv_file_path, "w") as f:
file_path = self._create_save_path()
with open(file_path, "w") as f:
self._sdf.dataframe.write_csv(f)
else:
raise ValueError("Unknown dataframe type")

return csv_file_path
return file_path

def save(self, name: str = None):
"""
Expand Down
6 changes: 6 additions & 0 deletions pandasai/smart_dataframe/abstract_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,12 @@ def to_markdown(self):
"""
return self.dataframe.to_markdown()

def to_parquet(self):
"""
A proxy-call to the dataframe's `.to_parquet()`.
"""
return self.dataframe.to_parquet()

# Query
def query(self, expr):
"""
Expand Down
2 changes: 1 addition & 1 deletion tests/connectors/test_yahoo_finance.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def test_head(yahoo_finance_connector):

def test_get_cache_path(yahoo_finance_connector):
with patch("os.path.join") as mock_join:
expected_result = "../AAPL_data.csv"
expected_result = "../AAPL_data.parquet"
mock_join.return_value = expected_result
assert yahoo_finance_connector._get_cache_path() == expected_result

Expand Down
14 changes: 9 additions & 5 deletions tests/test_smartdataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@ class TestSmartDataframe:
"""Unit tests for the SmartDatalake class"""

def tearDown(self):
for filename in ["df_test.csv", "df_test_polars.csv", "df_duplicate.csv"]:
for filename in [
"df_test.parquet",
"df_test_polars.parquet",
"df_duplicate.parquet",
]:
if os.path.exists("cache/" + filename):
os.remove("cache/" + filename)

Expand Down Expand Up @@ -92,7 +96,7 @@ def sample_saved_dfs(self):
"name": "photo",
"description": "Dataframe containing photo metadata",
"sample": "filename,format,size\n1.jpg,JPEG,1240KB\n2.png,PNG,320KB",
"import_path": "path/to/photo_data.csv",
"import_path": "path/to/photo_data.parquet",
}
]

Expand Down Expand Up @@ -564,7 +568,7 @@ def test_load_dataframe_from_saved_dfs(self, sample_saved_dfs, mocker):
"size": ["1240KB", "320KB"],
}
)
mocker.patch.object(pd, "read_csv", return_value=expected_df)
mocker.patch.object(pd, "read_parquet", return_value=expected_df)

mocker.patch.object(
json,
Expand Down Expand Up @@ -593,11 +597,11 @@ def test_load_dataframe_from_other_dataframe_type(self, smart_dataframe):
def test_import_csv_file(self, smart_dataframe, mocker):
mocker.patch.object(
pd,
"read_csv",
"read_parquet",
return_value=pd.DataFrame({"column1": [1, 2, 3], "column2": [4, 5, 6]}),
)

file_path = "sample.csv"
file_path = "sample.parquet"

df = smart_dataframe._import_from_file(file_path)

Expand Down

0 comments on commit 1fd67db

Please sign in to comment.