Skip to content

Commit

Permalink
make the resolver thread-safe
Browse files Browse the repository at this point in the history
  • Loading branch information
pudo committed Aug 2, 2024
1 parent ed78bb4 commit 5fa94c2
Showing 1 changed file with 54 additions and 46 deletions.
100 changes: 54 additions & 46 deletions nomenklatura/resolver/resolver.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import getpass
from pathlib import Path
from threading import RLock
from functools import lru_cache
from collections import defaultdict
from typing import Dict, Generator, List, Optional, Set, Tuple
Expand All @@ -22,6 +23,7 @@ class Resolver(Linker[CE]):

def __init__(self, path: Optional[Path] = None) -> None:
self.path = path
self.lock = RLock()
self.edges: Dict[Pair, Edge] = {}
self.nodes: Dict[Identifier, Set[Edge]] = defaultdict(set)

Expand Down Expand Up @@ -165,30 +167,31 @@ def decide(
user: Optional[str] = None,
score: Optional[float] = None,
) -> Identifier:
edge = self.get_edge(left_id, right_id)
if edge is None:
edge = Edge(left_id, right_id, judgement=judgement)

# Canonicalise positive matches, i.e. make both identifiers refer to a
# canonical identifier, instead of making a direct link.
if judgement == Judgement.POSITIVE:
connected = set(self.connected(edge.target))
connected.update(self.connected(edge.source))
target = max(connected)
if not target.canonical:
canonical = Identifier.make()
self._remove_edge(edge)
self.decide(edge.source, canonical, judgement=judgement, user=user)
self.decide(edge.target, canonical, judgement=judgement, user=user)
return canonical

edge.judgement = judgement
edge.timestamp = utc_now().isoformat()[:16]
edge.user = user or getpass.getuser()
edge.score = score or edge.score
self._register(edge)
self.connected.cache_clear()
return edge.target
with self.lock:
edge = self.get_edge(left_id, right_id)
if edge is None:
edge = Edge(left_id, right_id, judgement=judgement)

# Canonicalise positive matches, i.e. make both identifiers refer to a
# canonical identifier, instead of making a direct link.
if judgement == Judgement.POSITIVE:
connected = set(self.connected(edge.target))
connected.update(self.connected(edge.source))
target = max(connected)
if not target.canonical:
canonical = Identifier.make()
self._remove_edge(edge)
self.decide(edge.source, canonical, judgement=judgement, user=user)
self.decide(edge.target, canonical, judgement=judgement, user=user)
return canonical

edge.judgement = judgement
edge.timestamp = utc_now().isoformat()[:16]
edge.user = user or getpass.getuser()
edge.score = score or edge.score
self._register(edge)
self.connected.cache_clear()
return edge.target

def _register(self, edge: Edge) -> None:
if edge.judgement != Judgement.NO_JUDGEMENT:
Expand All @@ -215,39 +218,43 @@ def _remove_node(self, node: Identifier) -> None:

def remove(self, node_id: StrIdent) -> None:
"""Remove all edges linking to the given node from the graph."""
node = Identifier.get(node_id)
self._remove_node(node)
self.connected.cache_clear()
with self.lock:
node = Identifier.get(node_id)
self._remove_node(node)
self.connected.cache_clear()

def explode(self, node_id: StrIdent) -> Set[str]:
"""Dissolve all edges linked to the cluster to which the node belongs.
This is the hard way to make sure we re-do context once we realise
there's been a mistake."""
node = Identifier.get(node_id)
affected: Set[str] = set()
for part in self.connected(node):
affected.add(str(part))
self._remove_node(part)
self.connected.cache_clear()
return affected
with self.lock:
node = Identifier.get(node_id)
affected: Set[str] = set()
for part in self.connected(node):
affected.add(str(part))
self._remove_node(part)
self.connected.cache_clear()
return affected

def prune(self) -> None:
"""Remove suggested (i.e. NO_JUDGEMENT) edges, keep only the n with the
highest score. This also checks if a transitive judgement has been
established in the mean time and removes those candidates."""
for edge in list(self.edges.values()):
if edge.judgement == Judgement.NO_JUDGEMENT:
self._remove_edge(edge)
self.connected.cache_clear()
with self.lock:
for edge in list(self.edges.values()):
if edge.judgement == Judgement.NO_JUDGEMENT:
self._remove_edge(edge)
self.connected.cache_clear()

def save(self) -> None:
"""Store the resolver adjacency list to a plain text JSON list."""
if self.path is None:
raise RuntimeError("Resolver has no path")
edges = sorted(self.edges.values())
with open(self.path, "w") as fh:
for edge in edges:
fh.write(edge.to_line())
with self.lock:
if self.path is None:
raise RuntimeError("Resolver has no path")
edges = sorted(self.edges.values())
with open(self.path, "w") as fh:
for edge in edges:
fh.write(edge.to_line())

def merge(self, path: PathLike) -> None:
with open(path, "r") as fh:
Expand Down Expand Up @@ -278,8 +285,9 @@ def _load_edges(cls, path: Path) -> Generator[Edge, None, None]:
@classmethod
def load(cls, path: Path) -> "Resolver[CE]":
resolver = cls(path=path)
for edge in cls._load_edges(path):
resolver._register(edge)
with resolver.lock:
for edge in cls._load_edges(path):
resolver._register(edge)
return resolver

@classmethod
Expand Down

0 comments on commit 5fa94c2

Please sign in to comment.