diff --git a/mandala/storage.py b/mandala/storage.py index 228c3ed..5c457db 100644 --- a/mandala/storage.py +++ b/mandala/storage.py @@ -18,12 +18,16 @@ CachedDictStorage, SQLiteDictStorage, CachedCallStorage, + JoblibDictStorage, transaction ) class Storage: - def __init__(self, db_path: str = ":memory:", + def __init__(self, + db_path: str = ":memory:", + overflow_dir: Optional[str] = None, + overflow_threshold_MB: Optional[Union[int, float]] = 50.0, deps_path: Optional[Union[str, Path]] = None, tracer_impl: Optional[type] = None, strict_tracing: bool = False, @@ -40,8 +44,18 @@ def __init__(self, db_path: str = ":memory:", self.calls = CachedCallStorage(persistent=self.call_storage) self.call_cache = self.calls.cache + self.overflow_dir = overflow_dir + self.overflow_threshold_MB = overflow_threshold_MB + if self.overflow_dir is not None: + self.overflow_storage = JoblibDictStorage(root=self.overflow_dir) + else: + self.overflow_storage = None + self.atoms = CachedDictStorage( - persistent=SQLiteDictStorage(self.db, table="atoms") + persistent=SQLiteDictStorage(self.db, table="atoms", + overflow_storage=self.overflow_storage, + overflow_threshold_MB=self.overflow_threshold_MB + ) ) self.shapes = CachedDictStorage( persistent=SQLiteDictStorage(self.db, table="shapes") @@ -743,6 +757,8 @@ def call_internal( """ Main function to call an op, operating on the representations used internally by the storage. + + NOTE: this function does NOT save the call to the storage. """ ### wrap the inputs if not op.__structural__: logger.debug(f"Calling {op.name} with args {bound_arguments}.") diff --git a/mandala/storage_utils.py b/mandala/storage_utils.py index f2a4355..2d8acfa 100644 --- a/mandala/storage_utils.py +++ b/mandala/storage_utils.py @@ -108,14 +108,60 @@ def __len__(self) -> int: return len(self.keys()) +class JoblibDictStorage(DictStorage): + """ + A dictionary storage that uses joblib to store the data on disk. + """ + def __init__(self, root: str): + self.root = root + os.makedirs(root, exist_ok=True) + + def get_path_for_key(self, key: str) -> str: + return os.path.join(self.root, f'{key}.joblib') + + def get(self, key: str) -> Any: + return joblib.load(self.get_path_for_key(key)) + + def exists(self, key: str) -> bool: + return os.path.exists(self.get_path_for_key(key)) + + def set(self, key: str, value: Any) -> None: + if self.exists(key): + return # this is a write-once storage + joblib.dump(value, self.get_path_for_key(key)) + + def drop(self, key: str) -> None: + os.remove(self.get_path_for_key(key)) + + def load_all(self) -> Dict[str, Any]: + return {key: self.get(key) for key in self.keys()} + + def keys(self) -> List[str]: + return [os.path.splitext(f)[0] for f in os.listdir(self.root) if f.endswith('.joblib')] + + def values(self) -> List[Any]: + return [self.get(key) for key in self.keys()] + + + + class SQLiteDictStorage(DictStorage): - def __init__(self, db: DBAdapter, table: str): + def __init__(self, db: DBAdapter, + table: str, + overflow_storage: Optional[JoblibDictStorage] = None, + overflow_threshold_MB: Optional[Union[int, float]] = 50, + ): + """ + + """ self.db = db self.table = table with self.conn() as conn: conn.execute( f"CREATE TABLE IF NOT EXISTS {table} (key TEXT PRIMARY KEY, value BLOB)" ) + self.overflow_storage = overflow_storage + self.overflow_threshold_MB = overflow_threshold_MB def conn(self) -> sqlite3.Connection: return self.db.conn() @@ -123,28 +169,47 @@ def conn(self) -> sqlite3.Connection: def load_all(self) -> Dict[str, Any]: with self.conn() as conn: cursor = conn.execute(f"SELECT key, value FROM {self.table}") - return {row[0]: deserialize(row[1]) for row in cursor.fetchall()} + res = {row[0]: deserialize(row[1]) for row in cursor.fetchall()} + if self.overflow_storage is not None: + res.update(self.overflow_storage.load_all()) + return res @transaction def get(self, key: str, conn: Optional[sqlite3.Connection] = None) -> Any: cursor = conn.execute(f"SELECT value FROM {self.table} WHERE key = ?", (key,)) result = cursor.fetchone() if result is None: - raise KeyError(f"Key {key} not found") + if self.overflow_storage is not None: + return self.overflow_storage.get(key) + else: + raise KeyError(f"Key {key} not found") return deserialize(result[0]) @transaction def set( self, key: str, value: Any, conn: Optional[sqlite3.Connection] = None ) -> None: - conn.execute( - f"INSERT OR REPLACE INTO {self.table} (key, value) VALUES (?, ?)", - (key, serialize(value)), - ) + serialized_value = serialize(value) + # compute the space this string would take up in bytes + size_MB = len(serialized_value) / 1024 / 1024 + if size_MB > self.overflow_threshold_MB: + if self.overflow_storage is not None: + self.overflow_storage.set(key, value) + else: + raise ValueError( + f"Value for key {key} is too large ({size_MB:.2f} MB) and no overflow storage is provided" + ) + else: + conn.execute( + f"INSERT OR REPLACE INTO {self.table} (key, value) VALUES (?, ?)", + (key, serialized_value), + ) @transaction def drop(self, key: str, conn: Optional[sqlite3.Connection] = None) -> None: conn.execute(f"DELETE FROM {self.table} WHERE key = ?", (key,)) + if self.overflow_storage is not None and self.overflow_storage.exists(key): + self.overflow_storage.drop(key) @transaction def exists(self, key: str, conn: Optional[sqlite3.Connection] = None) -> bool: @@ -152,17 +217,26 @@ def exists(self, key: str, conn: Optional[sqlite3.Connection] = None) -> bool: f"SELECT COUNT(*) FROM {self.table} WHERE key = ?", (key,) ) count = cursor.fetchone()[0] - return count > 0 + if self.overflow_storage is None: + return count > 0 + else: + return count > 0 or self.overflow_storage.exists(key) @transaction def keys(self, conn: Optional[sqlite3.Connection] = None) -> List[str]: cursor = conn.execute(f"SELECT key FROM {self.table}") - return [row[0] for row in cursor.fetchall()] + res = [row[0] for row in cursor.fetchall()] + if self.overflow_storage is not None: + res.extend(self.overflow_storage.keys()) + return res @transaction def values(self, conn: Optional[sqlite3.Connection] = None) -> List[Any]: cursor = conn.execute(f"SELECT value FROM {self.table}") - return [deserialize(row[0]) for row in cursor.fetchall()] + res = [deserialize(row[0]) for row in cursor.fetchall()] + if self.overflow_storage is not None: + res.extend(self.overflow_storage.values()) + return res class CachedDictStorage(DictStorage): @@ -247,6 +321,9 @@ def save(self, call: Call): # if call.hid in self.df.index.levels[0]: if call.hid in self.call_hids: return + if not hasattr(sess, '_times'): + sess._times = [] + start = time.time() for k, v in call.inputs.items(): self.df.loc[(call.hid, k), :] = ("in", call.cid, v.cid, v.hid, call.op.name, call.semantic_version, call.content_version) for k, v in call.outputs.items(): @@ -259,6 +336,8 @@ def save(self, call: Call): call.semantic_version, call.content_version, ) + end = time.time() + sess._times.append(end - start) self.call_hids.add(call.hid) def drop(self, hid: str):