From 1836c48f245de0a3a9802693c4af00a682ca27ac Mon Sep 17 00:00:00 2001 From: binbash Date: Sat, 31 Aug 2024 12:21:28 +0200 Subject: [PATCH] fixed computation of orientability --- code/metrics/accuracies.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/code/metrics/accuracies.py b/code/metrics/accuracies.py index 3d7e211..3e44571 100644 --- a/code/metrics/accuracies.py +++ b/code/metrics/accuracies.py @@ -9,7 +9,7 @@ def compute_orientability_accuracies( ): benchmarks = [] for metrics_ in metrics: - y_hat_ = torch.sigmoid(y_hat).long() + y_hat_ = torch.sigmoid(y_hat).round().long() metric = metrics_.metric.to(y_hat.device) benchmarks.append( {"name": f"{name}_{metrics_.name}", "value": metric(y_hat_, y)}