diff --git a/ctc_metrics/metrics/technical/tra.py b/ctc_metrics/metrics/technical/tra.py index c9e674c..1d64248 100644 --- a/ctc_metrics/metrics/technical/tra.py +++ b/ctc_metrics/metrics/technical/tra.py @@ -44,5 +44,5 @@ def tra( AOGM_0 = w_fn * num_vertices + w_ea * num_edges # Calculate DET TRA = 1 - min(AOGM, AOGM_0) / AOGM_0 - return float(TRA) + return float(TRA), AOGM, AOGM_0 diff --git a/ctc_metrics/scripts/evaluate.py b/ctc_metrics/scripts/evaluate.py index 0d3c365..04c5ee7 100644 --- a/ctc_metrics/scripts/evaluate.py +++ b/ctc_metrics/scripts/evaluate.py @@ -198,7 +198,12 @@ def calculate_metrics( results["SEG"] = seg(segm["labels_ref"], segm["ious"]) if "TRA" in metrics: - results["TRA"] = tra(**graph_operations) + _tra, _aogm, _aogm0 = tra(**graph_operations) + results["TRA"] = _tra + results["AOGM"] = _aogm + results["AOGM_0"] = _aogm0 + for key in ["NS", "FN", "FP", "ED", "EA", "EC"]: + results[f"AOGM_{key}"] = graph_operations[key] if "LNK" in metrics: results["LNK"] = lnk(**graph_operations) @@ -265,7 +270,6 @@ def calculate_metrics( results.update(faf( traj["labels_comp_merged"], traj["mapped_comp_merged"])) - return results