Skip to content

Commit

Permalink
Merge pull request #367 from moj-analytical-services/compare_estimates
Browse files Browse the repository at this point in the history
Compare estimates
  • Loading branch information
RobinL authored Apr 1, 2022
2 parents 4ed739d + 81f09e3 commit 213d32d
Show file tree
Hide file tree
Showing 8 changed files with 140 additions and 7 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "splink"
version = "3.0.0.dev03"
version = "3.0.0.dev04"
description = "Implementation of Fellegi-Sunter's canonical model of record linkage in Apache Spark, including EM algorithm to estimate parameters"
authors = ["Robin Linacre <[email protected]>", "Sam Lindsay", "Theodore Manassis"]
license = "MIT"
Expand Down
9 changes: 9 additions & 0 deletions splink/charts.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,3 +207,12 @@ def match_weight_histogram(records, width=500, height=250, as_dict=False):
chart["width"] = width

return vegalite_or_json(chart, as_dict=as_dict)


def parameter_estimate_comparisons(records, as_dict=False):
chart_path = "parameter_estimate_comparisons.json"
chart = load_chart_definition(chart_path)

chart["data"]["values"] = records

return vegalite_or_json(chart, as_dict=as_dict)
11 changes: 11 additions & 0 deletions splink/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,17 @@ def as_detailed_records(self):
records.append(record)
return records

@property
def parameter_estimates_as_records(self):
records = []
for cl in self.comparison_levels:
new_records = cl.parameter_estimates_as_records
for r in new_records:
r["comparison_name"] = self.comparison_name
records.extend(new_records)

return records

def get_comparison_level_by_comparison_vector_value(self, value):
for cl in self.comparison_levels:

Expand Down
37 changes: 33 additions & 4 deletions splink/comparison_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,10 +166,14 @@ def u_probability_description(self):
)

def add_trained_u_probability(self, val, desc="no description given"):
self.trained_u_probabilities.append({"u_probability": val, "description": desc})
self.trained_u_probabilities.append(
{"probability": val, "description": desc, "m_or_u": "u"}
)

def add_trained_m_probability(self, val, desc="no description given"):
self.trained_m_probabilities.append({"m_probability": val, "description": desc})
self.trained_m_probabilities.append(
{"probability": val, "description": desc, "m_or_u": "m"}
)

@property
def u_is_trained(self):
Expand All @@ -189,11 +193,11 @@ def is_trained(self):

@property
def trained_m_median(self):
return median([r["m_probability"] for r in self.trained_m_probabilities])
return median([r["probability"] for r in self.trained_m_probabilities])

@property
def trained_u_median(self):
return median([r["u_probability"] for r in self.trained_u_probabilities])
return median([r["probability"] for r in self.trained_u_probabilities])

@property
def bayes_factor(self):
Expand Down Expand Up @@ -474,6 +478,31 @@ def as_detailed_record(self):

return output

@property
def parameter_estimates_as_records(self):

output_records = []

cl_record = self.as_detailed_record
trained_values = self.trained_u_probabilities + self.trained_m_probabilities
for trained_value in trained_values:
record = {}
record["m_or_u"] = trained_value["m_or_u"]
p = trained_value["probability"]
record["estimated_probability"] = p
record["estimate_description"] = trained_value["description"]
if p is not None and p > 0.0:
record["estimated_probability_as_log_odds"] = math.log2(p / (1 - p))
else:
record["estimated_probability_as_log_odds"] = None

record["sql_condition"] = cl_record["sql_condition"]
record["comparison_level_label"] = cl_record["label_for_charts"]
record["comparison_vector_value"] = cl_record["comparison_vector_value"]
output_records.append(record)

return output_records

def validate(self):
self._validate_sql()

Expand Down
2 changes: 1 addition & 1 deletion splink/comparison_levels_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def null_level(col_name):
def exact_match_level(col_name, m_probability=None, term_frequency_adjustments=False):
d = {
"sql_condition": f"{col_name}_l = {col_name}_r",
"label_for_charts": "exact_match",
"label_for_charts": "Exact match",
}
if m_probability:
d["m_probability"] = m_probability
Expand Down
64 changes: 64 additions & 0 deletions splink/files/chart_defs/parameter_estimate_comparisons.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
{
"$schema": "https://vega.github.io/schema/vega-lite/v5.2.0.json",
"title": {
"text": "Comparison of parameter estimates across training sessions",
"subtitle": "Use mousewheeel to zoom"
},
"data": {
"values": []
},
"config": {
"view": { "continuousWidth": 400, "continuousHeight": 300 },
"title": { "anchor": "middle" }
},

"mark": { "type": "point", "filled": false, "opacity": 0.7, "size": 100 },

"encoding": {
"color": { "type": "nominal", "field": "estimate_description" },
"row": {
"type": "nominal",
"field": "comparison_name",
"header": {
"labelAlign": "left",
"labelAnchor": "middle",
"labelAngle": 0
},
"sort": { "field": "comparison_sort_order" },
"title": null
},
"column": { "type": "nominal", "field": "m_or_u", "title": null },
"shape": {
"type": "nominal",
"field": "estimate_description",
"scale": { "range": ["circle", "square", "triangle", "diamond"] }
},
"tooltip": [
{ "type": "nominal", "field": "comparison_name" },
{ "type": "nominal", "field": "estimate_description" },
{ "type": "quantitative", "field": "estimated_probability" }
],
"x": {
"type": "quantitative",
"field": "estimated_probability_as_log_odds"
},
"y": {
"type": "nominal",
"axis": { "grid": true, "title": null },
"field": "comparison_level_label",
"sort": { "field": "comparison_vector_value", "order": "descending" }
}
},
"resolve": {
"scale": {
"y": "independent"
}
},
"selection": {
"selection_zoom": {
"type": "interval",
"bind": "scales",
"encodings": ["x"]
}
}
}
12 changes: 11 additions & 1 deletion splink/linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@
from statistics import median
import hashlib

from .charts import match_weight_histogram, precision_recall_chart, roc_chart
from .charts import (
match_weight_histogram,
precision_recall_chart,
roc_chart,
parameter_estimate_comparisons,
)

from .blocking import block_using_rules
from .comparison_vector_values import compute_comparison_vector_values
Expand Down Expand Up @@ -642,3 +647,8 @@ def splink_comparison_viewer(
out_path,
overwrite,
)

def parameter_estimate_comparisons(self):
return parameter_estimate_comparisons(
self.settings_obj._parameter_estimates_as_records
)
10 changes: 10 additions & 0 deletions splink/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,16 @@ def _parameters_as_detailed_records(self):
output.extend(records)
return output

@property
def _parameter_estimates_as_records(self):
output = []
for i, cc in enumerate(self.comparisons):
records = cc.parameter_estimates_as_records
for r in records:
r["comparison_sort_order"] = i
output.extend(records)
return output

@property
def as_dict(self):
current_settings = {
Expand Down

0 comments on commit 213d32d

Please sign in to comment.