diff --git a/elk/training/train.py b/elk/training/train.py index c654ca3a..fb882240 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -82,8 +82,7 @@ def apply_to_layer( reporter = CcsReporter(self.net, d, device=device, num_variants=v) train_loss = reporter.fit(first_train_h) - labels = to_one_hot(train_gt, k) - labels = repeat(labels, "n k -> n v k", v=v) + labels = repeat(to_one_hot(train_gt, k), "n k -> n v k", v=v) reporter.platt_scale(labels, first_train_h) elif isinstance(self.net, EigenFitterConfig):