diff --git a/docs/linker.md b/docs/linker.md index 5732caefaa..06a0fbf2ad 100644 --- a/docs/linker.md +++ b/docs/linker.md @@ -31,15 +31,18 @@ tags: - match_weights_chart - missingness_chart - parameter_estimate_comparisons_chart - - precision_recall_chart_from_labels + - precision_recall_chart_from_labels_column + - precision_recall_chart_from_labels_table - predict - prediction_errors_from_label_column - prediction_errors_from_labels_table - profile_columns - - roc_chart_from_labels - - roc_table_from_labels + - roc_chart_from_labels_column + - roc_chart_from_labels_table - save_settings_to_json - train_m_from_pairwise_labels + - truth_space_table_from_labels_column + - truth_space_table_from_labels_table - unlinkables_chart - waterfall_chart rendering: diff --git a/docs/linkerqa.md b/docs/linkerqa.md index 39753d9a9a..aa735a9920 100644 --- a/docs/linkerqa.md +++ b/docs/linkerqa.md @@ -15,11 +15,14 @@ tags: - match_weight_histogram - match_weights_chart - parameter_estimate_comparisons_chart - - precision_recall_chart_from_labels + - precision_recall_chart_from_labels_column + - precision_recall_chart_from_labels_table - prediction_errors_from_label_column - prediction_errors_from_labels_table - - roc_chart_from_labels - - roc_table_from_labels + - roc_chart_from_labels_column + - roc_chart_from_labels_table + - truth_space_table_from_labels_column + - truth_space_table_from_labels_table - unlinkables_chart - waterfall_chart rendering: diff --git a/pyproject.toml b/pyproject.toml index c567d35c90..722ca1e62c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "splink" -version = "3.2.1" +version = "3.3.0.dev01" description = "Fast probabilistic data linkage at scale" authors = ["Robin Linacre ", "Sam Lindsay", "Theodore Manassis", "Tom Hepworth"] license = "MIT" diff --git a/splink/__init__.py b/splink/__init__.py index f47d6ed274..80eb0ad7d7 100644 --- a/splink/__init__.py +++ b/splink/__init__.py @@ -1,3 +1 @@ -import pkg_resources - -__version__ = pkg_resources.require("splink")[0].version +__version__ = "3.3.0.dev01" diff --git a/splink/accuracy.py b/splink/accuracy.py index 2c8e20db43..d259423461 100644 --- a/splink/accuracy.py +++ b/splink/accuracy.py @@ -1,4 +1,5 @@ -from copy import copy +from copy import deepcopy + from .block_from_labels import block_from_labels from .comparison_vector_values import compute_comparison_vector_values_sql @@ -184,7 +185,7 @@ def truth_space_table_from_labels_with_predictions_sqls( return sqls -def roc_table( +def truth_space_table_from_labels_table( linker, labels_tablename, threshold_actual=0.5, match_weight_round_to_nearest=None ): @@ -213,6 +214,41 @@ def roc_table( return df_truth_space_table +def truth_space_table_from_labels_column( + linker, label_colname, threshold_actual=0.5, match_weight_round_to_nearest=None +): + + new_matchkey = len(linker._settings_obj._blocking_rules_to_generate_predictions) + + df_predict = _predict_from_label_column_sql( + linker, + label_colname, + ) + + sql = f""" + select + cast(({label_colname}_l = {label_colname}_r) as float) as clerical_match_score, + not (cast(match_key as int) = {new_matchkey}) + as found_by_blocking_rules, + * + from {df_predict.physical_name} + """ + + linker._enqueue_sql(sql, "__splink__labels_with_predictions") + + # c_P and c_N are clerical positive and negative, respectively + sqls = truth_space_table_from_labels_with_predictions_sqls( + threshold_actual, match_weight_round_to_nearest + ) + + for sql in sqls: + linker._enqueue_sql(sql["sql"], sql["output_table_name"]) + + df_truth_space_table = linker._execute_sql_pipeline() + + return df_truth_space_table + + def predictions_from_sample_of_pairwise_labels_sql(linker, labels_tablename): sqls = block_from_labels(linker, labels_tablename) @@ -280,20 +316,12 @@ def prediction_errors_from_labels_table( return linker._execute_sql_pipeline() -# from splink.linker import Linker +def _predict_from_label_column_sql(linker, label_colname): - -def prediction_errors_from_label_column( - linker, - label_colname, - include_false_positives=True, - include_false_negatives=True, - threshold=0.5, -): # In the case of labels, we use them to block # In the case we have a label column, we want to apply the model's blocking rules # but add in blocking on the label colname - + linker = deepcopy(linker) settings = linker._settings_obj brs = settings._blocking_rules_to_generate_predictions @@ -304,19 +332,38 @@ def prediction_errors_from_label_column( # Need the label colname to be in additional columns to retain add_cols = settings._additional_columns_to_retain_list - add_columns_to_restore = copy(add_cols) + if label_colname not in add_cols: settings._additional_columns_to_retain_list.append(label_colname) # Now we want to create predictions df_predict = linker.predict() + return df_predict + + +def prediction_errors_from_label_column( + linker, + label_colname, + include_false_positives=True, + include_false_negatives=True, + threshold=0.5, +): + + df_predict = _predict_from_label_column_sql( + linker, + label_colname, + ) + # Clerical match score is 1 where the label_colname is equal else zero + # _predict_from_label_column_sql will add a match key for matching on labels + new_matchkey = len(linker._settings_obj._blocking_rules_to_generate_predictions) + sql = f""" select cast(({label_colname}_l = {label_colname}_r) as float) as clerical_match_score, - not (cast(match_key as int) = {label_blocking_rule.match_key}) + not (cast(match_key as int) = {new_matchkey}) as found_by_blocking_rules, * from {df_predict.physical_name} @@ -358,8 +405,4 @@ def prediction_errors_from_label_column( predictions = linker._execute_sql_pipeline() - # Remove the blocking rule we added and restore original add cols to ret - brs.pop() - settings._additional_columns_to_retain_list = add_columns_to_restore - return predictions diff --git a/splink/linker.py b/splink/linker.py index 968d83851e..6f0cf796a0 100644 --- a/splink/linker.py +++ b/splink/linker.py @@ -43,9 +43,10 @@ from .vertically_concatenate import vertically_concatenate_sql from .m_from_labels import estimate_m_from_pairwise_labels from .accuracy import ( - roc_table, + truth_space_table_from_labels_table, prediction_errors_from_labels_table, prediction_errors_from_label_column, + truth_space_table_from_labels_column, ) from .match_weights_histogram import histogram_data @@ -1272,13 +1273,14 @@ def estimate_m_from_pairwise_labels(self, table_name): self._initialise_df_concat_with_tf(materialise=True) estimate_m_from_pairwise_labels(self, table_name) - def roc_chart_from_labels( + def truth_space_table_from_labels_table( self, labels_tablename, threshold_actual=0.5, match_weight_round_to_nearest: float = None, - ): - """Generate a ROC chart from labelled (ground truth) data. + ) -> SplinkDataFrame: + """Generate truth statistics (false positive etc.) for each threshold value of + match_probability, suitable for plotting a ROC chart. The table of labels should be in the following format, and should be registered with your database: @@ -1309,33 +1311,34 @@ def roc_chart_from_labels( >>> # DuckDBLinker >>> labels = pd.read_csv("my_labels.csv") >>> linker._con.register("labels", labels) - >>> linker.roc_chart_from_labels("labels") + >>> linker.truth_space_table_from_labels_table("labels") >>> >>> # SparkLinker >>> labels = spark.read.csv("my_labels.csv", header=True) >>> labels.createDataFrame("labels") - >>> linker.roc_chart_from_labels("labels") - + >>> linker.truth_space_table_from_labels_table("labels") Returns: - VegaLite: A VegaLite chart object. See altair.vegalite.v4.display.VegaLite. - The vegalite spec is available as a dictionary using the `spec` - attribute. + SplinkDataFrame: Table of truth statistics """ - df_truth_space = roc_table( + + return truth_space_table_from_labels_table( self, labels_tablename, threshold_actual=threshold_actual, match_weight_round_to_nearest=match_weight_round_to_nearest, ) - recs = df_truth_space.as_record_dict() - return roc_chart(recs) - def precision_recall_chart_from_labels(self, labels_tablename): - """Generate a precision-recall chart from labelled (ground truth) data. + def roc_chart_from_labels_table( + self, + labels_tablename, + threshold_actual=0.5, + match_weight_round_to_nearest: float = None, + ): + """Generate a ROC chart from labelled (ground truth) data. The table of labels should be in the following format, and should be registered - as a table with your database: + with your database: |source_dataset_l|unique_id_l|source_dataset_r|unique_id_r|clerical_match_score| |----------------|-----------|----------------|-----------|--------------------| @@ -1358,16 +1361,17 @@ def precision_recall_chart_from_labels(self, labels_tablename): are rounded. When large numbers of labels are provided, this is sometimes necessary to reduce the size of the ROC table, and therefore the number of points plotted on the ROC chart. Defaults to None. + Examples: >>> # DuckDBLinker >>> labels = pd.read_csv("my_labels.csv") >>> linker._con.register("labels", labels) - >>> linker.precision_recall_chart_from_labels("labels") + >>> linker.roc_chart_from_labels_table("labels") >>> >>> # SparkLinker >>> labels = spark.read.csv("my_labels.csv", header=True) >>> labels.createDataFrame("labels") - >>> linker.precision_recall_chart_from_labels("labels") + >>> linker.roc_chart_from_labels_table("labels") Returns: @@ -1375,21 +1379,20 @@ def precision_recall_chart_from_labels(self, labels_tablename): The vegalite spec is available as a dictionary using the `spec` attribute. """ - df_truth_space = roc_table(self, labels_tablename) + df_truth_space = truth_space_table_from_labels_table( + self, + labels_tablename, + threshold_actual=threshold_actual, + match_weight_round_to_nearest=match_weight_round_to_nearest, + ) recs = df_truth_space.as_record_dict() - return precision_recall_chart(recs) + return roc_chart(recs) - def roc_table_from_labels( - self, - labels_tablename, - threshold_actual=0.5, - match_weight_round_to_nearest: float = None, - ) -> SplinkDataFrame: - """Generate truth statistics (false positive etc.) for each threshold value of - match_probability, suitable for plotting a ROC chart. + def precision_recall_chart_from_labels_table(self, labels_tablename): + """Generate a precision-recall chart from labelled (ground truth) data. The table of labels should be in the following format, and should be registered - with your database: + as a table with your database: |source_dataset_l|unique_id_l|source_dataset_r|unique_id_r|clerical_match_score| |----------------|-----------|----------------|-----------|--------------------| @@ -1412,28 +1415,195 @@ def roc_table_from_labels( are rounded. When large numbers of labels are provided, this is sometimes necessary to reduce the size of the ROC table, and therefore the number of points plotted on the ROC chart. Defaults to None. - Examples: >>> # DuckDBLinker >>> labels = pd.read_csv("my_labels.csv") >>> linker._con.register("labels", labels) - >>> linker.roc_table_from_labels("labels") + >>> linker.precision_recall_chart_from_labels_table("labels") >>> >>> # SparkLinker >>> labels = spark.read.csv("my_labels.csv", header=True) >>> labels.createDataFrame("labels") - >>> linker.roc_table_from_labels("labels") + >>> linker.precision_recall_chart_from_labels_table("labels") + Returns: - SplinkDataFrame: Table of truth statistics + VegaLite: A VegaLite chart object. See altair.vegalite.v4.display.VegaLite. + The vegalite spec is available as a dictionary using the `spec` + attribute. """ + df_truth_space = truth_space_table_from_labels_table(self, labels_tablename) + recs = df_truth_space.as_record_dict() + return precision_recall_chart(recs) - return roc_table( + def prediction_errors_from_labels_table( + self, + labels_tablename, + include_false_positives=True, + include_false_negatives=True, + threshold=0.5, + ): + """Generate a dataframe containing false positives and false negatives + based on the comparison between the clerical_match_score in the labels + table compared with the splink predicted match probability + + Args: + labels_tablename (str): Name of labels table + include_false_positives (bool, optional): Defaults to True. + include_false_negatives (bool, optional): Defaults to True. + threshold (float, optional): Threshold above which a score is considered + to be a match. Defaults to 0.5. + + Returns: + SplinkDataFrame: Table containing false positives and negatives + """ + return prediction_errors_from_labels_table( self, labels_tablename, + include_false_positives, + include_false_negatives, + threshold, + ) + + def truth_space_table_from_labels_column( + self, + labels_column_name, + threshold_actual=0.5, + match_weight_round_to_nearest: float = None, + ): + """Generate truth statistics (false positive etc.) for each threshold value of + match_probability, suitable for plotting a ROC chart. + + Your labels_column_name should include the ground truth cluster (unique + identifier) that groups entities which are the same + + Args: + labels_tablename (str): Name of table containing labels in the database + threshold_actual (float, optional): Where the `clerical_match_score` + provided by the user is a probability rather than binary, this value + is used as the threshold to classify `clerical_match_score`s as binary + matches or non matches. Defaults to 0.5. + match_weight_round_to_nearest (float, optional): When provided, thresholds + are rounded. When large numbers of labels are provided, this is + sometimes necessary to reduce the size of the ROC table, and therefore + the number of points plotted on the ROC chart. Defaults to None. + + Examples: + >>> linker.truth_space_table_from_labels_column("cluster") + + Returns: + SplinkDataFrame: Table of truth statistics + """ + return truth_space_table_from_labels_column( + self, labels_column_name, threshold_actual, match_weight_round_to_nearest + ) + + def roc_chart_from_labels_column( + self, + labels_column_name, + threshold_actual=0.5, + match_weight_round_to_nearest: float = None, + ): + """Generate a ROC chart from ground truth data, whereby the ground truth + is in a column in the input dataset called `labels_column_name` + + Args: + labels_column_name (str): Column name containing labels in the input table + threshold_actual (float, optional): Where the `clerical_match_score` + provided by the user is a probability rather than binary, this value + is used as the threshold to classify `clerical_match_score`s as binary + matches or non matches. Defaults to 0.5. + match_weight_round_to_nearest (float, optional): When provided, thresholds + are rounded. When large numbers of labels are provided, this is + sometimes necessary to reduce the size of the ROC table, and therefore + the number of points plotted on the ROC chart. Defaults to None. + + Examples: + >>> linker.roc_chart_from_labels_column("labels") + + + Returns: + VegaLite: A VegaLite chart object. See altair.vegalite.v4.display.VegaLite. + The vegalite spec is available as a dictionary using the `spec` + attribute. + """ + + df_truth_space = truth_space_table_from_labels_column( + self, + labels_column_name, threshold_actual=threshold_actual, match_weight_round_to_nearest=match_weight_round_to_nearest, ) + recs = df_truth_space.as_record_dict() + return roc_chart(recs) + + def precision_recall_chart_from_labels_column( + self, + labels_column_name, + threshold_actual=0.5, + match_weight_round_to_nearest: float = None, + ): + """Generate a precision-recall chart from ground truth data, whereby the ground + truth is in a column in the input dataset called `labels_column_name` + + Args: + labels_column_name (str): Column name containing labels in the input table + threshold_actual (float, optional): Where the `clerical_match_score` + provided by the user is a probability rather than binary, this value + is used as the threshold to classify `clerical_match_score`s as binary + matches or non matches. Defaults to 0.5. + match_weight_round_to_nearest (float, optional): When provided, thresholds + are rounded. When large numbers of labels are provided, this is + sometimes necessary to reduce the size of the ROC table, and therefore + the number of points plotted on the ROC chart. Defaults to None. + Examples: + >>> linker.precision_recall_chart_from_labels_column("ground_truth") + + + Returns: + VegaLite: A VegaLite chart object. See altair.vegalite.v4.display.VegaLite. + The vegalite spec is available as a dictionary using the `spec` + attribute. + """ + + df_truth_space = truth_space_table_from_labels_column( + self, + labels_column_name, + threshold_actual=threshold_actual, + match_weight_round_to_nearest=match_weight_round_to_nearest, + ) + recs = df_truth_space.as_record_dict() + return precision_recall_chart(recs) + + def prediction_errors_from_labels_column( + self, + label_colname, + include_false_positives=True, + include_false_negatives=True, + threshold=0.5, + ): + """Generate a dataframe containing false positives and false negatives + based on the comparison between the splink match probability and the + labels column. A label column is a column in the input dataset that contains + the 'ground truth' cluster to which the record belongs + + Args: + label_colname (str): Name of labels column in input data + include_false_positives (bool, optional): Defaults to True. + include_false_negatives (bool, optional): Defaults to True. + threshold (float, optional): Threshold above which a score is considered + to be a match. Defaults to 0.5. + + Returns: + SplinkDataFrame: Table containing false positives and negatives + """ + return prediction_errors_from_label_column( + self, + label_colname, + include_false_positives, + include_false_negatives, + threshold, + ) def match_weights_histogram( self, df_predict: SplinkDataFrame, target_bins: int = 30, width=600, height=250 @@ -1973,62 +2143,3 @@ def estimate_probability_two_random_records_match( " possible comparisons, we expect a total of around " f"{prob*cartesian:,.2f} matching pairs" ) - - def prediction_errors_from_labels_table( - self, - labels_tablename, - include_false_positives=True, - include_false_negatives=True, - threshold=0.5, - ): - """Generate a dataframe containing false positives and false negatives - based on the comparison between the clerical_match_score in the labels - table compared with the splink predicted match probability - - Args: - labels_tablename (str): Name of labels table - include_false_positives (bool, optional): Defaults to True. - include_false_negatives (bool, optional): Defaults to True. - threshold (float, optional): Threshold above which a score is considered - to be a match. Defaults to 0.5. - - Returns: - SplinkDataFrame: Table containing false positives and negatives - """ - return prediction_errors_from_labels_table( - self, - labels_tablename, - include_false_positives, - include_false_negatives, - threshold, - ) - - def prediction_errors_from_label_column( - self, - label_colname, - include_false_positives=True, - include_false_negatives=True, - threshold=0.5, - ): - """Generate a dataframe containing false positives and false negatives - based on the comparison between the splink match probability and the - labels column. A label column is a column in the input dataset that contains - the 'ground truth' cluster to which the record belongs - - Args: - label_colname (str): Name of labels column in input data - include_false_positives (bool, optional): Defaults to True. - include_false_negatives (bool, optional): Defaults to True. - threshold (float, optional): Threshold above which a score is considered - to be a match. Defaults to 0.5. - - Returns: - SplinkDataFrame: Table containing false positives and negatives - """ - return prediction_errors_from_label_column( - self, - label_colname, - include_false_positives, - include_false_negatives, - threshold, - ) diff --git a/tests/test_accuracy.py b/tests/test_accuracy.py index c552386092..2ea33d8095 100644 --- a/tests/test_accuracy.py +++ b/tests/test_accuracy.py @@ -16,7 +16,7 @@ from splink.predict import predict_from_comparison_vectors_sqls -def test_scored_labels(): +def test_scored_labels_table(): df = pd.read_csv("./tests/datasets/fake_1000_from_splink_demos.csv") df = df.head(5) @@ -207,7 +207,7 @@ def test_roc_chart_dedupe_only(): linker._con.register("labels", df_labels) - linker.roc_chart_from_labels("labels") + linker.roc_chart_from_labels_table("labels") def test_roc_chart_link_and_dedupe(): @@ -240,7 +240,7 @@ def test_roc_chart_link_and_dedupe(): linker._con.register("labels", df_labels) - linker.roc_chart_from_labels("labels") + linker.roc_chart_from_labels_table("labels") def test_prediction_errors_from_labels_table(): @@ -390,7 +390,9 @@ def test_prediction_errors_from_labels_column(): # linker = DuckDBLinker(df, settings) - df_res = linker.prediction_errors_from_label_column("cluster").as_pandas_dataframe() + df_res = linker.prediction_errors_from_labels_column( + "cluster" + ).as_pandas_dataframe() df_res = df_res[["unique_id_l", "unique_id_r"]] records = list(df_res.to_records(index=False)) records = [tuple(p) for p in records] @@ -403,7 +405,7 @@ def test_prediction_errors_from_labels_column(): linker = DuckDBLinker(df, settings) - df_res = linker.prediction_errors_from_label_column( + df_res = linker.prediction_errors_from_labels_column( "cluster", include_false_positives=False ).as_pandas_dataframe() df_res = df_res[["unique_id_l", "unique_id_r"]] @@ -418,7 +420,7 @@ def test_prediction_errors_from_labels_column(): linker = DuckDBLinker(df, settings) - df_res = linker.prediction_errors_from_label_column( + df_res = linker.prediction_errors_from_labels_column( "cluster", include_false_negatives=False ).as_pandas_dataframe() df_res = df_res[["unique_id_l", "unique_id_r"]] @@ -430,3 +432,135 @@ def test_prediction_errors_from_labels_column(): assert (4, 5) in records # FP assert (1, 2) not in records # TP assert (1, 5) not in records # TN + + +def test_truth_space_table_from_labels_column_dedupe_only(): + + data = [ + {"unique_id": 1, "first_name": "john", "cluster": 1}, + {"unique_id": 2, "first_name": "john", "cluster": 1}, + {"unique_id": 3, "first_name": "john", "cluster": 1}, + {"unique_id": 4, "first_name": "john", "cluster": 2}, + {"unique_id": 5, "first_name": "edith", "cluster": 3}, + {"unique_id": 6, "first_name": "mary", "cluster": 3}, + ] + + df = pd.DataFrame(data) + + settings = { + "link_type": "dedupe_only", + "probability_two_random_records_match": 0.5, + "blocking_rules_to_generate_predictions": [ + "1=1", + ], + "comparisons": [ + { + "output_column_name": "First name", + "comparison_levels": [ + { + "sql_condition": "first_name_l IS NULL OR first_name_r IS NULL", + "label_for_charts": "Null", + "is_null_level": True, + }, + { + "sql_condition": "first_name_l = first_name_r", + "label_for_charts": "Exact match", + "m_probability": 0.9, + "u_probability": 0.1, + }, + { + "sql_condition": "ELSE", + "label_for_charts": "All other comparisons", + "m_probability": 0.1, + "u_probability": 0.9, + }, + ], + }, + ], + } + + linker = DuckDBLinker(df, settings) + + tt = linker.truth_space_table_from_labels_column("cluster").as_record_dict() + # Truth threshold -3.17, meaning all comparisons get classified as positive + truth_dict = tt[0] + assert truth_dict["TP"] == 4 + assert truth_dict["FP"] == 11 + assert truth_dict["TN"] == 0 + assert truth_dict["FN"] == 0 + + # Truth threshold 3.17, meaning only comparisons where forename match get classified + # as positive + truth_dict = tt[1] + assert truth_dict["TP"] == 3 + assert truth_dict["FP"] == 3 + assert truth_dict["TN"] == 8 + assert truth_dict["FN"] == 1 + + +def test_truth_space_table_from_labels_column_link_only(): + + data_left = [ + {"unique_id": 1, "first_name": "john", "ground_truth": 1}, + {"unique_id": 2, "first_name": "mary", "ground_truth": 2}, + {"unique_id": 3, "first_name": "edith", "ground_truth": 3}, + ] + + data_right = [ + {"unique_id": 1, "first_name": "john", "ground_truth": 1}, + {"unique_id": 2, "first_name": "john", "ground_truth": 2}, + {"unique_id": 3, "first_name": "eve", "ground_truth": 3}, + ] + + df_left = pd.DataFrame(data_left) + df_right = pd.DataFrame(data_right) + + settings = { + "link_type": "link_only", + "probability_two_random_records_match": 0.5, + "blocking_rules_to_generate_predictions": [ + "1=1", + ], + "comparisons": [ + { + "output_column_name": "First name", + "comparison_levels": [ + { + "sql_condition": "first_name_l IS NULL OR first_name_r IS NULL", + "label_for_charts": "Null", + "is_null_level": True, + }, + { + "sql_condition": "first_name_l = first_name_r", + "label_for_charts": "Exact match", + "m_probability": 0.9, + "u_probability": 0.1, + }, + { + "sql_condition": "ELSE", + "label_for_charts": "All other comparisons", + "m_probability": 0.1, + "u_probability": 0.9, + }, + ], + }, + ], + } + + linker = DuckDBLinker([df_left, df_right], settings) + + tt = linker.truth_space_table_from_labels_column("ground_truth").as_record_dict() + # Truth threshold -3.17, meaning all comparisons get classified as positive + truth_dict = tt[0] + assert truth_dict["TP"] == 3 + assert truth_dict["FP"] == 6 + assert truth_dict["TN"] == 0 + assert truth_dict["FN"] == 0 + + # Truth threshold 3.17, meaning only comparisons where forename match get classified + # as positive + truth_dict = tt[1] + assert truth_dict["TP"] == 1 + assert truth_dict["FP"] == 1 + assert truth_dict["TN"] == 5 + assert truth_dict["FN"] == 2 diff --git a/tests/test_full_example_duckdb.py b/tests/test_full_example_duckdb.py index 0397c9e9b7..93c8df4063 100644 --- a/tests/test_full_example_duckdb.py +++ b/tests/test_full_example_duckdb.py @@ -93,7 +93,7 @@ def test_full_example_duckdb(tmp_path): linker._con.register("labels", df_labels) # Finish create labels - linker.roc_chart_from_labels("labels") + linker.roc_chart_from_labels_table("labels") df_clusters = linker.cluster_pairwise_predictions_at_threshold(df_predict, 0.1)