diff --git a/pyproject.toml b/pyproject.toml index f4e2f89ac7..81201520b6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "splink" -version = "3.0.0.dev03" +version = "3.0.0.dev04" description = "Implementation of Fellegi-Sunter's canonical model of record linkage in Apache Spark, including EM algorithm to estimate parameters" authors = ["Robin Linacre ", "Sam Lindsay", "Theodore Manassis"] license = "MIT" diff --git a/splink/charts.py b/splink/charts.py index 5a3d2b01df..a29218f391 100644 --- a/splink/charts.py +++ b/splink/charts.py @@ -207,3 +207,12 @@ def match_weight_histogram(records, width=500, height=250, as_dict=False): chart["width"] = width return vegalite_or_json(chart, as_dict=as_dict) + + +def parameter_estimate_comparisons(records, as_dict=False): + chart_path = "parameter_estimate_comparisons.json" + chart = load_chart_definition(chart_path) + + chart["data"]["values"] = records + + return vegalite_or_json(chart, as_dict=as_dict) diff --git a/splink/comparison.py b/splink/comparison.py index 7fc0b0f542..488858447c 100644 --- a/splink/comparison.py +++ b/splink/comparison.py @@ -261,6 +261,17 @@ def as_detailed_records(self): records.append(record) return records + @property + def parameter_estimates_as_records(self): + records = [] + for cl in self.comparison_levels: + new_records = cl.parameter_estimates_as_records + for r in new_records: + r["comparison_name"] = self.comparison_name + records.extend(new_records) + + return records + def get_comparison_level_by_comparison_vector_value(self, value): for cl in self.comparison_levels: diff --git a/splink/comparison_level.py b/splink/comparison_level.py index 3d6ca2aa0a..f37ca58776 100644 --- a/splink/comparison_level.py +++ b/splink/comparison_level.py @@ -166,10 +166,14 @@ def u_probability_description(self): ) def add_trained_u_probability(self, val, desc="no description given"): - self.trained_u_probabilities.append({"u_probability": val, "description": desc}) + self.trained_u_probabilities.append( + {"probability": val, "description": desc, "m_or_u": "u"} + ) def add_trained_m_probability(self, val, desc="no description given"): - self.trained_m_probabilities.append({"m_probability": val, "description": desc}) + self.trained_m_probabilities.append( + {"probability": val, "description": desc, "m_or_u": "m"} + ) @property def u_is_trained(self): @@ -189,11 +193,11 @@ def is_trained(self): @property def trained_m_median(self): - return median([r["m_probability"] for r in self.trained_m_probabilities]) + return median([r["probability"] for r in self.trained_m_probabilities]) @property def trained_u_median(self): - return median([r["u_probability"] for r in self.trained_u_probabilities]) + return median([r["probability"] for r in self.trained_u_probabilities]) @property def bayes_factor(self): @@ -474,6 +478,31 @@ def as_detailed_record(self): return output + @property + def parameter_estimates_as_records(self): + + output_records = [] + + cl_record = self.as_detailed_record + trained_values = self.trained_u_probabilities + self.trained_m_probabilities + for trained_value in trained_values: + record = {} + record["m_or_u"] = trained_value["m_or_u"] + p = trained_value["probability"] + record["estimated_probability"] = p + record["estimate_description"] = trained_value["description"] + if p is not None and p > 0.0: + record["estimated_probability_as_log_odds"] = math.log2(p / (1 - p)) + else: + record["estimated_probability_as_log_odds"] = None + + record["sql_condition"] = cl_record["sql_condition"] + record["comparison_level_label"] = cl_record["label_for_charts"] + record["comparison_vector_value"] = cl_record["comparison_vector_value"] + output_records.append(record) + + return output_records + def validate(self): self._validate_sql() diff --git a/splink/comparison_levels_library.py b/splink/comparison_levels_library.py index c0046bc170..d195768885 100644 --- a/splink/comparison_levels_library.py +++ b/splink/comparison_levels_library.py @@ -9,7 +9,7 @@ def null_level(col_name): def exact_match_level(col_name, m_probability=None, term_frequency_adjustments=False): d = { "sql_condition": f"{col_name}_l = {col_name}_r", - "label_for_charts": "exact_match", + "label_for_charts": "Exact match", } if m_probability: d["m_probability"] = m_probability diff --git a/splink/files/chart_defs/parameter_estimate_comparisons.json b/splink/files/chart_defs/parameter_estimate_comparisons.json new file mode 100644 index 0000000000..4761642c36 --- /dev/null +++ b/splink/files/chart_defs/parameter_estimate_comparisons.json @@ -0,0 +1,64 @@ +{ + "$schema": "https://vega.github.io/schema/vega-lite/v5.2.0.json", + "title": { + "text": "Comparison of parameter estimates across training sessions", + "subtitle": "Use mousewheeel to zoom" + }, + "data": { + "values": [] + }, + "config": { + "view": { "continuousWidth": 400, "continuousHeight": 300 }, + "title": { "anchor": "middle" } + }, + + "mark": { "type": "point", "filled": false, "opacity": 0.7, "size": 100 }, + + "encoding": { + "color": { "type": "nominal", "field": "estimate_description" }, + "row": { + "type": "nominal", + "field": "comparison_name", + "header": { + "labelAlign": "left", + "labelAnchor": "middle", + "labelAngle": 0 + }, + "sort": { "field": "comparison_sort_order" }, + "title": null + }, + "column": { "type": "nominal", "field": "m_or_u", "title": null }, + "shape": { + "type": "nominal", + "field": "estimate_description", + "scale": { "range": ["circle", "square", "triangle", "diamond"] } + }, + "tooltip": [ + { "type": "nominal", "field": "comparison_name" }, + { "type": "nominal", "field": "estimate_description" }, + { "type": "quantitative", "field": "estimated_probability" } + ], + "x": { + "type": "quantitative", + "field": "estimated_probability_as_log_odds" + }, + "y": { + "type": "nominal", + "axis": { "grid": true, "title": null }, + "field": "comparison_level_label", + "sort": { "field": "comparison_vector_value", "order": "descending" } + } + }, + "resolve": { + "scale": { + "y": "independent" + } + }, + "selection": { + "selection_zoom": { + "type": "interval", + "bind": "scales", + "encodings": ["x"] + } + } +} diff --git a/splink/linker.py b/splink/linker.py index 5a442cf590..5f9e8e4140 100644 --- a/splink/linker.py +++ b/splink/linker.py @@ -3,7 +3,12 @@ from statistics import median import hashlib -from .charts import match_weight_histogram, precision_recall_chart, roc_chart +from .charts import ( + match_weight_histogram, + precision_recall_chart, + roc_chart, + parameter_estimate_comparisons, +) from .blocking import block_using_rules from .comparison_vector_values import compute_comparison_vector_values @@ -642,3 +647,8 @@ def splink_comparison_viewer( out_path, overwrite, ) + + def parameter_estimate_comparisons(self): + return parameter_estimate_comparisons( + self.settings_obj._parameter_estimates_as_records + ) diff --git a/splink/settings.py b/splink/settings.py index edea419a4b..0210173ef8 100644 --- a/splink/settings.py +++ b/splink/settings.py @@ -238,6 +238,16 @@ def _parameters_as_detailed_records(self): output.extend(records) return output + @property + def _parameter_estimates_as_records(self): + output = [] + for i, cc in enumerate(self.comparisons): + records = cc.parameter_estimates_as_records + for r in records: + r["comparison_sort_order"] = i + output.extend(records) + return output + @property def as_dict(self): current_settings = {