Skip to content

Commit

Permalink
Merge branch 'raw-extraction' of github.com:EleutherAI/elk into raw-e…
Browse files Browse the repository at this point in the history
…xtraction
  • Loading branch information
AlexTMallen committed Apr 24, 2023
2 parents 6fb6ef7 + 4394bf9 commit 9c0e54e
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions elk/training/reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,11 @@ def score_contrast_set(self, labels: Tensor, contrast_set: Tensor) -> EvalResult
cal_err = 0.0

Y_one_hot = to_one_hot(Y, c).long().flatten()
auroc_result = RocAucResult(-1., -1., -1.) if len(labels.unique()) == 1 else roc_auc_ci(Y_one_hot.cpu(), logits.cpu().flatten())
auroc_result = (
RocAucResult(-1.0, -1.0, -1.0)
if len(labels.unique()) == 1
else roc_auc_ci(Y_one_hot.cpu(), logits.cpu().flatten())
)

raw_preds = logits.argmax(dim=-1).long()
raw_acc = accuracy(Y, raw_preds.flatten())
Expand Down Expand Up @@ -172,7 +176,11 @@ def score(self, labels: Tensor, hiddens: Tensor) -> EvalResult:
preds = probs.gt(0.5).to(torch.int)
acc = preds.flatten().eq(labels).float().mean().item()

auroc_result = RocAucResult(-1., -1., -1.) if len(labels.unique()) == 1 else roc_auc_ci(labels.cpu(), logits.cpu().flatten())
auroc_result = (
RocAucResult(-1.0, -1.0, -1.0)
if len(labels.unique()) == 1
else roc_auc_ci(labels.cpu(), logits.cpu().flatten())
)

return EvalResult(
auroc=auroc_result.estimate,
Expand Down

0 comments on commit 9c0e54e

Please sign in to comment.