From 6ed5eca7eb35967e5117f83179d21ba1e7ba145d Mon Sep 17 00:00:00 2001 From: Walter Laurito Date: Mon, 28 Aug 2023 00:02:28 +0200 Subject: [PATCH] remove extra line for label --- elk/training/train.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/elk/training/train.py b/elk/training/train.py index c654ca3a..fb882240 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -82,8 +82,7 @@ def apply_to_layer( reporter = CcsReporter(self.net, d, device=device, num_variants=v) train_loss = reporter.fit(first_train_h) - labels = to_one_hot(train_gt, k) - labels = repeat(labels, "n k -> n v k", v=v) + labels = repeat(to_one_hot(train_gt, k), "n k -> n v k", v=v) reporter.platt_scale(labels, first_train_h) elif isinstance(self.net, EigenFitterConfig):