From 4394bf9ca036f35edd319b8b1fbba85eabeadc87 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 18 Apr 2023 05:46:47 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- elk/training/reporter.py | 12 ++++++++++-- elk/training/supervised.py | 6 +++++- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/elk/training/reporter.py b/elk/training/reporter.py index 60bbcd12..16ce5f35 100644 --- a/elk/training/reporter.py +++ b/elk/training/reporter.py @@ -132,7 +132,11 @@ def score_contrast_set(self, labels: Tensor, contrast_set: Tensor) -> EvalResult cal_err = 0.0 Y_one_hot = to_one_hot(Y, c).long().flatten() - auroc_result = RocAucResult(-1., -1., -1.) if len(labels.unique()) == 1 else roc_auc_ci(Y_one_hot.cpu(), logits.cpu().flatten()) + auroc_result = ( + RocAucResult(-1.0, -1.0, -1.0) + if len(labels.unique()) == 1 + else roc_auc_ci(Y_one_hot.cpu(), logits.cpu().flatten()) + ) raw_preds = logits.argmax(dim=-1).long() raw_acc = accuracy(Y, raw_preds.flatten()) @@ -172,7 +176,11 @@ def score(self, labels: Tensor, hiddens: Tensor) -> EvalResult: preds = probs.gt(0.5).to(torch.int) acc = preds.flatten().eq(labels).float().mean().item() - auroc_result = RocAucResult(-1., -1., -1.) if len(labels.unique()) == 1 else roc_auc_ci(labels.cpu(), logits.cpu().flatten()) + auroc_result = ( + RocAucResult(-1.0, -1.0, -1.0) + if len(labels.unique()) == 1 + else roc_auc_ci(labels.cpu(), logits.cpu().flatten()) + ) return EvalResult( auroc=auroc_result.estimate, diff --git a/elk/training/supervised.py b/elk/training/supervised.py index ded52fb6..8832613e 100644 --- a/elk/training/supervised.py +++ b/elk/training/supervised.py @@ -34,7 +34,11 @@ def evaluate_supervised( raise ValueError(f"Invalid val_h shape: {val_h.shape}") lr_acc = accuracy(labels, raw_preds.flatten()) - lr_auroc = RocAucResult(-1., -1., -1.) if len(labels.unique()) == 1 else roc_auc_ci(labels, logits.flatten()) + lr_auroc = ( + RocAucResult(-1.0, -1.0, -1.0) + if len(labels.unique()) == 1 + else roc_auc_ci(labels, logits.flatten()) + ) return lr_auroc, assert_type(float, lr_acc)