From 79816318ee46c6cf1522a905f66fa724eaf196b7 Mon Sep 17 00:00:00 2001 From: JD Bothma Date: Fri, 6 Sep 2024 10:51:43 +0100 Subject: [PATCH 1/7] Add regression v3 crawler based on v1 with disjoint cluster training data --- nomenklatura/cli.py | 13 +- nomenklatura/matching/__init__.py | 5 + nomenklatura/matching/pairs.py | 15 +- .../matching/regression_v3/__init__.py | 0 nomenklatura/matching/regression_v3/misc.py | 63 +++++++++ nomenklatura/matching/regression_v3/model.py | 105 ++++++++++++++ nomenklatura/matching/regression_v3/names.py | 63 +++++++++ nomenklatura/matching/regression_v3/train.py | 89 ++++++++++++ nomenklatura/matching/regression_v3/util.py | 31 +++++ tests/matching/test_regression_v3.py | 131 ++++++++++++++++++ 10 files changed, 509 insertions(+), 6 deletions(-) create mode 100644 nomenklatura/matching/regression_v3/__init__.py create mode 100644 nomenklatura/matching/regression_v3/misc.py create mode 100644 nomenklatura/matching/regression_v3/model.py create mode 100644 nomenklatura/matching/regression_v3/names.py create mode 100644 nomenklatura/matching/regression_v3/train.py create mode 100644 nomenklatura/matching/regression_v3/util.py create mode 100644 tests/matching/test_regression_v3.py diff --git a/nomenklatura/cli.py b/nomenklatura/cli.py index 51b2b3e8..03f1303e 100644 --- a/nomenklatura/cli.py +++ b/nomenklatura/cli.py @@ -11,7 +11,11 @@ from nomenklatura.cache import Cache from nomenklatura.index import Index, INDEX_TYPES -from nomenklatura.matching import train_v2_matcher, train_v1_matcher +from nomenklatura.matching import ( + train_v3_matcher, + train_v2_matcher, + train_v1_matcher, +) from nomenklatura.store import load_entity_file_store from nomenklatura.resolver import Resolver from nomenklatura.dataset import Dataset, DefaultDataset @@ -193,6 +197,13 @@ def train_v2_matcher_(pairs_file: Path) -> None: train_v2_matcher(pairs_file) +@cli.command("train-v3-matcher", help="Train a matching model from judgement pairs") +@click.argument("pairs_file", type=InPath) +@click.option("-s", "--splits", type=int, default=1) +def train_v3_matcher_(pairs_file: Path, splits: int = 1) -> None: + train_v3_matcher(pairs_file, splits) + + @cli.command("match", help="Generate matches from an enrichment source") @click.argument("config", type=InPath) @click.argument("entities", type=InPath) diff --git a/nomenklatura/matching/__init__.py b/nomenklatura/matching/__init__.py index c84fa59d..0ee5763c 100644 --- a/nomenklatura/matching/__init__.py +++ b/nomenklatura/matching/__init__.py @@ -3,6 +3,8 @@ from nomenklatura.matching.regression_v1.train import train_matcher as train_v1_matcher from nomenklatura.matching.regression_v2.model import RegressionV2 from nomenklatura.matching.regression_v2.train import train_matcher as train_v2_matcher +from nomenklatura.matching.regression_v3.model import RegressionV3 +from nomenklatura.matching.regression_v3.train import train_matcher as train_v3_matcher from nomenklatura.matching.name_based import NameMatcher, NameQualifiedMatcher from nomenklatura.matching.logic import LogicV1 from nomenklatura.matching.types import ScoringAlgorithm @@ -13,6 +15,7 @@ NameQualifiedMatcher, RegressionV1, RegressionV2, + RegressionV3, ] DefaultAlgorithm = RegressionV2 @@ -31,6 +34,8 @@ def get_algorithm(name: str) -> Optional[Type[ScoringAlgorithm]]: "train_v1_matcher", "RegressionV2", "train_v2_matcher", + "RegressionV3", + "train_v3_matcher", "DefaultAlgorithm", "ScoringAlgorithm", "NameMatcher", diff --git a/nomenklatura/matching/pairs.py b/nomenklatura/matching/pairs.py index c2c00b0c..59440cd1 100644 --- a/nomenklatura/matching/pairs.py +++ b/nomenklatura/matching/pairs.py @@ -8,23 +8,27 @@ class JudgedPair(object): - """A pair of two entities which have been judged to be the same - (or not) by a user.""" + """ + A pair of two entities which have been judged to be the same + (or not) by a user. + """ - __slots__ = ("left", "right", "judgement") + __slots__ = ("left", "right", "judgement", "group") def __init__( - self, left: EntityProxy, right: EntityProxy, judgement: Judgement + self, left: EntityProxy, right: EntityProxy, judgement: Judgement, group: int ) -> None: self.left = left self.right = right self.judgement = judgement + self.group = group def to_dict(self) -> Dict[str, Any]: return { "left": self.left.to_dict(), "right": self.right.to_dict(), "judgement": self.judgement.value, + "group": self.group, } @@ -38,4 +42,5 @@ def read_pairs(pairs_file: PathLike) -> Generator[JudgedPair, None, None]: judgement = Judgement(data["judgement"]) if judgement not in (Judgement.POSITIVE, Judgement.NEGATIVE): continue - yield JudgedPair(left_entity, right_entity, judgement) + group = data.get("group", None) + yield JudgedPair(left_entity, right_entity, judgement, group) diff --git a/nomenklatura/matching/regression_v3/__init__.py b/nomenklatura/matching/regression_v3/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/nomenklatura/matching/regression_v3/misc.py b/nomenklatura/matching/regression_v3/misc.py new file mode 100644 index 00000000..5908b50a --- /dev/null +++ b/nomenklatura/matching/regression_v3/misc.py @@ -0,0 +1,63 @@ +from followthemoney.proxy import E +from followthemoney.types import registry + +from nomenklatura.matching.regression_v1.util import tokenize_pair, compare_levenshtein +from nomenklatura.matching.compare.util import has_overlap, extract_numbers +from nomenklatura.matching.util import props_pair, type_pair +from nomenklatura.matching.util import max_in_sets, has_schema +from nomenklatura.util import normalize_name + + +def birth_place(query: E, result: E) -> float: + """Same place of birth.""" + lv, rv = tokenize_pair(props_pair(query, result, ["birthPlace"])) + tokens = min(len(lv), len(rv)) + return float(len(lv.intersection(rv))) / float(max(2.0, tokens)) + + +def address_match(query: E, result: E) -> float: + """Text similarity between addresses.""" + lv, rv = type_pair(query, result, registry.address) + lvn = [normalize_name(v) for v in lv] + rvn = [normalize_name(v) for v in rv] + return max_in_sets(lvn, rvn, compare_levenshtein) + + +def address_numbers(query: E, result: E) -> float: + """Find if names contain numbers, score if the numbers are different.""" + lv, rv = type_pair(query, result, registry.address) + lvn = extract_numbers(lv) + rvn = extract_numbers(rv) + common = len(lvn.intersection(rvn)) + disjoint = len(lvn.difference(rvn)) + return common - disjoint + + +def phone_match(query: E, result: E) -> float: + """Matching phone numbers between the two entities.""" + lv, rv = type_pair(query, result, registry.phone) + return 1.0 if has_overlap(lv, rv) else 0.0 + + +def email_match(query: E, result: E) -> float: + """Matching email addresses between the two entities.""" + lv, rv = type_pair(query, result, registry.email) + return 1.0 if has_overlap(lv, rv) else 0.0 + + +def identifier_match(query: E, result: E) -> float: + """Matching identifiers (e.g. passports, national ID cards, registration or + tax numbers) between the two entities.""" + if has_schema(query, result, "Organization"): + return 0.0 + lv, rv = type_pair(query, result, registry.identifier) + return 1.0 if has_overlap(lv, rv) else 0.0 + + +def org_identifier_match(query: E, result: E) -> float: + """Matching identifiers (e.g. registration or tax numbers) between two + organizations or companies.""" + if not has_schema(query, result, "Organization"): + return 0.0 + lv, rv = type_pair(query, result, registry.identifier) + return 1.0 if has_overlap(lv, rv) else 0.0 diff --git a/nomenklatura/matching/regression_v3/model.py b/nomenklatura/matching/regression_v3/model.py new file mode 100644 index 00000000..e08a220d --- /dev/null +++ b/nomenklatura/matching/regression_v3/model.py @@ -0,0 +1,105 @@ +import pickle +import numpy as np +from typing import List, Dict, Tuple, cast +from functools import cache +from sklearn.pipeline import Pipeline # type: ignore +from followthemoney.proxy import E + +from nomenklatura.matching.regression_v3.names import first_name_match +from nomenklatura.matching.regression_v3.names import family_name_match +from nomenklatura.matching.regression_v3.names import name_levenshtein, name_match +from nomenklatura.matching.regression_v3.names import name_token_overlap, name_numbers +from nomenklatura.matching.regression_v3.misc import phone_match, email_match +from nomenklatura.matching.regression_v3.misc import address_match, address_numbers +from nomenklatura.matching.regression_v3.misc import identifier_match, birth_place +from nomenklatura.matching.regression_v3.misc import org_identifier_match +from nomenklatura.matching.compare.countries import country_mismatch +from nomenklatura.matching.compare.gender import gender_mismatch +from nomenklatura.matching.compare.dates import dob_matches, dob_year_matches +from nomenklatura.matching.compare.dates import dob_year_disjoint +from nomenklatura.matching.types import FeatureDocs, FeatureDoc, MatchingResult +from nomenklatura.matching.types import CompareFunction, Encoded, ScoringAlgorithm +from nomenklatura.matching.util import make_github_url +from nomenklatura.util import DATA_PATH + + +class RegressionV3(ScoringAlgorithm): + """A simple matching algorithm based on a regression model.""" + + NAME = "regression-v3" + MODEL_PATH = DATA_PATH.joinpath(f"{NAME}.pkl") + FEATURES: List[CompareFunction] = [ + name_match, + name_token_overlap, + name_numbers, + name_levenshtein, + phone_match, + email_match, + identifier_match, + dob_matches, + dob_year_matches, + dob_year_disjoint, + first_name_match, + family_name_match, + birth_place, + gender_mismatch, + country_mismatch, + org_identifier_match, + address_match, + address_numbers, + ] + + @classmethod + def save(cls, pipe: Pipeline, coefficients: Dict[str, float]) -> None: + """Store a classification pipeline after training.""" + mdl = pickle.dumps({"pipe": pipe, "coefficients": coefficients}) + with open(cls.MODEL_PATH, "wb") as fh: + fh.write(mdl) + cls.load.cache_clear() + cls.explain.cache_clear() + + @classmethod + @cache + def load(cls) -> Tuple[Pipeline, Dict[str, float]]: + """Load a pre-trained classification pipeline for ad-hoc use.""" + with open(cls.MODEL_PATH, "rb") as fh: + matcher = pickle.loads(fh.read()) + pipe = cast(Pipeline, matcher["pipe"]) + coefficients = cast(Dict[str, float], matcher["coefficients"]) + current = [f.__name__ for f in cls.FEATURES] + if list(coefficients.keys()) != current: + raise RuntimeError("Model was not trained on identical features!") + return pipe, coefficients + + @classmethod + @cache + def explain(cls) -> FeatureDocs: + """Return an explanation of the features and their coefficients.""" + features: FeatureDocs = {} + _, coefficients = cls.load() + for func in cls.FEATURES: + name = func.__name__ + features[name] = FeatureDoc( + description=func.__doc__, + coefficient=float(coefficients[name]), + url=make_github_url(func), + ) + return features + + @classmethod + def compare( + cls, query: E, match: E, override_weights: Dict[str, float] = {} + ) -> MatchingResult: + """Use a regression model to compare two entities.""" + pipe, _ = cls.load() + encoded = cls.encode_pair(query, match) + npfeat = np.array([encoded]) + pred = pipe.predict_proba(npfeat) + score = cast(float, pred[0][1]) + features = {f.__name__: float(c) for f, c in zip(cls.FEATURES, encoded)} + return MatchingResult.make(score=score, features=features) + + @classmethod + def encode_pair(cls, left: E, right: E) -> Encoded: + """Encode the comparison between two entities as a set of feature values.""" + return [f(left, right) for f in cls.FEATURES] diff --git a/nomenklatura/matching/regression_v3/names.py b/nomenklatura/matching/regression_v3/names.py new file mode 100644 index 00000000..66ae6f79 --- /dev/null +++ b/nomenklatura/matching/regression_v3/names.py @@ -0,0 +1,63 @@ +from typing import Iterable, Set +from followthemoney.proxy import E +from followthemoney.types import registry + +from nomenklatura.matching.regression_v1.util import tokenize_pair, compare_levenshtein +from nomenklatura.matching.compare.util import is_disjoint, has_overlap, extract_numbers +from nomenklatura.matching.util import props_pair, type_pair +from nomenklatura.matching.util import max_in_sets +from nomenklatura.util import fingerprint_name + + +def normalize_names(raws: Iterable[str]) -> Set[str]: + names = set() + for raw in raws: + name = fingerprint_name(raw) + if name is not None: + names.add(name[:128]) + return names + + +def name_levenshtein(left: E, right: E) -> float: + """Consider the edit distance (as a fraction of name length) between the two most + similar names linked to both entities.""" + lv, rv = type_pair(left, right, registry.name) + lvn, rvn = normalize_names(lv), normalize_names(rv) + return max_in_sets(lvn, rvn, compare_levenshtein) + + +def first_name_match(left: E, right: E) -> float: + """Matching first/given name between the two entities.""" + lv, rv = tokenize_pair(props_pair(left, right, ["firstName"])) + return 1.0 if has_overlap(lv, rv) else 0.0 + + +def family_name_match(left: E, right: E) -> float: + """Matching family name between the two entities.""" + lv, rv = tokenize_pair(props_pair(left, right, ["lastName"])) + return 1.0 if has_overlap(lv, rv) else 0.0 + + +def name_match(left: E, right: E) -> float: + """Check for exact name matches between the two entities.""" + lv, rv = type_pair(left, right, registry.name) + lvn, rvn = normalize_names(lv), normalize_names(rv) + common = [len(n) for n in lvn.intersection(rvn)] + max_common = max(common, default=0) + if max_common == 0: + return 0.0 + return float(max_common) + + +def name_token_overlap(left: E, right: E) -> float: + """Evaluate the proportion of identical words in each name.""" + lv, rv = tokenize_pair(type_pair(left, right, registry.name)) + common = lv.intersection(rv) + tokens = min(len(lv), len(rv)) + return float(len(common)) / float(max(2.0, tokens)) + + +def name_numbers(left: E, right: E) -> float: + """Find if names contain numbers, score if the numbers are different.""" + lv, rv = type_pair(left, right, registry.name) + return 1.0 if is_disjoint(extract_numbers(lv), extract_numbers(rv)) else 0.0 diff --git a/nomenklatura/matching/regression_v3/train.py b/nomenklatura/matching/regression_v3/train.py new file mode 100644 index 00000000..56ccaa69 --- /dev/null +++ b/nomenklatura/matching/regression_v3/train.py @@ -0,0 +1,89 @@ +import logging +import numpy as np +import multiprocessing +from typing import List, Tuple +from pprint import pprint +from numpy.typing import NDArray +from sklearn.pipeline import make_pipeline # type: ignore +from sklearn.preprocessing import StandardScaler # type: ignore +from sklearn.model_selection import GroupShuffleSplit # type: ignore +from sklearn.linear_model import LogisticRegression # type: ignore +from sklearn import metrics # type: ignore +from concurrent.futures import ProcessPoolExecutor + +from nomenklatura.judgement import Judgement +from nomenklatura.matching.pairs import read_pairs, JudgedPair +from nomenklatura.matching.regression_v3.model import RegressionV3 +from nomenklatura.util import PathLike + +log = logging.getLogger(__name__) + + +def pair_convert(pair: JudgedPair) -> Tuple[List[float], int]: + """Encode a pair of training data into features and target.""" + judgement = 1 if pair.judgement == Judgement.POSITIVE else 0 + features = RegressionV3.encode_pair(pair.left, pair.right) + return features, judgement + + +def pairs_to_arrays( + pairs: List[JudgedPair], +) -> Tuple[NDArray[np.float32], NDArray[np.float32]]: + """Parallelize feature computation for training data""" + xrows = [] + yrows = [] + threads = multiprocessing.cpu_count() + log.info("Compute threads: %d", threads) + with ProcessPoolExecutor(max_workers=threads) as executor: + results = executor.map(pair_convert, pairs, chunksize=1000) + for idx, (x, y) in enumerate(results): + if idx > 0 and idx % 10000 == 0: + log.info("Computing features: %s....", idx) + xrows.append(x) + yrows.append(y) + + return np.array(xrows), np.array(yrows) + + +def train_matcher(pairs_file: PathLike, splits: int = 1) -> None: + pairs = [] + for pair in read_pairs(pairs_file): + if pair.judgement == Judgement.UNSURE: + pair.judgement = Judgement.NEGATIVE + pairs.append(pair) + positive = len([p for p in pairs if p.judgement == Judgement.POSITIVE]) + negative = len([p for p in pairs if p.judgement == Judgement.NEGATIVE]) + log.info("Total pairs loaded: %d (%d pos/%d neg)", len(pairs), positive, negative) + + X, y = pairs_to_arrays(pairs) + groups = [p.group for p in pairs] + gss = GroupShuffleSplit(n_splits=splits, test_size=0.33) + for split, (train_indices, test_indices) in enumerate( + gss.split(X, y, groups=groups), 1 + ): + X_train = [X[i] for i in train_indices] + X_test = [X[i] for i in test_indices] + y_train = [y[i] for i in train_indices] + y_test = [y[i] for i in test_indices] + + print() + log.info("Training model...(split %d)" % split) + logreg = LogisticRegression(penalty="l2") + pipe = make_pipeline(StandardScaler(), logreg) + pipe.fit(X_train, y_train) + coef = logreg.coef_[0] + coefficients = {n.__name__: c for n, c in zip(RegressionV3.FEATURES, coef)} + RegressionV3.save(pipe, coefficients) + + print("Coefficients:") + pprint(coefficients) + y_pred = pipe.predict(X_test) + cnf_matrix = metrics.confusion_matrix(y_test, y_pred, normalize="all") * 100 + print("Confusion matrix (% of all):\n", cnf_matrix) + print("Accuracy:", metrics.accuracy_score(y_test, y_pred)) + print("Precision:", metrics.precision_score(y_test, y_pred)) + print("Recall:", metrics.recall_score(y_test, y_pred)) + + y_pred_proba = pipe.predict_proba(X_test)[::, 1] + auc = metrics.roc_auc_score(y_test, y_pred_proba) + print("Area under curve:", auc) diff --git a/nomenklatura/matching/regression_v3/util.py b/nomenklatura/matching/regression_v3/util.py new file mode 100644 index 00000000..078b8ee9 --- /dev/null +++ b/nomenklatura/matching/regression_v3/util.py @@ -0,0 +1,31 @@ +from normality.constants import WS +from typing import Iterable, Set, Tuple +from rigour.text.distance import levenshtein + +from nomenklatura.util import normalize_name + + +def tokenize(texts: Iterable[str]) -> Set[str]: + tokens: Set[str] = set() + for text in texts: + cleaned = normalize_name(text) + if cleaned is None: + continue + for token in cleaned.split(WS): + token = token.strip() + if len(token) > 2: + tokens.add(token) + return tokens + + +def tokenize_pair( + pair: Tuple[Iterable[str], Iterable[str]] +) -> Tuple[Set[str], Set[str]]: + return tokenize(pair[0]), tokenize(pair[1]) + + +def compare_levenshtein(left: str, right: str) -> float: + distance = levenshtein(left, right) + base = max((1, len(left), len(right))) + return 1.0 - (distance / float(base)) + # return math.sqrt(distance) diff --git a/tests/matching/test_regression_v3.py b/tests/matching/test_regression_v3.py new file mode 100644 index 00000000..9bbe3f17 --- /dev/null +++ b/tests/matching/test_regression_v3.py @@ -0,0 +1,131 @@ +from followthemoney import model + +from nomenklatura.entity import CompositeEntity as Entity +from nomenklatura.matching import RegressionV3 + +candidate = { + "id": "left-putin", + "schema": "Person", + "properties": { + "name": ["Vladimir Putin"], + "birthDate": ["1952-10-07"], + "country": ["ru"], + }, +} + +putin = { + "id": "right-putin", + "schema": "Person", + "properties": { + "name": ["Vladimir Vladimirovich Putin"], + "birthDate": ["1952-10-07"], + "nationality": ["ru"], + }, +} + +saddam = { + "id": "other-guy", + "schema": "Person", + "properties": { + "name": ["Saddam Hussein"], + "birthDate": ["1937"], + "nationality": ["iq"], + }, +} + + +POS_ET = { + "id": "et-1", + "schema": "Position", + "properties": { + "name": ["Minister of ABC"], + "country": ["et"], + }, +} + +POS_ET2 = { + "id": "et-2", + "schema": "Position", + "properties": { + "name": ["Their excellency the Minister of ABC"], + "country": ["et"], + }, +} + +POS_VU = { + "id": "vu-1", + "schema": "Position", + "properties": { + "name": ["Minister of ABC"], + "country": ["vu"], + }, +} + + +def test_explain_matcher(): + explanation = RegressionV3.explain() + assert len(explanation) > 3, explanation + for _, desc in explanation.items(): + assert len(desc.description) > 0, desc + assert desc.coefficient != 0.0, desc + assert "github" in desc.url, desc + + +def test_compare_entities(): + cand = Entity.from_dict(model, candidate) + match = Entity.from_dict(model, putin) + mismatch = Entity.from_dict(model, saddam) + + res_match = RegressionV3.compare(cand, match) + res_mismatch = RegressionV3.compare(cand, mismatch) + assert res_match.score > res_mismatch.score + assert res_match.score > 0.5 + assert res_mismatch.score < 0.5 + + +def test_compare_features(): + cand = Entity.from_dict(model, candidate) + match = Entity.from_dict(model, putin) + ref_match = RegressionV3.compare(cand, match) + ref_score = ref_match.score + + no_bday = match.clone() + no_bday.pop("birthDate") + bday_match = RegressionV3.compare(cand, no_bday) + assert ref_score > bday_match.score + + bela = match.clone() + bela.set("nationality", "by") + bela_match = RegressionV3.compare(cand, bela) + assert ref_score > bela_match.score + + +def test_position_country(): + et1 = Entity.from_dict(model, POS_ET) + et2 = Entity.from_dict(model, POS_ET2) + vu1 = Entity.from_dict(model, POS_VU) + + res_et1_et2 = RegressionV3.compare(et1, et2) + res_et1_vu1 = RegressionV3.compare(et1, vu1) + assert res_et1_et2.score > res_et1_vu1.score, (res_et1_et2, res_et1_vu1) + assert res_et1_et2.score > 0.5, res_et1_et2 + assert res_et1_vu1.score < 0.5, res_et1_vu1 + + +def test_name_country(): + """name and country together shouldn't be too strong""" + + data = { + "id": "mike1", + "schema": "Person", + "properties": { + "name": ["Mykhailov Hlib Leonidovych", "Михайлов Гліб Леонідович"], + "country": ["ru"], + }, + } + e1 = Entity.from_dict(model, data) + data["id"] = "mike2" + e2 = Entity.from_dict(model, data) + res = RegressionV3.compare(e1, e2) + assert 0.8 < res.score < 0.96, res + From 3238ed64f15ede02f863cd6b2d6b463fcdffe38a Mon Sep 17 00:00:00 2001 From: JD Bothma Date: Fri, 6 Sep 2024 11:32:22 +0100 Subject: [PATCH 2/7] Replace country_mismatch with country_match with positive and negative scoring --- nomenklatura/matching/compare/countries.py | 11 ++++++++--- nomenklatura/matching/regression_v3/model.py | 4 ++-- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/nomenklatura/matching/compare/countries.py b/nomenklatura/matching/compare/countries.py index 4e0b8900..65e73e72 100644 --- a/nomenklatura/matching/compare/countries.py +++ b/nomenklatura/matching/compare/countries.py @@ -5,7 +5,12 @@ from nomenklatura.matching.compare.util import is_disjoint -def country_mismatch(query: E, result: E) -> float: - """Both entities are linked to different countries.""" +def country_match(query: E, result: E) -> float: + """Both entities are linked to the same country.""" qv, rv = type_pair(query, result, registry.country) - return 1.0 if is_disjoint(qv, rv) else 0.0 + if qv and rv: + if has_overlap(qv, rv): + return 1.0 + elif is_disjoint(qv, rv): + return -1.0 + return 0.0 diff --git a/nomenklatura/matching/regression_v3/model.py b/nomenklatura/matching/regression_v3/model.py index e08a220d..c9db5794 100644 --- a/nomenklatura/matching/regression_v3/model.py +++ b/nomenklatura/matching/regression_v3/model.py @@ -13,7 +13,7 @@ from nomenklatura.matching.regression_v3.misc import address_match, address_numbers from nomenklatura.matching.regression_v3.misc import identifier_match, birth_place from nomenklatura.matching.regression_v3.misc import org_identifier_match -from nomenklatura.matching.compare.countries import country_mismatch +from nomenklatura.matching.compare.countries import country_match from nomenklatura.matching.compare.gender import gender_mismatch from nomenklatura.matching.compare.dates import dob_matches, dob_year_matches from nomenklatura.matching.compare.dates import dob_year_disjoint @@ -43,7 +43,7 @@ class RegressionV3(ScoringAlgorithm): family_name_match, birth_place, gender_mismatch, - country_mismatch, + country_match, org_identifier_match, address_match, address_numbers, From 13ea0455fcc39e62f39dc0048e9065732cafbd7c Mon Sep 17 00:00:00 2001 From: JD Bothma Date: Fri, 6 Sep 2024 11:36:58 +0100 Subject: [PATCH 3/7] Replace distinct date features with date similarity feature --- nomenklatura/matching/compare/dates.py | 54 ++++++++++++++++++++ nomenklatura/matching/regression_v3/model.py | 6 +-- tests/matching/test_dates.py | 14 ++++- 3 files changed, 69 insertions(+), 5 deletions(-) diff --git a/nomenklatura/matching/compare/dates.py b/nomenklatura/matching/compare/dates.py index 8f7f3165..4a5b655e 100644 --- a/nomenklatura/matching/compare/dates.py +++ b/nomenklatura/matching/compare/dates.py @@ -1,11 +1,16 @@ from typing import Iterable, Set from prefixdate import Precision from followthemoney.proxy import E +from rigour.text.distance import dam_levenshtein +from itertools import product from nomenklatura.matching.compare.util import has_overlap, is_disjoint from nomenklatura.matching.util import props_pair +MAX_YEARS = 2 + + def _dates_precision(values: Iterable[str], precision: Precision) -> Set[str]: dates = set() for value in values: @@ -70,3 +75,52 @@ def dob_year_disjoint(query: E, result: E) -> float: if is_disjoint(query_years, result_years): return 1.0 return 0.0 + + +def dob_similarity(query: E, result: E) -> float: + """ + Provide a similarity score for the birth dates of two entities taking + date precision into account. + + 1.0: precise dates match + 0.75: years match + 0.5: dates within 1 edit from each other + 0.25: years within 2 years from each other + -0.2: imprecise dates are disjoint + -0.3: precise dates are disjoint + """ + query_dates, result_dates = props_pair(query, result, ["birthDate"]) + + # missing data + if len(query_dates) == 0 or len(result_dates) == 0: + return 0.0 + + # exact match on precise dates + result_days = _dates_precision(result_dates, Precision.DAY) + query_days = _dates_precision(query_dates, Precision.DAY) + if has_overlap(query_days, result_days): + return 1.0 + + # precise dates available but have no common values + if is_disjoint(query_days, result_days): + return -0.3 + + # clerical errors on precise dates + for qd, rd in product(query_days, result_days): + if dam_levenshtein(qd, rd) <= 1: + return 0.5 + + # years overlap + query_years = _dates_precision(query_dates, Precision.YEAR) + result_years = _dates_precision(result_dates, Precision.YEAR) + if has_overlap(query_years, result_years): + return 0.75 + + # years are close + for qy, ry in product(query_years, result_years): + years_difference = abs(int(qy) - int(ry)) + if years_difference <= MAX_YEARS: + return 0.25 + + # dates exist but are disjoint other than above options + return -0.2 diff --git a/nomenklatura/matching/regression_v3/model.py b/nomenklatura/matching/regression_v3/model.py index c9db5794..2fb3159a 100644 --- a/nomenklatura/matching/regression_v3/model.py +++ b/nomenklatura/matching/regression_v3/model.py @@ -16,7 +16,7 @@ from nomenklatura.matching.compare.countries import country_match from nomenklatura.matching.compare.gender import gender_mismatch from nomenklatura.matching.compare.dates import dob_matches, dob_year_matches -from nomenklatura.matching.compare.dates import dob_year_disjoint +from nomenklatura.matching.compare.dates import dob_year_disjoint, dob_similarity from nomenklatura.matching.types import FeatureDocs, FeatureDoc, MatchingResult from nomenklatura.matching.types import CompareFunction, Encoded, ScoringAlgorithm from nomenklatura.matching.util import make_github_url @@ -36,9 +36,7 @@ class RegressionV3(ScoringAlgorithm): phone_match, email_match, identifier_match, - dob_matches, - dob_year_matches, - dob_year_disjoint, + dob_similarity, first_name_match, family_name_match, birth_place, diff --git a/tests/matching/test_dates.py b/tests/matching/test_dates.py index 9ff1fe85..10f97997 100644 --- a/tests/matching/test_dates.py +++ b/tests/matching/test_dates.py @@ -1,4 +1,4 @@ -from nomenklatura.matching.compare.dates import dob_matches, dob_year_matches +from nomenklatura.matching.compare.dates import dob_matches, dob_similarity, dob_year_matches from nomenklatura.matching.compare.dates import dob_day_disjoint, dob_year_disjoint from .util import e @@ -11,23 +11,35 @@ def test_dob_matches(): assert dob_year_matches(left, right) == 1.0 assert dob_day_disjoint(left, right) == 0.0 assert dob_year_disjoint(left, right) == 0.0 + assert dob_similarity(left, right) == 1.0 + right = e("Person", birthDate="1980-04-15") + assert dob_similarity(left, right) == 0.5 + right = e("Person", birthDate="1980-03-16") + assert dob_similarity(left, right) == 0.5 + right = e("Person", birthDate="1981-04-16") + assert dob_similarity(left, right) == 0.5 right = e("Person", birthDate="1980") assert dob_year_matches(left, right) == 1.0 assert dob_day_disjoint(left, right) == 0.0 + assert dob_similarity(left, right) == 0.75 right = e("Person", birthDate="1980-04") assert dob_year_matches(left, right) == 1.0 assert dob_day_disjoint(left, right) == 0.0 + assert dob_similarity(left, right) == 0.75 right = e("Person", birthDate="1980-04-16T19:00:00") assert dob_matches(left, right) == 1.0 assert dob_year_matches(left, right) == 1.0 assert dob_day_disjoint(left, right) == 0.0 + assert dob_similarity(left, right) == 1.0 right = e("Person", birthDate="1965-04-16") assert dob_matches(left, right) == 0.0 assert dob_year_matches(left, right) == 0.0 assert dob_day_disjoint(left, right) == 1.0 assert dob_year_disjoint(left, right) == 1.0 + assert dob_similarity(left, right) == -1.0 none = e("Person", name="Harry") assert dob_matches(left, none) == 0.0 assert dob_year_matches(left, none) == 0.0 assert dob_day_disjoint(left, none) == 0.0 assert dob_year_disjoint(left, none) == 0.0 + assert dob_similarity(left, none) == 0.0 From 5c42def06b812fbde8d35e7f7f0ba23c0fb2230e Mon Sep 17 00:00:00 2001 From: JD Bothma Date: Fri, 6 Sep 2024 11:52:03 +0100 Subject: [PATCH 4/7] Restrict name_match scale, reducing only name and country score --- nomenklatura/matching/regression_v3/names.py | 31 +++++++++++++++---- tests/matching/test_regression_v3.py | 32 ++++++++++++++++++-- 2 files changed, 55 insertions(+), 8 deletions(-) diff --git a/nomenklatura/matching/regression_v3/names.py b/nomenklatura/matching/regression_v3/names.py index 66ae6f79..312f5732 100644 --- a/nomenklatura/matching/regression_v3/names.py +++ b/nomenklatura/matching/regression_v3/names.py @@ -4,11 +4,18 @@ from nomenklatura.matching.regression_v1.util import tokenize_pair, compare_levenshtein from nomenklatura.matching.compare.util import is_disjoint, has_overlap, extract_numbers -from nomenklatura.matching.util import props_pair, type_pair +from nomenklatura.matching.util import has_schema, props_pair, type_pair from nomenklatura.matching.util import max_in_sets from nomenklatura.util import fingerprint_name +MATCH_BASE_SCORE = 0.7 +MAX_BONUS_LENGTH = 100 +LENGTH_BONUS_FACTOR = (1 - MATCH_BASE_SCORE) / MAX_BONUS_LENGTH +MAX_BONUS_QTY = 10 +QTY_BONUS_FACTOR = (1 - MATCH_BASE_SCORE) / MAX_BONUS_QTY + + def normalize_names(raws: Iterable[str]) -> Set[str]: names = set() for raw in raws: @@ -39,14 +46,26 @@ def family_name_match(left: E, right: E) -> float: def name_match(left: E, right: E) -> float: - """Check for exact name matches between the two entities.""" + """ + Check for exact name matches between the two entities. + + Having any completely matching name initially scores 0.8. + A length bonus is added based on the length of the longest common name up to 100 chars. + A quantity bonus is added based on the number of common names up to 10. + + The maximum score is 1.0. + No matches scores 0.0. + """ lv, rv = type_pair(left, right, registry.name) lvn, rvn = normalize_names(lv), normalize_names(rv) - common = [len(n) for n in lvn.intersection(rvn)] - max_common = max(common, default=0) - if max_common == 0: + common = sorted(lvn.intersection(rvn), key=lambda n: len(n), reverse=True) + if not common: return 0.0 - return float(max_common) + score = MATCH_BASE_SCORE + longest_common = common[0] + length_bonus = min(len(longest_common), MAX_BONUS_LENGTH) * LENGTH_BONUS_FACTOR + quantity_bonus = min(len(common), MAX_BONUS_QTY) * QTY_BONUS_FACTOR + return score + (length_bonus + quantity_bonus) / 2 def name_token_overlap(left: E, right: E) -> float: diff --git a/tests/matching/test_regression_v3.py b/tests/matching/test_regression_v3.py index 9bbe3f17..de19e180 100644 --- a/tests/matching/test_regression_v3.py +++ b/tests/matching/test_regression_v3.py @@ -2,6 +2,7 @@ from nomenklatura.entity import CompositeEntity as Entity from nomenklatura.matching import RegressionV3 +from nomenklatura.matching.regression_v3.names import name_match candidate = { "id": "left-putin", @@ -119,7 +120,7 @@ def test_name_country(): "id": "mike1", "schema": "Person", "properties": { - "name": ["Mykhailov Hlib Leonidovych", "Михайлов Гліб Леонідович"], + "name": ["Mykhailov Hlib Leonidovych"], "country": ["ru"], }, } @@ -127,5 +128,32 @@ def test_name_country(): data["id"] = "mike2" e2 = Entity.from_dict(model, data) res = RegressionV3.compare(e1, e2) - assert 0.8 < res.score < 0.96, res + assert 0.92 < res.score < 0.93, res + +def test_name_match(): + data = { + "id": "mike1", + "schema": "Person", + "properties": { + "name": [ + "John", + ], + }, + } + e1 = Entity.from_dict(model, data) + data["id"] = "mike2" + e2 = Entity.from_dict(model, data) + assert 0.72 < name_match(e1, e2) < 0.73 + + e1.set("name", ["a" * 100]) + e2.set("name", ["a" * 100]) + assert 0.86 < name_match(e1, e2) < 0.87 + + e1.set("name", []) + e2.set("name", []) + for i in range(10): + char = chr(65 + i) + e1.add("name", char * 100) + e2.add("name", char * 100) + assert 1.0 == name_match(e1, e2) From f1508ae6d39f2314b22fd61736aa7ecb2080f89b Mon Sep 17 00:00:00 2001 From: JD Bothma Date: Fri, 6 Sep 2024 12:39:22 +0100 Subject: [PATCH 5/7] Fixes --- nomenklatura/matching/compare/countries.py | 8 +++++++- nomenklatura/matching/compare/dates.py | 8 ++++---- tests/matching/test_dates.py | 2 +- tests/test_xref.py | 3 ++- 4 files changed, 14 insertions(+), 7 deletions(-) diff --git a/nomenklatura/matching/compare/countries.py b/nomenklatura/matching/compare/countries.py index 65e73e72..3732754c 100644 --- a/nomenklatura/matching/compare/countries.py +++ b/nomenklatura/matching/compare/countries.py @@ -2,7 +2,7 @@ from followthemoney.types import registry from nomenklatura.matching.util import type_pair -from nomenklatura.matching.compare.util import is_disjoint +from nomenklatura.matching.compare.util import is_disjoint, has_overlap def country_match(query: E, result: E) -> float: @@ -14,3 +14,9 @@ def country_match(query: E, result: E) -> float: elif is_disjoint(qv, rv): return -1.0 return 0.0 + + +def country_mismatch(query: E, result: E) -> float: + """Both entities are linked to different countries.""" + qv, rv = type_pair(query, result, registry.country) + return 1.0 if is_disjoint(qv, rv) else 0.0 diff --git a/nomenklatura/matching/compare/dates.py b/nomenklatura/matching/compare/dates.py index 4a5b655e..00010b85 100644 --- a/nomenklatura/matching/compare/dates.py +++ b/nomenklatura/matching/compare/dates.py @@ -101,15 +101,15 @@ def dob_similarity(query: E, result: E) -> float: if has_overlap(query_days, result_days): return 1.0 - # precise dates available but have no common values - if is_disjoint(query_days, result_days): - return -0.3 - # clerical errors on precise dates for qd, rd in product(query_days, result_days): if dam_levenshtein(qd, rd) <= 1: return 0.5 + # precise dates available but have no common values + if is_disjoint(query_days, result_days): + return -0.3 + # years overlap query_years = _dates_precision(query_dates, Precision.YEAR) result_years = _dates_precision(result_dates, Precision.YEAR) diff --git a/tests/matching/test_dates.py b/tests/matching/test_dates.py index 10f97997..8e82f1f3 100644 --- a/tests/matching/test_dates.py +++ b/tests/matching/test_dates.py @@ -36,7 +36,7 @@ def test_dob_matches(): assert dob_year_matches(left, right) == 0.0 assert dob_day_disjoint(left, right) == 1.0 assert dob_year_disjoint(left, right) == 1.0 - assert dob_similarity(left, right) == -1.0 + assert dob_similarity(left, right) == -0.3 none = e("Person", name="Harry") assert dob_matches(left, none) == 0.0 assert dob_year_matches(left, none) == 0.0 diff --git a/tests/test_xref.py b/tests/test_xref.py index 950fcea8..45130e27 100644 --- a/tests/test_xref.py +++ b/tests/test_xref.py @@ -6,6 +6,7 @@ from nomenklatura.entity import CompositeEntity from nomenklatura.judgement import Judgement from nomenklatura.matching.regression_v1.model import RegressionV1 +from nomenklatura.matching.regression_v3.model import RegressionV3 from nomenklatura.resolver import Resolver from nomenklatura.store import SimpleMemoryStore from nomenklatura.xref import xref @@ -81,7 +82,7 @@ def test_xref_potential_conflicts( resolver, store, index_path, - algorithm=RegressionV1, + algorithm=RegressionV3, conflicting_match_threshold=0.9, ) stdout = capsys.readouterr().out From 023f877b6f9027fa1f04c74f1223da8a6c534db5 Mon Sep 17 00:00:00 2001 From: JD Bothma Date: Fri, 6 Sep 2024 12:42:12 +0100 Subject: [PATCH 6/7] Add trained model --- nomenklatura/data/regression-v3.pkl | Bin 0 -> 2208 bytes 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 nomenklatura/data/regression-v3.pkl diff --git a/nomenklatura/data/regression-v3.pkl b/nomenklatura/data/regression-v3.pkl new file mode 100644 index 0000000000000000000000000000000000000000..865064cda76f70c52472bc5248caf20ab2026e1b GIT binary patch literal 2208 zcmaJ?Yitx%6y9yO-L)-cDYcbXl~9E8SXu-mMz}3Y5M8OJtsp@UbO}I0K}?OsT3eBi!N;`@FLlAylFm0aL&6}11Bg51{A!M zJf57t37~OCuVqsl_pPu>tR;d}&v*Xu<*5i$VdP3SP=HQ z9nZ2>Jz2Pp;M4>yQVMhFY6>MoMiLby2N^`zxJ=cqB}BV6f<2a6?6!KbL2mUpYcV%W zjq~-nkQ?>nbzyf+{ln9(qbn{0N6Wm-CG1TwIkI$jQ+@%66r-cImYbkV7&C5PNfGGC zsi<$|}R$KzZBA=C=@LU9K*3Azxx10mV&)>T_D`OK_UDPzD zHKiB?59F*J?!F%5`5b(0uPfPA{}tH!-8%<&t=tDNQLWmlXql+Y=APGtaO3;RaA`sM zaWFaoV8)H2AIa|jozDKZ%XlAj{O070@&R3w3@Bnkzu zuUsKR#7|dKJfX!NR!!EON$DI|O=8KMfDAdvV%1Jgc_jrBhRVtogzhN!jZTodD5^vtRfXwnj$` zzp#JXm1EH_FS?rx79WaktQ?bkSGWS?drnWrnaf~N`|nqCqg_#0F{5E=9+YY3(I`i6P9483?D25MUJ) z6%IjOjG@eY8mQy|CyD(7G7}vGKtIh3{USCi5W9zGi8|iLLt?=`y3G0*s8m++QVQ?q z$%}?3t~0cowTm(Ncv(Tz6B#lqYoLbsXu=YD5@r_OfF*l*nY=Uu0?okKZDvOpfWhTZ zWVM)irEl0~yBMHBj1&}+9O|dc>?8w{F3IK8f8@aIwHO)AvfA|OM`pHREGQGwzp}kO Ks=Z%>LD#<$4W4cQ literal 0 HcmV?d00001 From 0d292bc97c4055ba06bc97e71f1745d5ee95bf16 Mon Sep 17 00:00:00 2001 From: JD Bothma Date: Fri, 6 Sep 2024 16:44:29 +0100 Subject: [PATCH 7/7] Use aligned levenstein for non-persons and plain for persons This is to deal with levenshtein ending up with negative coefficient resulting in often horrible scores for very similar names --- nomenklatura/data/regression-v3.pkl | Bin 2208 -> 2211 bytes nomenklatura/matching/compare/names.py | 52 ++++++++++--------- nomenklatura/matching/regression_v3/model.py | 6 ++- nomenklatura/matching/regression_v3/names.py | 11 +++- tests/matching/test_regression_v3.py | 10 ++-- 5 files changed, 48 insertions(+), 31 deletions(-) diff --git a/nomenklatura/data/regression-v3.pkl b/nomenklatura/data/regression-v3.pkl index 865064cda76f70c52472bc5248caf20ab2026e1b..eeaeb2ce31262cc76c7bf56a66b8ee2e4834ff69 100644 GIT binary patch delta 868 zcmY+AeK6B;9LF28-%ni+>!@ohqMJHtJ+STwA2sQLi)o2dX_z)`^DtA^NvTxyAm0fW zj&-G?nyGYk>*p>-N7EE>%+~dcL#dQj*B$$H_5b_xx!3DeLBB|+9-?F z6-Ab1iFajp-gynVG0J{T-XKH^eOUPtZ%CNpP1A)90H<`AF*M480NH}3@@&PhQ?*?PZ}y={}HFR<;oh+MvXA~!bVV|&62F2ZT5b(JKLdf zXQcmK?lUl6QfmF&Rq*up_P+Q99dK(Pm_fNv0o((_nJM*EP~V--Y4PfT^~2hGahWFA z_s$cNf0-a|TRzj%V1kgY)^^c!9rzL(z6D`wSRWmPW!`H7`t;7_cRHG2QJ(J%_6j|W zuQMrpmyFcywKY;G47AB^X;!^Yw}%Brq1AvbU#m0{ggzbZDXJux`Prq1xDSB#Sbgwm z-C3XynYEvTtBK-@$i?n2Yyz5}9Xr5cQUa!SR+VUHYY3E5d35Gr7o_`abaZhu5T7j1 z9do*LM9t~Z> zI+KH0XeB-uC)H30w!m&e56Q5F_+P{Gm@`glA{{mlv$`T9=ENp_tq*9LM`{cfT98C=qf;s1Bjh!`k?`G_lLt&LyTexoa2av~$;-dANCM#IE`t z#VVv`i#U(%8vU{{md%|Tr94j#V=51oDXZyzfB65k&*$}eS6f$FQHlT1>7e%uOe6j1{CLR~`tOQk61s}|8{m#XANNQv zabb3?U|_S956e$OK5?)fpk^B4b@n>gyQDB4`=bNLq5D58s&(*O#$#nt3m|*$INF7W z!(pS?IY`+LIV%lM8tSS*Z8S=?!}Socp+A{+_dDt`JCcQeD4%`%*#JArcqN5$Jt*-m zhBD(JFj8VdWFdNpGM@|WbgBa1@C}C8q6W?v3H9SHBT$pq{%&8<1Qc^(79MR*fS1wZ zcKyHvq{Z=rr#S5pVOWkPq*d@d=UD2X{{(o{+nn+1?E$}T_i-#y2UeSx$GkhgH7A;n zVHg92Qa!V#&ry~K#aQ5+NDG=3>m0Ba;aRnXN) zq&D+RGvLHmp_XuT7;mna=_X!H2Gz5p+KA#rSIZUg z3`ke4IJU#na7FiH)|VJ0gbAl&GVdAC3e}5dE_4g-Ye(+13v^Nk$!N}W(g{7LIp8KS z)Inp0l5*;NHCJ6^c}tWm5ymFP$>K$7k*&x(h)!CY0AW(xP4PCsl!M%HCY?+|7jY-* zRu+=r&UCU2mHuPMJk){PAFywVk$K42V0O_?)?`eBafN@ diff --git a/nomenklatura/matching/compare/names.py b/nomenklatura/matching/compare/names.py index df5a25f5..34286ac2 100644 --- a/nomenklatura/matching/compare/names.py +++ b/nomenklatura/matching/compare/names.py @@ -68,6 +68,33 @@ def person_name_jaro_winkler(query: E, result: E) -> float: return score +def aligned_levenshtein(qfp: str, rfp: str) -> float: + qtokens = name_words(qfp, min_length=2) + rtokens = name_words(rfp, min_length=2) + for part in name_words(clean_name_ascii(rfp), min_length=2): + if part not in rtokens: + rtokens.append(part) + + scores: Dict[Tuple[str, str], float] = {} + # compute all pairwise scores for name parts: + for q, r in product(set(qtokens), set(rtokens)): + scores[(q, r)] = levenshtein_similarity(q, r) + aligned: List[Tuple[str, str, float]] = [] + # find the best pairing for each name part by score: + for (q, r), score in sorted(scores.items(), key=lambda i: i[1], reverse=True): + # one name part can only be used once, but can show up multiple times: + while q in qtokens and r in rtokens: + qtokens.remove(q) + rtokens.remove(r) + aligned.append((q, r, score)) + # assume there should be at least a candidate for each query name part: + if len(qtokens): + return 0.0 + qaligned = "".join(p[0] for p in aligned) + raligned = "".join(p[1] for p in aligned) + return levenshtein_similarity(qaligned, raligned) + + def name_fingerprint_levenshtein(query: E, result: E) -> float: """Two non-person entities have similar fingerprinted names. This includes simplifying entity type names (e.g. "Limited" -> "Ltd") and uses the @@ -85,30 +112,7 @@ def name_fingerprint_levenshtein(query: E, result: E) -> float: continue score = levenshtein_similarity(qfp.replace(" ", ""), rfp.replace(" ", "")) max_score = max(max_score, score) - qtokens = name_words(qfp, min_length=2) - rtokens = name_words(rfp, min_length=2) - for part in name_words(clean_name_ascii(rfp), min_length=2): - if part not in rtokens: - rtokens.append(part) - - scores: Dict[Tuple[str, str], float] = {} - # compute all pairwise scores for name parts: - for q, r in product(set(qtokens), set(rtokens)): - scores[(q, r)] = levenshtein_similarity(q, r) - aligned: List[Tuple[str, str, float]] = [] - # find the best pairing for each name part by score: - for (q, r), score in sorted(scores.items(), key=lambda i: i[1], reverse=True): - # one name part can only be used once, but can show up multiple times: - while q in qtokens and r in rtokens: - qtokens.remove(q) - rtokens.remove(r) - aligned.append((q, r, score)) - # assume there should be at least a candidate for each query name part: - if len(qtokens): - continue - qaligned = "".join(p[0] for p in aligned) - raligned = "".join(p[1] for p in aligned) - score = levenshtein_similarity(qaligned, raligned) + score = aligned_levenshtein(qfp, rfp) max_score = max(max_score, score) return max_score diff --git a/nomenklatura/matching/regression_v3/model.py b/nomenklatura/matching/regression_v3/model.py index 2fb3159a..fbd046c1 100644 --- a/nomenklatura/matching/regression_v3/model.py +++ b/nomenklatura/matching/regression_v3/model.py @@ -5,6 +5,7 @@ from sklearn.pipeline import Pipeline # type: ignore from followthemoney.proxy import E + from nomenklatura.matching.regression_v3.names import first_name_match from nomenklatura.matching.regression_v3.names import family_name_match from nomenklatura.matching.regression_v3.names import name_levenshtein, name_match @@ -13,10 +14,11 @@ from nomenklatura.matching.regression_v3.misc import address_match, address_numbers from nomenklatura.matching.regression_v3.misc import identifier_match, birth_place from nomenklatura.matching.regression_v3.misc import org_identifier_match -from nomenklatura.matching.compare.countries import country_match +from nomenklatura.matching.compare.countries import country_mismatch from nomenklatura.matching.compare.gender import gender_mismatch from nomenklatura.matching.compare.dates import dob_matches, dob_year_matches from nomenklatura.matching.compare.dates import dob_year_disjoint, dob_similarity +from nomenklatura.matching.compare.names import name_fingerprint_levenshtein from nomenklatura.matching.types import FeatureDocs, FeatureDoc, MatchingResult from nomenklatura.matching.types import CompareFunction, Encoded, ScoringAlgorithm from nomenklatura.matching.util import make_github_url @@ -41,7 +43,7 @@ class RegressionV3(ScoringAlgorithm): family_name_match, birth_place, gender_mismatch, - country_match, + country_mismatch, org_identifier_match, address_match, address_numbers, diff --git a/nomenklatura/matching/regression_v3/names.py b/nomenklatura/matching/regression_v3/names.py index 312f5732..03104614 100644 --- a/nomenklatura/matching/regression_v3/names.py +++ b/nomenklatura/matching/regression_v3/names.py @@ -2,8 +2,9 @@ from followthemoney.proxy import E from followthemoney.types import registry -from nomenklatura.matching.regression_v1.util import tokenize_pair, compare_levenshtein +from nomenklatura.matching.regression_v3.util import tokenize_pair, compare_levenshtein from nomenklatura.matching.compare.util import is_disjoint, has_overlap, extract_numbers +from nomenklatura.matching.compare.names import aligned_levenshtein from nomenklatura.matching.util import has_schema, props_pair, type_pair from nomenklatura.matching.util import max_in_sets from nomenklatura.util import fingerprint_name @@ -30,7 +31,13 @@ def name_levenshtein(left: E, right: E) -> float: similar names linked to both entities.""" lv, rv = type_pair(left, right, registry.name) lvn, rvn = normalize_names(lv), normalize_names(rv) - return max_in_sets(lvn, rvn, compare_levenshtein) + if has_schema(left, right, "Person"): + return max_in_sets(lvn, rvn, compare_levenshtein) + else: + return max( + max_in_sets(lv, rv, aligned_levenshtein), + max_in_sets(rv, lv, aligned_levenshtein), + ) def first_name_match(left: E, right: E) -> float: diff --git a/tests/matching/test_regression_v3.py b/tests/matching/test_regression_v3.py index de19e180..d33c9d54 100644 --- a/tests/matching/test_regression_v3.py +++ b/tests/matching/test_regression_v3.py @@ -102,6 +102,10 @@ def test_compare_features(): def test_position_country(): + """ + Two names matching with country mismatch should score better than two countries + matching with name mismatch + """ et1 = Entity.from_dict(model, POS_ET) et2 = Entity.from_dict(model, POS_ET2) vu1 = Entity.from_dict(model, POS_VU) @@ -109,8 +113,8 @@ def test_position_country(): res_et1_et2 = RegressionV3.compare(et1, et2) res_et1_vu1 = RegressionV3.compare(et1, vu1) assert res_et1_et2.score > res_et1_vu1.score, (res_et1_et2, res_et1_vu1) - assert res_et1_et2.score > 0.5, res_et1_et2 - assert res_et1_vu1.score < 0.5, res_et1_vu1 + assert res_et1_et2.score > 0.3, res_et1_et2 + assert res_et1_vu1.score < 0.2, res_et1_vu1 def test_name_country(): @@ -128,7 +132,7 @@ def test_name_country(): data["id"] = "mike2" e2 = Entity.from_dict(model, data) res = RegressionV3.compare(e1, e2) - assert 0.92 < res.score < 0.93, res + assert 0.91 < res.score < 0.93, res def test_name_match():