Skip to content

Commit

Permalink
add support for overflow storage
Browse files Browse the repository at this point in the history
  • Loading branch information
amakelov committed Aug 29, 2024
1 parent a09b31a commit e6422b1
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 12 deletions.
20 changes: 18 additions & 2 deletions mandala/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,16 @@
CachedDictStorage,
SQLiteDictStorage,
CachedCallStorage,
JoblibDictStorage,
transaction
)


class Storage:
def __init__(self, db_path: str = ":memory:",
def __init__(self,
db_path: str = ":memory:",
overflow_dir: Optional[str] = None,
overflow_threshold_MB: Optional[Union[int, float]] = 50.0,
deps_path: Optional[Union[str, Path]] = None,
tracer_impl: Optional[type] = None,
strict_tracing: bool = False,
Expand All @@ -40,8 +44,18 @@ def __init__(self, db_path: str = ":memory:",
self.calls = CachedCallStorage(persistent=self.call_storage)
self.call_cache = self.calls.cache

self.overflow_dir = overflow_dir
self.overflow_threshold_MB = overflow_threshold_MB
if self.overflow_dir is not None:
self.overflow_storage = JoblibDictStorage(root=self.overflow_dir)
else:
self.overflow_storage = None

self.atoms = CachedDictStorage(
persistent=SQLiteDictStorage(self.db, table="atoms")
persistent=SQLiteDictStorage(self.db, table="atoms",
overflow_storage=self.overflow_storage,
overflow_threshold_MB=self.overflow_threshold_MB
)
)
self.shapes = CachedDictStorage(
persistent=SQLiteDictStorage(self.db, table="shapes")
Expand Down Expand Up @@ -743,6 +757,8 @@ def call_internal(
"""
Main function to call an op, operating on the representations used
internally by the storage.
NOTE: this function does NOT save the call to the storage.
"""
### wrap the inputs
if not op.__structural__: logger.debug(f"Calling {op.name} with args {bound_arguments}.")
Expand Down
99 changes: 89 additions & 10 deletions mandala/storage_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,61 +108,135 @@ def __len__(self) -> int:
return len(self.keys())


class JoblibDictStorage(DictStorage):
"""
A dictionary storage that uses joblib to store the data on disk.
"""
def __init__(self, root: str):
self.root = root
os.makedirs(root, exist_ok=True)

def get_path_for_key(self, key: str) -> str:
return os.path.join(self.root, f'{key}.joblib')

def get(self, key: str) -> Any:
return joblib.load(self.get_path_for_key(key))

def exists(self, key: str) -> bool:
return os.path.exists(self.get_path_for_key(key))

def set(self, key: str, value: Any) -> None:
if self.exists(key):
return # this is a write-once storage
joblib.dump(value, self.get_path_for_key(key))

def drop(self, key: str) -> None:
os.remove(self.get_path_for_key(key))

def load_all(self) -> Dict[str, Any]:
return {key: self.get(key) for key in self.keys()}

def keys(self) -> List[str]:
return [os.path.splitext(f)[0] for f in os.listdir(self.root) if f.endswith('.joblib')]

def values(self) -> List[Any]:
return [self.get(key) for key in self.keys()]




class SQLiteDictStorage(DictStorage):
def __init__(self, db: DBAdapter, table: str):
def __init__(self, db: DBAdapter,
table: str,
overflow_storage: Optional[JoblibDictStorage] = None,
overflow_threshold_MB: Optional[Union[int, float]] = 50,
):
"""
"""
self.db = db
self.table = table
with self.conn() as conn:
conn.execute(
f"CREATE TABLE IF NOT EXISTS {table} (key TEXT PRIMARY KEY, value BLOB)"
)
self.overflow_storage = overflow_storage
self.overflow_threshold_MB = overflow_threshold_MB

def conn(self) -> sqlite3.Connection:
return self.db.conn()

def load_all(self) -> Dict[str, Any]:
with self.conn() as conn:
cursor = conn.execute(f"SELECT key, value FROM {self.table}")
return {row[0]: deserialize(row[1]) for row in cursor.fetchall()}
res = {row[0]: deserialize(row[1]) for row in cursor.fetchall()}
if self.overflow_storage is not None:
res.update(self.overflow_storage.load_all())
return res

@transaction
def get(self, key: str, conn: Optional[sqlite3.Connection] = None) -> Any:
cursor = conn.execute(f"SELECT value FROM {self.table} WHERE key = ?", (key,))
result = cursor.fetchone()
if result is None:
raise KeyError(f"Key {key} not found")
if self.overflow_storage is not None:
return self.overflow_storage.get(key)
else:
raise KeyError(f"Key {key} not found")
return deserialize(result[0])

@transaction
def set(
self, key: str, value: Any, conn: Optional[sqlite3.Connection] = None
) -> None:
conn.execute(
f"INSERT OR REPLACE INTO {self.table} (key, value) VALUES (?, ?)",
(key, serialize(value)),
)
serialized_value = serialize(value)
# compute the space this string would take up in bytes
size_MB = len(serialized_value) / 1024 / 1024
if size_MB > self.overflow_threshold_MB:
if self.overflow_storage is not None:
self.overflow_storage.set(key, value)
else:
raise ValueError(
f"Value for key {key} is too large ({size_MB:.2f} MB) and no overflow storage is provided"
)
else:
conn.execute(
f"INSERT OR REPLACE INTO {self.table} (key, value) VALUES (?, ?)",
(key, serialized_value),
)

@transaction
def drop(self, key: str, conn: Optional[sqlite3.Connection] = None) -> None:
conn.execute(f"DELETE FROM {self.table} WHERE key = ?", (key,))
if self.overflow_storage is not None and self.overflow_storage.exists(key):
self.overflow_storage.drop(key)

@transaction
def exists(self, key: str, conn: Optional[sqlite3.Connection] = None) -> bool:
cursor = conn.execute(
f"SELECT COUNT(*) FROM {self.table} WHERE key = ?", (key,)
)
count = cursor.fetchone()[0]
return count > 0
if self.overflow_storage is None:
return count > 0
else:
return count > 0 or self.overflow_storage.exists(key)

@transaction
def keys(self, conn: Optional[sqlite3.Connection] = None) -> List[str]:
cursor = conn.execute(f"SELECT key FROM {self.table}")
return [row[0] for row in cursor.fetchall()]
res = [row[0] for row in cursor.fetchall()]
if self.overflow_storage is not None:
res.extend(self.overflow_storage.keys())
return res

@transaction
def values(self, conn: Optional[sqlite3.Connection] = None) -> List[Any]:
cursor = conn.execute(f"SELECT value FROM {self.table}")
return [deserialize(row[0]) for row in cursor.fetchall()]
res = [deserialize(row[0]) for row in cursor.fetchall()]
if self.overflow_storage is not None:
res.extend(self.overflow_storage.values())
return res


class CachedDictStorage(DictStorage):
Expand Down Expand Up @@ -247,6 +321,9 @@ def save(self, call: Call):
# if call.hid in self.df.index.levels[0]:
if call.hid in self.call_hids:
return
if not hasattr(sess, '_times'):
sess._times = []
start = time.time()
for k, v in call.inputs.items():
self.df.loc[(call.hid, k), :] = ("in", call.cid, v.cid, v.hid, call.op.name, call.semantic_version, call.content_version)
for k, v in call.outputs.items():
Expand All @@ -259,6 +336,8 @@ def save(self, call: Call):
call.semantic_version,
call.content_version,
)
end = time.time()
sess._times.append(end - start)
self.call_hids.add(call.hid)

def drop(self, hid: str):
Expand Down

0 comments on commit e6422b1

Please sign in to comment.