diff --git a/src/numbat/api.py b/src/numbat/api.py index e947ed4..cc2f636 100644 --- a/src/numbat/api.py +++ b/src/numbat/api.py @@ -59,6 +59,7 @@ def __init__(self, database: sqlite3.Connection, path: Path, logger: logging.Log self.logger = logging.getLogger() else: self.logger = logger + self.name_cache = dict() # ------------------------------------------------------------------------ # # Database file management functions # @@ -80,6 +81,19 @@ def __uniformize_path(cls, path: Path | str) -> Path: path = path.with_suffix(cls.SOURCETRAIL_DB_EXT) return path.absolute() + @classmethod + def exists(cls, path: Path | str) -> bool: + """ + This method check if there is a Sourcetrail db with the given path. + If the provided path does not end with the sourcetrail db correct + suffix. It will be added. + + :param path: The path to test + :return: a bool + """ + path = cls.__uniformize_path(path) + return path.exists() + @classmethod def open(cls, path: Path | str, clear: bool = False) -> 'SourcetrailDB': """ @@ -229,8 +243,7 @@ def __add_if_not_existing(self, name: str, type_: NodeType) -> int: the existing one """ - node = NodeDAO.get_by_name(self.database, name) - if not node: + if name not in self.name_cache: elem = Element() elem.id = ElementDAO.new(self.database, elem) @@ -240,9 +253,10 @@ def __add_if_not_existing(self, name: str, type_: NodeType) -> int: name )) + self.name_cache[name] = elem.id return elem.id else: - return node.id + return self.name_cache[name] def _record_symbol(self, hierarchy: NameHierarchy) -> int: """