diff --git a/code/metrics/accuracies.py b/code/metrics/accuracies.py index 3d7e211..3e44571 100644 --- a/code/metrics/accuracies.py +++ b/code/metrics/accuracies.py @@ -9,7 +9,7 @@ def compute_orientability_accuracies( ): benchmarks = [] for metrics_ in metrics: - y_hat_ = torch.sigmoid(y_hat).long() + y_hat_ = torch.sigmoid(y_hat).round().long() metric = metrics_.metric.to(y_hat.device) benchmarks.append( {"name": f"{name}_{metrics_.name}", "value": metric(y_hat_, y)}