From 4b71b929c6b3b5c19417f02964e9f98e222218b1 Mon Sep 17 00:00:00 2001 From: mbaak Date: Fri, 16 Feb 2024 15:34:33 +0100 Subject: [PATCH] ENH: make name of positive_set column full configurable --- emm/aggregation/base_entity_aggregation.py | 6 ++++-- emm/data/create_data.py | 19 ++++++++++++++++--- emm/data/prepare_name_pairs.py | 6 ++++-- emm/pipeline/pandas_entity_matching.py | 6 ++---- emm/supervised_model/base_supervised_model.py | 2 +- emm/threshold/threshold_decision.py | 9 +++++++-- 6 files changed, 34 insertions(+), 14 deletions(-) diff --git a/emm/aggregation/base_entity_aggregation.py b/emm/aggregation/base_entity_aggregation.py index 82ad8fd..5cd818d 100644 --- a/emm/aggregation/base_entity_aggregation.py +++ b/emm/aggregation/base_entity_aggregation.py @@ -143,6 +143,7 @@ def __init__( gt_preprocessed_col: str = "gt_preprocessed", aggregation_method: Literal["max_frequency_nm_score", "mean_score"] = "max_frequency_nm_score", blacklist: list | None = None, + positive_set_col: str = "positive_set", ) -> None: self.score_col = score_col self.account_col = account_col @@ -158,6 +159,7 @@ def __init__( self.gt_preprocessed_col = gt_preprocessed_col self.aggregation_method = aggregation_method self.blacklist = blacklist or [] + self.positive_set_col = positive_set_col # perform very basic preprocessing to blacklist, remove abbreviations, to lower, etc. self.blacklist = [preprocess(name) for name in self.blacklist] @@ -171,8 +173,8 @@ def get_group(self, dataframe) -> list[str]: group += [self.index_col] # Useful for collect_metrics() - if "positive_set" in dataframe.columns: - group += ["positive_set"] + if self.positive_set_col in dataframe.columns: + group += [self.positive_set_col] # Notice we lose the name_to_match 'uid' column here return group diff --git a/emm/data/create_data.py b/emm/data/create_data.py index 7c818f2..2b9959a 100644 --- a/emm/data/create_data.py +++ b/emm/data/create_data.py @@ -259,6 +259,7 @@ def pandas_create_noised_data( name_col="Name", index_col="Index", random_seed=None, + positive_set_col="positive_set", ): """Create pandas noised dataset based on company names from kvk. @@ -274,6 +275,7 @@ def pandas_create_noised_data( name_col: name column in csv file index_col: name-id column in csv file (optional) random_seed: seed to use + positive_set_col: name of positive set column in csv file, default is "positive_set". Returns: ground_truth and companies_noised_pd pandas dataframes @@ -331,7 +333,7 @@ def pandas_create_noised_data( pos = shuffled_ids[: len(shuffled_ids) // 2] # ground truth only contains companies in positive set is_in_pos = companies_pd["Index"].isin(pos) - companies_pd["positive_set"] = is_in_pos + companies_pd[positive_set_col] = is_in_pos if split_pos_neg: ground_truth = companies_pd[is_in_pos].copy() @@ -387,6 +389,7 @@ def create_noised_data( index_col="Index", ret_posneg=False, random_seed=None, + positive_set_col="positive_set", ): """Create spark noised dataset based on company names from kvk. @@ -404,6 +407,7 @@ def create_noised_data( index_col: name-id column in csv file (optional) ret_posneg: if true also return original positive and negative spark true datasets random_seed: seed to use + positive_set_col: name of positive set column in csv file, default is "positive_set". Returns: ground_truth and companies_noised_pd spark dataframes @@ -412,8 +416,17 @@ def create_noised_data( # location of local sample of kvk unit test dataset; downloads the dataset in case not present. data_path, _ = retrieve_kvk_test_sample() + # name_col and index_col get renamed to Name and Index (ground_truth_pd, companies_noised_pd, positive_noised_pd, negative_pd) = pandas_create_noised_data( - noise_level, noise_type, noise_count, split_pos_neg, data_path, name_col, index_col, random_seed + noise_level, + noise_type, + noise_count, + split_pos_neg, + data_path, + name_col, + index_col, + random_seed, + positive_set_col, ) # Sparkify dataframes @@ -424,7 +437,7 @@ def create_noised_data( StructField("amount", FloatType(), True), StructField("counterparty_account_count_distinct", IntegerType(), nullable=True), StructField("uid", IntegerType(), nullable=True), - StructField("positive_set", BooleanType(), True), + StructField(positive_set_col, BooleanType(), True), StructField("country", StringType(), nullable=True), StructField("account", StringType(), True), ] diff --git a/emm/data/prepare_name_pairs.py b/emm/data/prepare_name_pairs.py index 30f8db9..0c0b29f 100644 --- a/emm/data/prepare_name_pairs.py +++ b/emm/data/prepare_name_pairs.py @@ -87,7 +87,7 @@ def prepare_name_pairs_pd( candidates_pd["correct"] = candidates_pd[entity_id_col] == candidates_pd[gt_entity_id_col] # negative sample creation? - # if so, add positive_set column for negative sample creation + # if so, add positive_set_col column for negative sample creation rng = np.random.default_rng(random_seed) create_negative_sample_fraction = min(create_negative_sample_fraction, 1) create_negative_sample = create_negative_sample_fraction > 0 @@ -98,7 +98,9 @@ def prepare_name_pairs_pd( pos_ids = list(rng.choice(ids, n_positive, replace=False)) candidates_pd[positive_set_col] = candidates_pd[entity_id_col].isin(pos_ids) elif create_negative_sample and positive_set_col in candidates_pd.columns: - logger.info("create_negative_sample_fraction is set, but positive_set already defined; using the latter.") + logger.info( + f"create_negative_sample_fraction is set, but {positive_set_col} already defined; using the latter." + ) # We remove duplicates ground-truth name candidates in the pure string similarity model (i.e. when WITHOUT_RANK_FEATURES==True) # because we noticed that when we don't do this, the model learns that perfect match are worst than non-perfect match (like different legal form) diff --git a/emm/pipeline/pandas_entity_matching.py b/emm/pipeline/pandas_entity_matching.py index c763c83..0e3d786 100644 --- a/emm/pipeline/pandas_entity_matching.py +++ b/emm/pipeline/pandas_entity_matching.py @@ -538,10 +538,8 @@ def combine_sm_results(df: pd.DataFrame, sel_cand: pd.DataFrame, test_gt: pd.Dat res = df.join(sel_cand[["gt_entity_id", "gt_name", "gt_preprocessed", "nm_score", "score_0"]], how="left") res["nm_score"] = res["nm_score"].fillna(-1) res["score_0"] = res["score_0"].fillna(-1) - res["positive_set"] = res["id"].isin(test_gt["id"]) - res["correct"] = ((res["positive_set"]) & (res["id"] == res["gt_entity_id"])) | ( - (~res["positive_set"]) & (res["id"].isnull()) - ) + is_in_pos = res["id"].isin(test_gt["id"]) + res["correct"] = ((is_in_pos) & (res["id"] == res["gt_entity_id"])) | ((~is_in_pos) & (res["id"].isnull())) return res test_candidates = self.transform(test_names_to_match.copy()) diff --git a/emm/supervised_model/base_supervised_model.py b/emm/supervised_model/base_supervised_model.py index 6a94923..82602a8 100644 --- a/emm/supervised_model/base_supervised_model.py +++ b/emm/supervised_model/base_supervised_model.py @@ -256,7 +256,7 @@ def train_test_model( y = dataset["correct"].astype(str) + dataset["no_candidate"].astype(str) if positive_set_col in dataset.columns: - y += dataset["positive_set"].astype(str) + y += dataset[positive_set_col].astype(str) # Train test split with consistent name-to-match account (group) and with approximately same class balance y (stratified) # it is important to have all the name in the same account, for account matching after aggregation diff --git a/emm/threshold/threshold_decision.py b/emm/threshold/threshold_decision.py index f1cb0fe..6ce272b 100644 --- a/emm/threshold/threshold_decision.py +++ b/emm/threshold/threshold_decision.py @@ -72,6 +72,7 @@ def get_threshold_curves_parameters( score_col: str = "nm_score", aggregation_layer: bool = False, aggregation_method: str = "name_clustering", + positive_set_col: str = "positive_set", ) -> dict: """Get threshold decision curves @@ -80,12 +81,16 @@ def get_threshold_curves_parameters( score_col: which score column to use, default is 'nm_score'. For aggregation use 'agg_score'. aggregation_layer: use aggregation layer? default is False. aggregation_method: which aggregation method is used? 'name_clustering' or 'mean_score'. + positive_set_col: name of positive set column in best candidates df. default is 'positive_set' Returns: dictionary with threshold decision curves """ - best_positive_df = best_candidate_df[best_candidate_df.positive_set] - best_negative_df = best_candidate_df[~best_candidate_df.positive_set] + if positive_set_col not in best_candidate_df.columns: + msg = f"positive set column {positive_set_col} not in best_candidates df." + raise ValueError(msg) + best_positive_df = best_candidate_df[best_candidate_df[positive_set_col]] + best_negative_df = best_candidate_df[~best_candidate_df[positive_set_col]] n_positive_names_to_match = len(best_positive_df) name_sets = {"all": best_candidate_df, "positive": best_positive_df, "negative": best_negative_df}