Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Apr 18, 2023
1 parent ad4cf34 commit 4394bf9
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 3 deletions.
12 changes: 10 additions & 2 deletions elk/training/reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,11 @@ def score_contrast_set(self, labels: Tensor, contrast_set: Tensor) -> EvalResult
cal_err = 0.0

Y_one_hot = to_one_hot(Y, c).long().flatten()
auroc_result = RocAucResult(-1., -1., -1.) if len(labels.unique()) == 1 else roc_auc_ci(Y_one_hot.cpu(), logits.cpu().flatten())
auroc_result = (
RocAucResult(-1.0, -1.0, -1.0)
if len(labels.unique()) == 1
else roc_auc_ci(Y_one_hot.cpu(), logits.cpu().flatten())
)

raw_preds = logits.argmax(dim=-1).long()
raw_acc = accuracy(Y, raw_preds.flatten())
Expand Down Expand Up @@ -172,7 +176,11 @@ def score(self, labels: Tensor, hiddens: Tensor) -> EvalResult:
preds = probs.gt(0.5).to(torch.int)
acc = preds.flatten().eq(labels).float().mean().item()

auroc_result = RocAucResult(-1., -1., -1.) if len(labels.unique()) == 1 else roc_auc_ci(labels.cpu(), logits.cpu().flatten())
auroc_result = (
RocAucResult(-1.0, -1.0, -1.0)
if len(labels.unique()) == 1
else roc_auc_ci(labels.cpu(), logits.cpu().flatten())
)

return EvalResult(
auroc=auroc_result.estimate,
Expand Down
6 changes: 5 additions & 1 deletion elk/training/supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,11 @@ def evaluate_supervised(
raise ValueError(f"Invalid val_h shape: {val_h.shape}")

lr_acc = accuracy(labels, raw_preds.flatten())
lr_auroc = RocAucResult(-1., -1., -1.) if len(labels.unique()) == 1 else roc_auc_ci(labels, logits.flatten())
lr_auroc = (
RocAucResult(-1.0, -1.0, -1.0)
if len(labels.unique()) == 1
else roc_auc_ci(labels, logits.flatten())
)

return lr_auroc, assert_type(float, lr_acc)

Expand Down

0 comments on commit 4394bf9

Please sign in to comment.