Skip to content

Commit

Permalink
WIP independent pair sets
Browse files Browse the repository at this point in the history
  • Loading branch information
jbothma committed Aug 19, 2024
1 parent 9d7cbf6 commit 74719f4
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 7 deletions.
19 changes: 19 additions & 0 deletions nomenklatura/matching/pairs.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ def to_dict(self) -> Dict[str, Any]:
"judgement": self.judgement.value,
}

def __hash__(self):
return hash((self.left.id, self.right.id, self.judgement.value))


def read_pairs(pairs_file: PathLike) -> Generator[JudgedPair, None, None]:
"""Read judgement pairs (training data) from a JSON file."""
Expand All @@ -39,3 +42,19 @@ def read_pairs(pairs_file: PathLike) -> Generator[JudgedPair, None, None]:
if judgement not in (Judgement.POSITIVE, Judgement.NEGATIVE):
continue
yield JudgedPair(left_entity, right_entity, judgement)


def read_pair_sets(pairs_file: PathLike) -> Generator[Set[JudgedPair], None, None]:
with open(pairs_file, "r") as fh:
while line := fh.readline():
pair_array = json.loads(line)
pair_set: Set[JudgedPair] = set()
for pair_dict in data:
left_entity = EntityProxy.from_dict(model, pair_dict["left"])
right_entity = EntityProxy.from_dict(model, pair_dict["right"])
judgement = Judgement(pair_dict["judgement"])
if judgement not in (Judgement.POSITIVE, Judgement.NEGATIVE):
continue
pair_set.add(JudgedPair(left_entity, right_entity, judgement))
yield pair_set

11 changes: 4 additions & 7 deletions nomenklatura/matching/randomforest_v1/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
from sklearn.ensemble import RandomForestClassifier # type: ignore
from sklearn import metrics # type: ignore
from concurrent.futures import ThreadPoolExecutor
from random import shuffle

from nomenklatura.judgement import Judgement
from nomenklatura.resolver import Resolver
from nomenklatura.matching.pairs import read_pairs, JudgedPair
from nomenklatura.matching.randomforest_v1.model import RandomForestV1
from nomenklatura.util import PathLike
Expand Down Expand Up @@ -48,16 +50,11 @@ def pairs_to_arrays(
def train_matcher(pairs_file: PathLike) -> None:
pairs = []
for pair in read_pairs(pairs_file):
# HACK: support more eventually:
# if not pair.left.schema.is_a("LegalEntity"):
# continue
if pair.judgement == Judgement.UNSURE:
pair.judgement = Judgement.NEGATIVE
# randomize_entity(pair.left)
# randomize_entity(pair.right)
pairs.append(pair)
# random.shuffle(pairs)
# pairs = pairs[:30000]
resolver = Resolver.load("../operations/etl/data/resolve.ijson")
pairs = shuffle(pairs)
positive = len([p for p in pairs if p.judgement == Judgement.POSITIVE])
negative = len([p for p in pairs if p.judgement == Judgement.NEGATIVE])
log.info("Total pairs loaded: %d (%d pos/%d neg)", len(pairs), positive, negative)
Expand Down

0 comments on commit 74719f4

Please sign in to comment.