Skip to content

Commit

Permalink
feat: use duckdb for the cache
Browse files Browse the repository at this point in the history
  • Loading branch information
gventuri committed Sep 14, 2023
1 parent aad466d commit 0191615
Show file tree
Hide file tree
Showing 4 changed files with 499 additions and 406 deletions.
2 changes: 1 addition & 1 deletion pandasai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def last_prompt(self) -> str:

def clear_cache(filename: str = None):
"""Clear the cache"""
cache = Cache(filename or "cache")
cache = Cache(filename or "cache_db")
cache.clear()


Expand Down
41 changes: 21 additions & 20 deletions pandasai/helpers/cache.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Cache module for caching queries."""
import glob
import os
import shelve
import glob
import duckdb
from .path import find_project_root


Expand All @@ -13,17 +12,20 @@ class Cache:
filename (str): filename to store the cache.
"""

def __init__(self, filename="cache"):
# define cache directory and create directory if it does not exist
def __init__(self, filename="cache_db"):
# Define cache directory and create directory if it does not exist
try:
cache_dir = os.path.join((find_project_root()), "cache")
cache_dir = os.path.join(find_project_root(), "cache")
except ValueError:
cache_dir = os.path.join(os.getcwd(), "cache")

os.makedirs(cache_dir, mode=0o777, exist_ok=True)

self.filepath = os.path.join(cache_dir, filename)
self.cache = shelve.open(self.filepath)
self.filepath = os.path.join(cache_dir, filename + ".db")
self.connection = duckdb.connect(self.filepath)
self.connection.execute(
"CREATE TABLE IF NOT EXISTS cache (key STRING, value STRING)"
)

def set(self, key: str, value: str) -> None:
"""Set a key value pair in the cache.
Expand All @@ -32,8 +34,7 @@ def set(self, key: str, value: str) -> None:
key (str): key to store the value.
value (str): value to store in the cache.
"""

self.cache[key] = value
self.connection.execute("INSERT INTO cache VALUES (?, ?)", [key, value])

def get(self, key: str) -> str:
"""Get a value from the cache.
Expand All @@ -44,31 +45,31 @@ def get(self, key: str) -> str:
Returns:
str: value from the cache.
"""

return self.cache.get(key)
result = self.connection.execute("SELECT value FROM cache WHERE key=?", [key])
row = result.fetchone()
if row:
return row[0]
else:
return None

def delete(self, key: str) -> None:
"""Delete a key value pair from the cache.
Args:
key (str): key to delete the value from the cache.
"""

if key in self.cache:
del self.cache[key]
self.connection.execute("DELETE FROM cache WHERE key=?", [key])

def close(self) -> None:
"""Close the cache."""

self.cache.close()
self.connection.close()

def clear(self) -> None:
"""Clean the cache."""

self.cache.clear()
self.connection.execute("DELETE FROM cache")

def destroy(self) -> None:
"""Destroy the cache."""
self.cache.close()
self.connection.close()
for cache_file in glob.glob(self.filepath + ".*"):
os.remove(cache_file)
Loading

0 comments on commit 0191615

Please sign in to comment.