diff --git a/nomenklatura/resolver/resolver.py b/nomenklatura/resolver/resolver.py index b3767a58..cb8e3c6c 100644 --- a/nomenklatura/resolver/resolver.py +++ b/nomenklatura/resolver/resolver.py @@ -1,6 +1,7 @@ import logging import getpass from pathlib import Path +from threading import RLock from functools import lru_cache from collections import defaultdict from typing import Dict, Generator, List, Optional, Set, Tuple @@ -22,6 +23,7 @@ class Resolver(Linker[CE]): def __init__(self, path: Optional[Path] = None) -> None: self.path = path + self.lock = RLock() self.edges: Dict[Pair, Edge] = {} self.nodes: Dict[Identifier, Set[Edge]] = defaultdict(set) @@ -165,30 +167,31 @@ def decide( user: Optional[str] = None, score: Optional[float] = None, ) -> Identifier: - edge = self.get_edge(left_id, right_id) - if edge is None: - edge = Edge(left_id, right_id, judgement=judgement) - - # Canonicalise positive matches, i.e. make both identifiers refer to a - # canonical identifier, instead of making a direct link. - if judgement == Judgement.POSITIVE: - connected = set(self.connected(edge.target)) - connected.update(self.connected(edge.source)) - target = max(connected) - if not target.canonical: - canonical = Identifier.make() - self._remove_edge(edge) - self.decide(edge.source, canonical, judgement=judgement, user=user) - self.decide(edge.target, canonical, judgement=judgement, user=user) - return canonical - - edge.judgement = judgement - edge.timestamp = utc_now().isoformat()[:16] - edge.user = user or getpass.getuser() - edge.score = score or edge.score - self._register(edge) - self.connected.cache_clear() - return edge.target + with self.lock: + edge = self.get_edge(left_id, right_id) + if edge is None: + edge = Edge(left_id, right_id, judgement=judgement) + + # Canonicalise positive matches, i.e. make both identifiers refer to a + # canonical identifier, instead of making a direct link. + if judgement == Judgement.POSITIVE: + connected = set(self.connected(edge.target)) + connected.update(self.connected(edge.source)) + target = max(connected) + if not target.canonical: + canonical = Identifier.make() + self._remove_edge(edge) + self.decide(edge.source, canonical, judgement=judgement, user=user) + self.decide(edge.target, canonical, judgement=judgement, user=user) + return canonical + + edge.judgement = judgement + edge.timestamp = utc_now().isoformat()[:16] + edge.user = user or getpass.getuser() + edge.score = score or edge.score + self._register(edge) + self.connected.cache_clear() + return edge.target def _register(self, edge: Edge) -> None: if edge.judgement != Judgement.NO_JUDGEMENT: @@ -215,39 +218,43 @@ def _remove_node(self, node: Identifier) -> None: def remove(self, node_id: StrIdent) -> None: """Remove all edges linking to the given node from the graph.""" - node = Identifier.get(node_id) - self._remove_node(node) - self.connected.cache_clear() + with self.lock: + node = Identifier.get(node_id) + self._remove_node(node) + self.connected.cache_clear() def explode(self, node_id: StrIdent) -> Set[str]: """Dissolve all edges linked to the cluster to which the node belongs. This is the hard way to make sure we re-do context once we realise there's been a mistake.""" - node = Identifier.get(node_id) - affected: Set[str] = set() - for part in self.connected(node): - affected.add(str(part)) - self._remove_node(part) - self.connected.cache_clear() - return affected + with self.lock: + node = Identifier.get(node_id) + affected: Set[str] = set() + for part in self.connected(node): + affected.add(str(part)) + self._remove_node(part) + self.connected.cache_clear() + return affected def prune(self) -> None: """Remove suggested (i.e. NO_JUDGEMENT) edges, keep only the n with the highest score. This also checks if a transitive judgement has been established in the mean time and removes those candidates.""" - for edge in list(self.edges.values()): - if edge.judgement == Judgement.NO_JUDGEMENT: - self._remove_edge(edge) - self.connected.cache_clear() + with self.lock: + for edge in list(self.edges.values()): + if edge.judgement == Judgement.NO_JUDGEMENT: + self._remove_edge(edge) + self.connected.cache_clear() def save(self) -> None: """Store the resolver adjacency list to a plain text JSON list.""" - if self.path is None: - raise RuntimeError("Resolver has no path") - edges = sorted(self.edges.values()) - with open(self.path, "w") as fh: - for edge in edges: - fh.write(edge.to_line()) + with self.lock: + if self.path is None: + raise RuntimeError("Resolver has no path") + edges = sorted(self.edges.values()) + with open(self.path, "w") as fh: + for edge in edges: + fh.write(edge.to_line()) def merge(self, path: PathLike) -> None: with open(path, "r") as fh: @@ -278,8 +285,9 @@ def _load_edges(cls, path: Path) -> Generator[Edge, None, None]: @classmethod def load(cls, path: Path) -> "Resolver[CE]": resolver = cls(path=path) - for edge in cls._load_edges(path): - resolver._register(edge) + with resolver.lock: + for edge in cls._load_edges(path): + resolver._register(edge) return resolver @classmethod