Skip to content

Commit

Permalink
Merge pull request #1887 from moj-analytical-services/metrics_dataclass
Browse files Browse the repository at this point in the history
return data class instead of dictionary
  • Loading branch information
zslade authored Jan 29, 2024
2 parents b0c0b00 + a0d91cd commit 47f7d20
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 19 deletions.
19 changes: 19 additions & 0 deletions splink/cluster_metrics.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from dataclasses import dataclass
from typing import Dict, List

from splink.splink_dataframe import SplinkDataFrame
Expand Down Expand Up @@ -137,3 +138,21 @@ def _size_density_centralisation_sql(
sqls.append(sql)

return sqls


@dataclass
class GraphMetricsResults:
nodes: SplinkDataFrame
edges: SplinkDataFrame
clusters: SplinkDataFrame

def __repr__(self):
msg = (
"A data class of Splink dataframes containing metrics for nodes, edges "
"and clusters.\n"
"\nAccess dataframes via attributes:\n"
"`compute_graph_metrics.nodes` for node metrics,\n"
"`compute_graph_metrics.edges` for edge metrics, and\n"
"`compute_graph_metrics.clusters` for cluster metrics\n"
)
return msg
26 changes: 13 additions & 13 deletions splink/linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from copy import copy, deepcopy
from pathlib import Path
from statistics import median
from typing import Dict

import sqlglot

Expand Down Expand Up @@ -57,6 +56,7 @@
waterfall_chart,
)
from .cluster_metrics import (
GraphMetricsResults,
_node_degree_sql,
_size_density_centralisation_sql,
)
Expand Down Expand Up @@ -2223,11 +2223,12 @@ def _compute_graph_metrics(
self,
df_predict: SplinkDataFrame,
df_clustered: SplinkDataFrame,
*,
threshold_match_probability: float,
) -> Dict[str, SplinkDataFrame]:
) -> GraphMetricsResults:
"""
Generates tables containing graph metrics (for nodes, edges, and clusters),
and returns a dictionary of Splink dataframes
Generates tables containing graph metrics (for nodes, edges and clusters),
and returns a data class of Splink dataframes
Args:
df_predict (SplinkDataFrame): The results of `linker.predict()`
Expand All @@ -2238,11 +2239,11 @@ def _compute_graph_metrics(
above this threshold.
Returns:
dict[str, SplinkDataFrame]: A dictionary of SplinkDataFrames
containing cluster IDs and selected cluster, node, or edge metrics
key "nodes" for nodes metrics table
key "edges" for edge metrics table
key "clusters" for cluster metrics table
GraphMetricsResult: A data class containing SplinkDataFrames
of cluster IDs and selected node, edge or cluster metrics.
attribute "nodes" for nodes metrics table
attribute "edges" for edge metrics table
attribute "clusters" for cluster metrics table
"""
df_node_metrics = self._compute_metrics_nodes(
Expand All @@ -2251,10 +2252,9 @@ def _compute_graph_metrics(
# don't need edges as information is baked into node metrics
df_cluster_metrics = self._compute_metrics_clusters(df_node_metrics)

return {
"nodes": df_node_metrics,
"clusters": df_cluster_metrics,
}
return GraphMetricsResults(
nodes=df_node_metrics, edges=None, clusters=df_cluster_metrics
)

def profile_columns(
self, column_expressions: str | list[str] = None, top_n=10, bottom_n=10
Expand Down
14 changes: 8 additions & 6 deletions tests/test_cluster_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_size_density_dedupe():

df_result = linker._compute_graph_metrics(
df_predict, df_clustered, threshold_match_probability=0.9
)["clusters"].as_pandas_dataframe()
).clusters.as_pandas_dataframe()
# not testing this here - it's not relevant for small clusters anyhow
del df_result["cluster_centralisation"]

Expand Down Expand Up @@ -74,8 +74,8 @@ def test_size_density_link():
df_result = (
linker._compute_graph_metrics(
df_predict, df_clustered, threshold_match_probability=0.99
)["clusters"]
.as_pandas_dataframe()
)
.clusters.as_pandas_dataframe()
.sort_values(by="cluster_id")
.reset_index(drop=True)
)
Expand Down Expand Up @@ -225,8 +225,10 @@ def test_metrics(dialect, test_helpers):
df_predict = linker.register_table(helper.convert_frame(df_e), "predict")
df_clustered = linker.register_table(helper.convert_frame(df_c), "clusters")

cm = linker._compute_graph_metrics(df_predict, df_clustered, 0.95)
df_cm = cm["clusters"].as_pandas_dataframe()
cm = linker._compute_graph_metrics(
df_predict, df_clustered, threshold_match_probability=0.95
)
df_cm = cm.clusters.as_pandas_dataframe()

expected = [
{"cluster_id": 1, "n_nodes": 4, "n_edges": 4, "cluster_centralisation": 4 / 6},
Expand Down Expand Up @@ -260,7 +262,7 @@ def test_metrics(dialect, test_helpers):
expected_row_details["cluster_centralisation"]
)

df_nm = cm["nodes"].as_pandas_dataframe()
df_nm = cm.nodes.as_pandas_dataframe()

for unique_id, expected_node_degree in expected_node_degrees:
relevant_row = df_nm[df_nm["composite_unique_id"] == unique_id]
Expand Down

0 comments on commit 47f7d20

Please sign in to comment.