From a4f53b01f52f1d433207f033ddf477dc1210fbc0 Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Thu, 15 Sep 2022 09:18:09 +0100 Subject: [PATCH 1/9] roc_chart_from_labels_column is working --- splink/accuracy.py | 77 ++++++++++++++++++++++++++++++++++++---------- splink/linker.py | 27 ++++++++++++++++ 2 files changed, 87 insertions(+), 17 deletions(-) diff --git a/splink/accuracy.py b/splink/accuracy.py index 2c8e20db43..038daaa880 100644 --- a/splink/accuracy.py +++ b/splink/accuracy.py @@ -1,4 +1,5 @@ -from copy import copy +from copy import deepcopy + from .block_from_labels import block_from_labels from .comparison_vector_values import compute_comparison_vector_values_sql @@ -213,6 +214,41 @@ def roc_table( return df_truth_space_table +def roc_table_from_labels_column( + linker, label_colname, threshold_actual=0.5, match_weight_round_to_nearest=None +): + + new_matchkey = len(linker._settings_obj._blocking_rules_to_generate_predictions) + + df_predict = _predict_from_label_column_sql( + linker, + label_colname, + ) + + sql = f""" + select + cast(({label_colname}_l = {label_colname}_r) as float) as clerical_match_score, + not (cast(match_key as int) = {new_matchkey}) + as found_by_blocking_rules, + * + from {df_predict.physical_name} + """ + + linker._enqueue_sql(sql, "__splink__labels_with_predictions") + + # c_P and c_N are clerical positive and negative, respectively + sqls = truth_space_table_from_labels_with_predictions_sqls( + threshold_actual, match_weight_round_to_nearest + ) + + for sql in sqls: + linker._enqueue_sql(sql["sql"], sql["output_table_name"]) + + df_truth_space_table = linker._execute_sql_pipeline() + + return df_truth_space_table + + def predictions_from_sample_of_pairwise_labels_sql(linker, labels_tablename): sqls = block_from_labels(linker, labels_tablename) @@ -280,20 +316,12 @@ def prediction_errors_from_labels_table( return linker._execute_sql_pipeline() -# from splink.linker import Linker +def _predict_from_label_column_sql(linker, label_colname): - -def prediction_errors_from_label_column( - linker, - label_colname, - include_false_positives=True, - include_false_negatives=True, - threshold=0.5, -): # In the case of labels, we use them to block # In the case we have a label column, we want to apply the model's blocking rules # but add in blocking on the label colname - + linker = deepcopy(linker) settings = linker._settings_obj brs = settings._blocking_rules_to_generate_predictions @@ -304,19 +332,38 @@ def prediction_errors_from_label_column( # Need the label colname to be in additional columns to retain add_cols = settings._additional_columns_to_retain_list - add_columns_to_restore = copy(add_cols) + if label_colname not in add_cols: settings._additional_columns_to_retain_list.append(label_colname) # Now we want to create predictions df_predict = linker.predict() + return df_predict + + +def prediction_errors_from_label_column( + linker, + label_colname, + include_false_positives=True, + include_false_negatives=True, + threshold=0.5, +): + + df_predict = _predict_from_label_column_sql( + linker, + label_colname, + ) + # Clerical match score is 1 where the label_colname is equal else zero + # _predict_from_label_column_sql will add a match key for matching on labels + new_matchkey = len(linker._settings_obj._blocking_rules_to_generate_predictions) + sql = f""" select cast(({label_colname}_l = {label_colname}_r) as float) as clerical_match_score, - not (cast(match_key as int) = {label_blocking_rule.match_key}) + not (cast(match_key as int) = {new_matchkey}) as found_by_blocking_rules, * from {df_predict.physical_name} @@ -358,8 +405,4 @@ def prediction_errors_from_label_column( predictions = linker._execute_sql_pipeline() - # Remove the blocking rule we added and restore original add cols to ret - brs.pop() - settings._additional_columns_to_retain_list = add_columns_to_restore - return predictions diff --git a/splink/linker.py b/splink/linker.py index 968d83851e..33ce5a4c42 100644 --- a/splink/linker.py +++ b/splink/linker.py @@ -46,6 +46,7 @@ roc_table, prediction_errors_from_labels_table, prediction_errors_from_label_column, + roc_table_from_labels_column, ) from .match_weights_histogram import histogram_data @@ -1435,6 +1436,32 @@ def roc_table_from_labels( match_weight_round_to_nearest=match_weight_round_to_nearest, ) + def roc_table_from_labels_column( + self, + labels_column_name, + threshold_actual=0.5, + match_weight_round_to_nearest: float = None, + ): + return roc_table_from_labels_column( + self, labels_column_name, threshold_actual, match_weight_round_to_nearest + ) + + def roc_chart_from_labels_column( + self, + labels_column_name, + threshold_actual=0.5, + match_weight_round_to_nearest: float = None, + ): + + df_truth_space = roc_table_from_labels_column( + self, + labels_column_name, + threshold_actual=threshold_actual, + match_weight_round_to_nearest=match_weight_round_to_nearest, + ) + recs = df_truth_space.as_record_dict() + return roc_chart(recs) + def match_weights_histogram( self, df_predict: SplinkDataFrame, target_bins: int = 30, width=600, height=250 ): From 7f6b180007984ae78928526ddb4ed2b74d30430f Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Thu, 15 Sep 2022 19:08:45 +0100 Subject: [PATCH 2/9] works, now need tests --- pyproject.toml | 2 +- splink/__init__.py | 4 +- splink/accuracy.py | 4 +- splink/linker.py | 106 ++++++++++++++++++++++++++------------------- 4 files changed, 65 insertions(+), 51 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c567d35c90..f2badd1f7f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "splink" -version = "3.2.1" +version = "3.3.0" description = "Fast probabilistic data linkage at scale" authors = ["Robin Linacre ", "Sam Lindsay", "Theodore Manassis", "Tom Hepworth"] license = "MIT" diff --git a/splink/__init__.py b/splink/__init__.py index f47d6ed274..de8aed319e 100644 --- a/splink/__init__.py +++ b/splink/__init__.py @@ -1,3 +1 @@ -import pkg_resources - -__version__ = pkg_resources.require("splink")[0].version +__version__ = "3.3.0" diff --git a/splink/accuracy.py b/splink/accuracy.py index 038daaa880..d259423461 100644 --- a/splink/accuracy.py +++ b/splink/accuracy.py @@ -185,7 +185,7 @@ def truth_space_table_from_labels_with_predictions_sqls( return sqls -def roc_table( +def truth_space_table_from_labels_table( linker, labels_tablename, threshold_actual=0.5, match_weight_round_to_nearest=None ): @@ -214,7 +214,7 @@ def roc_table( return df_truth_space_table -def roc_table_from_labels_column( +def truth_space_table_from_labels_column( linker, label_colname, threshold_actual=0.5, match_weight_round_to_nearest=None ): diff --git a/splink/linker.py b/splink/linker.py index 33ce5a4c42..9d062885b8 100644 --- a/splink/linker.py +++ b/splink/linker.py @@ -43,10 +43,10 @@ from .vertically_concatenate import vertically_concatenate_sql from .m_from_labels import estimate_m_from_pairwise_labels from .accuracy import ( - roc_table, + truth_space_table_from_labels_table, prediction_errors_from_labels_table, prediction_errors_from_label_column, - roc_table_from_labels_column, + truth_space_table_from_labels_column, ) from .match_weights_histogram import histogram_data @@ -1273,13 +1273,14 @@ def estimate_m_from_pairwise_labels(self, table_name): self._initialise_df_concat_with_tf(materialise=True) estimate_m_from_pairwise_labels(self, table_name) - def roc_chart_from_labels( + def truth_space_table_from_labels_table( self, labels_tablename, threshold_actual=0.5, match_weight_round_to_nearest: float = None, - ): - """Generate a ROC chart from labelled (ground truth) data. + ) -> SplinkDataFrame: + """Generate truth statistics (false positive etc.) for each threshold value of + match_probability, suitable for plotting a ROC chart. The table of labels should be in the following format, and should be registered with your database: @@ -1310,33 +1311,34 @@ def roc_chart_from_labels( >>> # DuckDBLinker >>> labels = pd.read_csv("my_labels.csv") >>> linker._con.register("labels", labels) - >>> linker.roc_chart_from_labels("labels") + >>> linker.roc_table_from_labels("labels") >>> >>> # SparkLinker >>> labels = spark.read.csv("my_labels.csv", header=True) >>> labels.createDataFrame("labels") - >>> linker.roc_chart_from_labels("labels") - + >>> linker.roc_table_from_labels("labels") Returns: - VegaLite: A VegaLite chart object. See altair.vegalite.v4.display.VegaLite. - The vegalite spec is available as a dictionary using the `spec` - attribute. + SplinkDataFrame: Table of truth statistics """ - df_truth_space = roc_table( + + return truth_space_table_from_labels_table( self, labels_tablename, threshold_actual=threshold_actual, match_weight_round_to_nearest=match_weight_round_to_nearest, ) - recs = df_truth_space.as_record_dict() - return roc_chart(recs) - def precision_recall_chart_from_labels(self, labels_tablename): - """Generate a precision-recall chart from labelled (ground truth) data. + def roc_chart_from_labels_table( + self, + labels_tablename, + threshold_actual=0.5, + match_weight_round_to_nearest: float = None, + ): + """Generate a ROC chart from labelled (ground truth) data. The table of labels should be in the following format, and should be registered - as a table with your database: + with your database: |source_dataset_l|unique_id_l|source_dataset_r|unique_id_r|clerical_match_score| |----------------|-----------|----------------|-----------|--------------------| @@ -1359,16 +1361,17 @@ def precision_recall_chart_from_labels(self, labels_tablename): are rounded. When large numbers of labels are provided, this is sometimes necessary to reduce the size of the ROC table, and therefore the number of points plotted on the ROC chart. Defaults to None. + Examples: >>> # DuckDBLinker >>> labels = pd.read_csv("my_labels.csv") >>> linker._con.register("labels", labels) - >>> linker.precision_recall_chart_from_labels("labels") + >>> linker.roc_chart_from_labels("labels") >>> >>> # SparkLinker >>> labels = spark.read.csv("my_labels.csv", header=True) >>> labels.createDataFrame("labels") - >>> linker.precision_recall_chart_from_labels("labels") + >>> linker.roc_chart_from_labels("labels") Returns: @@ -1376,21 +1379,20 @@ def precision_recall_chart_from_labels(self, labels_tablename): The vegalite spec is available as a dictionary using the `spec` attribute. """ - df_truth_space = roc_table(self, labels_tablename) + df_truth_space = truth_space_table_from_labels_table( + self, + labels_tablename, + threshold_actual=threshold_actual, + match_weight_round_to_nearest=match_weight_round_to_nearest, + ) recs = df_truth_space.as_record_dict() - return precision_recall_chart(recs) + return roc_chart(recs) - def roc_table_from_labels( - self, - labels_tablename, - threshold_actual=0.5, - match_weight_round_to_nearest: float = None, - ) -> SplinkDataFrame: - """Generate truth statistics (false positive etc.) for each threshold value of - match_probability, suitable for plotting a ROC chart. + def precision_recall_chart_from_labels_table(self, labels_tablename): + """Generate a precision-recall chart from labelled (ground truth) data. The table of labels should be in the following format, and should be registered - with your database: + as a table with your database: |source_dataset_l|unique_id_l|source_dataset_r|unique_id_r|clerical_match_score| |----------------|-----------|----------------|-----------|--------------------| @@ -1413,36 +1415,34 @@ def roc_table_from_labels( are rounded. When large numbers of labels are provided, this is sometimes necessary to reduce the size of the ROC table, and therefore the number of points plotted on the ROC chart. Defaults to None. - Examples: >>> # DuckDBLinker >>> labels = pd.read_csv("my_labels.csv") >>> linker._con.register("labels", labels) - >>> linker.roc_table_from_labels("labels") + >>> linker.precision_recall_chart_from_labels("labels") >>> >>> # SparkLinker >>> labels = spark.read.csv("my_labels.csv", header=True) >>> labels.createDataFrame("labels") - >>> linker.roc_table_from_labels("labels") + >>> linker.precision_recall_chart_from_labels("labels") + Returns: - SplinkDataFrame: Table of truth statistics + VegaLite: A VegaLite chart object. See altair.vegalite.v4.display.VegaLite. + The vegalite spec is available as a dictionary using the `spec` + attribute. """ + df_truth_space = truth_space_table_from_labels_table(self, labels_tablename) + recs = df_truth_space.as_record_dict() + return precision_recall_chart(recs) - return roc_table( - self, - labels_tablename, - threshold_actual=threshold_actual, - match_weight_round_to_nearest=match_weight_round_to_nearest, - ) - - def roc_table_from_labels_column( + def truth_space_table_from_labels_column( self, labels_column_name, threshold_actual=0.5, match_weight_round_to_nearest: float = None, ): - return roc_table_from_labels_column( + return truth_space_table_from_labels_column( self, labels_column_name, threshold_actual, match_weight_round_to_nearest ) @@ -1453,7 +1453,7 @@ def roc_chart_from_labels_column( match_weight_round_to_nearest: float = None, ): - df_truth_space = roc_table_from_labels_column( + df_truth_space = truth_space_table_from_labels_column( self, labels_column_name, threshold_actual=threshold_actual, @@ -1462,6 +1462,22 @@ def roc_chart_from_labels_column( recs = df_truth_space.as_record_dict() return roc_chart(recs) + def precision_recall_chart_from_labels_column( + self, + labels_column_name, + threshold_actual=0.5, + match_weight_round_to_nearest: float = None, + ): + + df_truth_space = truth_space_table_from_labels_column( + self, + labels_column_name, + threshold_actual=threshold_actual, + match_weight_round_to_nearest=match_weight_round_to_nearest, + ) + recs = df_truth_space.as_record_dict() + return precision_recall_chart(recs) + def match_weights_histogram( self, df_predict: SplinkDataFrame, target_bins: int = 30, width=600, height=250 ): @@ -2030,7 +2046,7 @@ def prediction_errors_from_labels_table( threshold, ) - def prediction_errors_from_label_column( + def prediction_errors_from_labels_column( self, label_colname, include_false_positives=True, From 9830eb9f1270b198fd87007e2811f72edaf7ac7a Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Thu, 15 Sep 2022 19:10:16 +0100 Subject: [PATCH 3/9] reorder functions into logical groups --- splink/linker.py | 118 +++++++++++++++++++++++------------------------ 1 file changed, 59 insertions(+), 59 deletions(-) diff --git a/splink/linker.py b/splink/linker.py index 9d062885b8..a8d51f157f 100644 --- a/splink/linker.py +++ b/splink/linker.py @@ -1436,6 +1436,35 @@ def precision_recall_chart_from_labels_table(self, labels_tablename): recs = df_truth_space.as_record_dict() return precision_recall_chart(recs) + def prediction_errors_from_labels_table( + self, + labels_tablename, + include_false_positives=True, + include_false_negatives=True, + threshold=0.5, + ): + """Generate a dataframe containing false positives and false negatives + based on the comparison between the clerical_match_score in the labels + table compared with the splink predicted match probability + + Args: + labels_tablename (str): Name of labels table + include_false_positives (bool, optional): Defaults to True. + include_false_negatives (bool, optional): Defaults to True. + threshold (float, optional): Threshold above which a score is considered + to be a match. Defaults to 0.5. + + Returns: + SplinkDataFrame: Table containing false positives and negatives + """ + return prediction_errors_from_labels_table( + self, + labels_tablename, + include_false_positives, + include_false_negatives, + threshold, + ) + def truth_space_table_from_labels_column( self, labels_column_name, @@ -1478,6 +1507,36 @@ def precision_recall_chart_from_labels_column( recs = df_truth_space.as_record_dict() return precision_recall_chart(recs) + def prediction_errors_from_labels_column( + self, + label_colname, + include_false_positives=True, + include_false_negatives=True, + threshold=0.5, + ): + """Generate a dataframe containing false positives and false negatives + based on the comparison between the splink match probability and the + labels column. A label column is a column in the input dataset that contains + the 'ground truth' cluster to which the record belongs + + Args: + label_colname (str): Name of labels column in input data + include_false_positives (bool, optional): Defaults to True. + include_false_negatives (bool, optional): Defaults to True. + threshold (float, optional): Threshold above which a score is considered + to be a match. Defaults to 0.5. + + Returns: + SplinkDataFrame: Table containing false positives and negatives + """ + return prediction_errors_from_label_column( + self, + label_colname, + include_false_positives, + include_false_negatives, + threshold, + ) + def match_weights_histogram( self, df_predict: SplinkDataFrame, target_bins: int = 30, width=600, height=250 ): @@ -2016,62 +2075,3 @@ def estimate_probability_two_random_records_match( " possible comparisons, we expect a total of around " f"{prob*cartesian:,.2f} matching pairs" ) - - def prediction_errors_from_labels_table( - self, - labels_tablename, - include_false_positives=True, - include_false_negatives=True, - threshold=0.5, - ): - """Generate a dataframe containing false positives and false negatives - based on the comparison between the clerical_match_score in the labels - table compared with the splink predicted match probability - - Args: - labels_tablename (str): Name of labels table - include_false_positives (bool, optional): Defaults to True. - include_false_negatives (bool, optional): Defaults to True. - threshold (float, optional): Threshold above which a score is considered - to be a match. Defaults to 0.5. - - Returns: - SplinkDataFrame: Table containing false positives and negatives - """ - return prediction_errors_from_labels_table( - self, - labels_tablename, - include_false_positives, - include_false_negatives, - threshold, - ) - - def prediction_errors_from_labels_column( - self, - label_colname, - include_false_positives=True, - include_false_negatives=True, - threshold=0.5, - ): - """Generate a dataframe containing false positives and false negatives - based on the comparison between the splink match probability and the - labels column. A label column is a column in the input dataset that contains - the 'ground truth' cluster to which the record belongs - - Args: - label_colname (str): Name of labels column in input data - include_false_positives (bool, optional): Defaults to True. - include_false_negatives (bool, optional): Defaults to True. - threshold (float, optional): Threshold above which a score is considered - to be a match. Defaults to 0.5. - - Returns: - SplinkDataFrame: Table containing false positives and negatives - """ - return prediction_errors_from_label_column( - self, - label_colname, - include_false_positives, - include_false_negatives, - threshold, - ) From 1c1aa59f14ed5ec8789c7b72d42fdf4e8db7c55a Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Thu, 15 Sep 2022 19:12:11 +0100 Subject: [PATCH 4/9] existing tests pass --- tests/test_accuracy.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/tests/test_accuracy.py b/tests/test_accuracy.py index c552386092..68f7fa227f 100644 --- a/tests/test_accuracy.py +++ b/tests/test_accuracy.py @@ -207,7 +207,7 @@ def test_roc_chart_dedupe_only(): linker._con.register("labels", df_labels) - linker.roc_chart_from_labels("labels") + linker.roc_chart_from_labels_table("labels") def test_roc_chart_link_and_dedupe(): @@ -240,7 +240,7 @@ def test_roc_chart_link_and_dedupe(): linker._con.register("labels", df_labels) - linker.roc_chart_from_labels("labels") + linker.roc_chart_from_labels_table("labels") def test_prediction_errors_from_labels_table(): @@ -390,7 +390,9 @@ def test_prediction_errors_from_labels_column(): # linker = DuckDBLinker(df, settings) - df_res = linker.prediction_errors_from_label_column("cluster").as_pandas_dataframe() + df_res = linker.prediction_errors_from_labels_column( + "cluster" + ).as_pandas_dataframe() df_res = df_res[["unique_id_l", "unique_id_r"]] records = list(df_res.to_records(index=False)) records = [tuple(p) for p in records] @@ -403,7 +405,7 @@ def test_prediction_errors_from_labels_column(): linker = DuckDBLinker(df, settings) - df_res = linker.prediction_errors_from_label_column( + df_res = linker.prediction_errors_from_labels_column( "cluster", include_false_positives=False ).as_pandas_dataframe() df_res = df_res[["unique_id_l", "unique_id_r"]] @@ -418,7 +420,7 @@ def test_prediction_errors_from_labels_column(): linker = DuckDBLinker(df, settings) - df_res = linker.prediction_errors_from_label_column( + df_res = linker.prediction_errors_from_labels_column( "cluster", include_false_negatives=False ).as_pandas_dataframe() df_res = df_res[["unique_id_l", "unique_id_r"]] From 61b40a638170d329ac70ac826bdd1986204554af Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Thu, 15 Sep 2022 19:30:32 +0100 Subject: [PATCH 5/9] update full duckdb test --- tests/test_full_example_duckdb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_full_example_duckdb.py b/tests/test_full_example_duckdb.py index 0397c9e9b7..93c8df4063 100644 --- a/tests/test_full_example_duckdb.py +++ b/tests/test_full_example_duckdb.py @@ -93,7 +93,7 @@ def test_full_example_duckdb(tmp_path): linker._con.register("labels", df_labels) # Finish create labels - linker.roc_chart_from_labels("labels") + linker.roc_chart_from_labels_table("labels") df_clusters = linker.cluster_pairwise_predictions_at_threshold(df_predict, 0.1) From 9fcdac4efd9c7840be38a59dbbf6e34df52ba0cf Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Thu, 15 Sep 2022 19:51:18 +0100 Subject: [PATCH 6/9] add tests --- tests/test_accuracy.py | 134 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 133 insertions(+), 1 deletion(-) diff --git a/tests/test_accuracy.py b/tests/test_accuracy.py index 68f7fa227f..2ea33d8095 100644 --- a/tests/test_accuracy.py +++ b/tests/test_accuracy.py @@ -16,7 +16,7 @@ from splink.predict import predict_from_comparison_vectors_sqls -def test_scored_labels(): +def test_scored_labels_table(): df = pd.read_csv("./tests/datasets/fake_1000_from_splink_demos.csv") df = df.head(5) @@ -432,3 +432,135 @@ def test_prediction_errors_from_labels_column(): assert (4, 5) in records # FP assert (1, 2) not in records # TP assert (1, 5) not in records # TN + + +def test_truth_space_table_from_labels_column_dedupe_only(): + + data = [ + {"unique_id": 1, "first_name": "john", "cluster": 1}, + {"unique_id": 2, "first_name": "john", "cluster": 1}, + {"unique_id": 3, "first_name": "john", "cluster": 1}, + {"unique_id": 4, "first_name": "john", "cluster": 2}, + {"unique_id": 5, "first_name": "edith", "cluster": 3}, + {"unique_id": 6, "first_name": "mary", "cluster": 3}, + ] + + df = pd.DataFrame(data) + + settings = { + "link_type": "dedupe_only", + "probability_two_random_records_match": 0.5, + "blocking_rules_to_generate_predictions": [ + "1=1", + ], + "comparisons": [ + { + "output_column_name": "First name", + "comparison_levels": [ + { + "sql_condition": "first_name_l IS NULL OR first_name_r IS NULL", + "label_for_charts": "Null", + "is_null_level": True, + }, + { + "sql_condition": "first_name_l = first_name_r", + "label_for_charts": "Exact match", + "m_probability": 0.9, + "u_probability": 0.1, + }, + { + "sql_condition": "ELSE", + "label_for_charts": "All other comparisons", + "m_probability": 0.1, + "u_probability": 0.9, + }, + ], + }, + ], + } + + linker = DuckDBLinker(df, settings) + + tt = linker.truth_space_table_from_labels_column("cluster").as_record_dict() + # Truth threshold -3.17, meaning all comparisons get classified as positive + truth_dict = tt[0] + assert truth_dict["TP"] == 4 + assert truth_dict["FP"] == 11 + assert truth_dict["TN"] == 0 + assert truth_dict["FN"] == 0 + + # Truth threshold 3.17, meaning only comparisons where forename match get classified + # as positive + truth_dict = tt[1] + assert truth_dict["TP"] == 3 + assert truth_dict["FP"] == 3 + assert truth_dict["TN"] == 8 + assert truth_dict["FN"] == 1 + + +def test_truth_space_table_from_labels_column_link_only(): + + data_left = [ + {"unique_id": 1, "first_name": "john", "ground_truth": 1}, + {"unique_id": 2, "first_name": "mary", "ground_truth": 2}, + {"unique_id": 3, "first_name": "edith", "ground_truth": 3}, + ] + + data_right = [ + {"unique_id": 1, "first_name": "john", "ground_truth": 1}, + {"unique_id": 2, "first_name": "john", "ground_truth": 2}, + {"unique_id": 3, "first_name": "eve", "ground_truth": 3}, + ] + + df_left = pd.DataFrame(data_left) + df_right = pd.DataFrame(data_right) + + settings = { + "link_type": "link_only", + "probability_two_random_records_match": 0.5, + "blocking_rules_to_generate_predictions": [ + "1=1", + ], + "comparisons": [ + { + "output_column_name": "First name", + "comparison_levels": [ + { + "sql_condition": "first_name_l IS NULL OR first_name_r IS NULL", + "label_for_charts": "Null", + "is_null_level": True, + }, + { + "sql_condition": "first_name_l = first_name_r", + "label_for_charts": "Exact match", + "m_probability": 0.9, + "u_probability": 0.1, + }, + { + "sql_condition": "ELSE", + "label_for_charts": "All other comparisons", + "m_probability": 0.1, + "u_probability": 0.9, + }, + ], + }, + ], + } + + linker = DuckDBLinker([df_left, df_right], settings) + + tt = linker.truth_space_table_from_labels_column("ground_truth").as_record_dict() + # Truth threshold -3.17, meaning all comparisons get classified as positive + truth_dict = tt[0] + assert truth_dict["TP"] == 3 + assert truth_dict["FP"] == 6 + assert truth_dict["TN"] == 0 + assert truth_dict["FN"] == 0 + + # Truth threshold 3.17, meaning only comparisons where forename match get classified + # as positive + truth_dict = tt[1] + assert truth_dict["TP"] == 1 + assert truth_dict["FP"] == 1 + assert truth_dict["TN"] == 5 + assert truth_dict["FN"] == 2 From ce33a1542caca86754cc93b08b6327ad3ca5dd06 Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Thu, 15 Sep 2022 20:00:56 +0100 Subject: [PATCH 7/9] document new functions --- splink/linker.py | 80 ++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 74 insertions(+), 6 deletions(-) diff --git a/splink/linker.py b/splink/linker.py index a8d51f157f..6f0cf796a0 100644 --- a/splink/linker.py +++ b/splink/linker.py @@ -1311,12 +1311,12 @@ def truth_space_table_from_labels_table( >>> # DuckDBLinker >>> labels = pd.read_csv("my_labels.csv") >>> linker._con.register("labels", labels) - >>> linker.roc_table_from_labels("labels") + >>> linker.truth_space_table_from_labels_table("labels") >>> >>> # SparkLinker >>> labels = spark.read.csv("my_labels.csv", header=True) >>> labels.createDataFrame("labels") - >>> linker.roc_table_from_labels("labels") + >>> linker.truth_space_table_from_labels_table("labels") Returns: SplinkDataFrame: Table of truth statistics @@ -1366,12 +1366,12 @@ def roc_chart_from_labels_table( >>> # DuckDBLinker >>> labels = pd.read_csv("my_labels.csv") >>> linker._con.register("labels", labels) - >>> linker.roc_chart_from_labels("labels") + >>> linker.roc_chart_from_labels_table("labels") >>> >>> # SparkLinker >>> labels = spark.read.csv("my_labels.csv", header=True) >>> labels.createDataFrame("labels") - >>> linker.roc_chart_from_labels("labels") + >>> linker.roc_chart_from_labels_table("labels") Returns: @@ -1419,12 +1419,12 @@ def precision_recall_chart_from_labels_table(self, labels_tablename): >>> # DuckDBLinker >>> labels = pd.read_csv("my_labels.csv") >>> linker._con.register("labels", labels) - >>> linker.precision_recall_chart_from_labels("labels") + >>> linker.precision_recall_chart_from_labels_table("labels") >>> >>> # SparkLinker >>> labels = spark.read.csv("my_labels.csv", header=True) >>> labels.createDataFrame("labels") - >>> linker.precision_recall_chart_from_labels("labels") + >>> linker.precision_recall_chart_from_labels_table("labels") Returns: @@ -1471,6 +1471,29 @@ def truth_space_table_from_labels_column( threshold_actual=0.5, match_weight_round_to_nearest: float = None, ): + """Generate truth statistics (false positive etc.) for each threshold value of + match_probability, suitable for plotting a ROC chart. + + Your labels_column_name should include the ground truth cluster (unique + identifier) that groups entities which are the same + + Args: + labels_tablename (str): Name of table containing labels in the database + threshold_actual (float, optional): Where the `clerical_match_score` + provided by the user is a probability rather than binary, this value + is used as the threshold to classify `clerical_match_score`s as binary + matches or non matches. Defaults to 0.5. + match_weight_round_to_nearest (float, optional): When provided, thresholds + are rounded. When large numbers of labels are provided, this is + sometimes necessary to reduce the size of the ROC table, and therefore + the number of points plotted on the ROC chart. Defaults to None. + + Examples: + >>> linker.truth_space_table_from_labels_column("cluster") + + Returns: + SplinkDataFrame: Table of truth statistics + """ return truth_space_table_from_labels_column( self, labels_column_name, threshold_actual, match_weight_round_to_nearest ) @@ -1481,6 +1504,29 @@ def roc_chart_from_labels_column( threshold_actual=0.5, match_weight_round_to_nearest: float = None, ): + """Generate a ROC chart from ground truth data, whereby the ground truth + is in a column in the input dataset called `labels_column_name` + + Args: + labels_column_name (str): Column name containing labels in the input table + threshold_actual (float, optional): Where the `clerical_match_score` + provided by the user is a probability rather than binary, this value + is used as the threshold to classify `clerical_match_score`s as binary + matches or non matches. Defaults to 0.5. + match_weight_round_to_nearest (float, optional): When provided, thresholds + are rounded. When large numbers of labels are provided, this is + sometimes necessary to reduce the size of the ROC table, and therefore + the number of points plotted on the ROC chart. Defaults to None. + + Examples: + >>> linker.roc_chart_from_labels_column("labels") + + + Returns: + VegaLite: A VegaLite chart object. See altair.vegalite.v4.display.VegaLite. + The vegalite spec is available as a dictionary using the `spec` + attribute. + """ df_truth_space = truth_space_table_from_labels_column( self, @@ -1497,6 +1543,28 @@ def precision_recall_chart_from_labels_column( threshold_actual=0.5, match_weight_round_to_nearest: float = None, ): + """Generate a precision-recall chart from ground truth data, whereby the ground + truth is in a column in the input dataset called `labels_column_name` + + Args: + labels_column_name (str): Column name containing labels in the input table + threshold_actual (float, optional): Where the `clerical_match_score` + provided by the user is a probability rather than binary, this value + is used as the threshold to classify `clerical_match_score`s as binary + matches or non matches. Defaults to 0.5. + match_weight_round_to_nearest (float, optional): When provided, thresholds + are rounded. When large numbers of labels are provided, this is + sometimes necessary to reduce the size of the ROC table, and therefore + the number of points plotted on the ROC chart. Defaults to None. + Examples: + >>> linker.precision_recall_chart_from_labels_column("ground_truth") + + + Returns: + VegaLite: A VegaLite chart object. See altair.vegalite.v4.display.VegaLite. + The vegalite spec is available as a dictionary using the `spec` + attribute. + """ df_truth_space = truth_space_table_from_labels_column( self, From 65586d05a6523ae1f6535d85d79fa65cef75ab0a Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Thu, 15 Sep 2022 20:02:57 +0100 Subject: [PATCH 8/9] update docs mds --- docs/linker.md | 9 ++++++--- docs/linkerqa.md | 9 ++++++--- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/docs/linker.md b/docs/linker.md index 5732caefaa..06a0fbf2ad 100644 --- a/docs/linker.md +++ b/docs/linker.md @@ -31,15 +31,18 @@ tags: - match_weights_chart - missingness_chart - parameter_estimate_comparisons_chart - - precision_recall_chart_from_labels + - precision_recall_chart_from_labels_column + - precision_recall_chart_from_labels_table - predict - prediction_errors_from_label_column - prediction_errors_from_labels_table - profile_columns - - roc_chart_from_labels - - roc_table_from_labels + - roc_chart_from_labels_column + - roc_chart_from_labels_table - save_settings_to_json - train_m_from_pairwise_labels + - truth_space_table_from_labels_column + - truth_space_table_from_labels_table - unlinkables_chart - waterfall_chart rendering: diff --git a/docs/linkerqa.md b/docs/linkerqa.md index 39753d9a9a..aa735a9920 100644 --- a/docs/linkerqa.md +++ b/docs/linkerqa.md @@ -15,11 +15,14 @@ tags: - match_weight_histogram - match_weights_chart - parameter_estimate_comparisons_chart - - precision_recall_chart_from_labels + - precision_recall_chart_from_labels_column + - precision_recall_chart_from_labels_table - prediction_errors_from_label_column - prediction_errors_from_labels_table - - roc_chart_from_labels - - roc_table_from_labels + - roc_chart_from_labels_column + - roc_chart_from_labels_table + - truth_space_table_from_labels_column + - truth_space_table_from_labels_table - unlinkables_chart - waterfall_chart rendering: From 135e1e59d26f3e37ce29d6e72f7c305c697961c3 Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Thu, 15 Sep 2022 20:04:07 +0100 Subject: [PATCH 9/9] 3.3.0.dev01 --- pyproject.toml | 2 +- splink/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f2badd1f7f..722ca1e62c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "splink" -version = "3.3.0" +version = "3.3.0.dev01" description = "Fast probabilistic data linkage at scale" authors = ["Robin Linacre ", "Sam Lindsay", "Theodore Manassis", "Tom Hepworth"] license = "MIT" diff --git a/splink/__init__.py b/splink/__init__.py index de8aed319e..80eb0ad7d7 100644 --- a/splink/__init__.py +++ b/splink/__init__.py @@ -1 +1 @@ -__version__ = "3.3.0" +__version__ = "3.3.0.dev01"