Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Mar 17, 2023
1 parent 1302c76 commit 8cd297a
Showing 1 changed file with 13 additions and 8 deletions.
21 changes: 13 additions & 8 deletions elk/evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def evaluate_reporter(


def evaluate_reporters(cfg: EvaluateConfig, out_dir: Optional[Path] = None):
if cfg.source == 'zero-shot':
if cfg.source == "zero-shot":
evaluate_zeroshot(cfg, out_dir)

return
Expand Down Expand Up @@ -145,12 +145,12 @@ def get_logprobs_tensor(split):
# returns (num_examples, num_prompts, num_labels)
logprobs_lst = []
for ex in tqdm(ds[split]):
logprobs_lst.append(ex['logprobs'])
logprobs_lst.append(ex["logprobs"])

Check failure on line 148 in elk/evaluation/evaluate.py

View workflow job for this annotation

GitHub Actions / run-tests (3.10, ubuntu-latest)

Argument of type "Literal['logprobs']" cannot be assigned to parameter "__s" of type "slice" in function "__getitem__"   "Literal['logprobs']" is incompatible with "slice" (reportGeneralTypeIssues)
logprobs = torch.tensor(logprobs_lst)
return logprobs

train_logprobs = get_logprobs_tensor('train')
test_logprobs = get_logprobs_tensor('test')
train_logprobs = get_logprobs_tensor("train")
test_logprobs = get_logprobs_tensor("test")

print(test_logprobs.shape)

Expand All @@ -174,7 +174,7 @@ def get_logprobs_tensor(split):
zs_preds_cal = torch.argmax(test_logprobs_cal, dim=-1)

# labels: (num_examples,) -> (num_examples, num_prompts)
labels = torch.tensor(ds['test']['label']).unsqueeze(1)
labels = torch.tensor(ds["test"]["label"]).unsqueeze(1)
labels = labels.repeat(1, num_prompts)

# accuracy calculation
Expand All @@ -183,16 +183,21 @@ def get_logprobs_tensor(split):

# get aggregate truth probability
# (1 - e^l_neg) * 0.5 + (e^l_pos) * 0.5
truth_prob = ((1 - torch.exp(test_logprobs_cal[:, :, 0])) + torch.exp(train_logprobs[:, :, 1])) * 0.5
truth_prob = (
(1 - torch.exp(test_logprobs_cal[:, :, 0])) + torch.exp(train_logprobs[:, :, 1])
) * 0.5

# auroc calculation
# logprobs -> probs
test_probs = torch.exp(test_logprobs)
auroc = roc_auc_score(labels.flatten(), truth_prob.flatten())

print('Raw accuracy: %.4f\nCalibrated accuracy:%.4f\nAUROC: %.4f' % (raw_acc, cal_acc, auroc))
print(
"Raw accuracy: %.4f\nCalibrated accuracy:%.4f\nAUROC: %.4f"
% (raw_acc, cal_acc, auroc)
)

cols = ['acc', 'cal_acc', 'auroc']
cols = ["acc", "cal_acc", "auroc"]

with open(out_dir / "eval.csv", "w") as f:
writer = csv.writer(f)
Expand Down

0 comments on commit 8cd297a

Please sign in to comment.