diff --git a/elk/evaluation/evaluate.py b/elk/evaluation/evaluate.py index 4b2c7b87..47d7788c 100644 --- a/elk/evaluation/evaluate.py +++ b/elk/evaluation/evaluate.py @@ -72,7 +72,7 @@ def evaluate_reporter( def evaluate_reporters(cfg: EvaluateConfig, out_dir: Optional[Path] = None): - if cfg.source == 'zero-shot': + if cfg.source == "zero-shot": evaluate_zeroshot(cfg, out_dir) return @@ -145,12 +145,12 @@ def get_logprobs_tensor(split): # returns (num_examples, num_prompts, num_labels) logprobs_lst = [] for ex in tqdm(ds[split]): - logprobs_lst.append(ex['logprobs']) + logprobs_lst.append(ex["logprobs"]) logprobs = torch.tensor(logprobs_lst) return logprobs - train_logprobs = get_logprobs_tensor('train') - test_logprobs = get_logprobs_tensor('test') + train_logprobs = get_logprobs_tensor("train") + test_logprobs = get_logprobs_tensor("test") print(test_logprobs.shape) @@ -174,7 +174,7 @@ def get_logprobs_tensor(split): zs_preds_cal = torch.argmax(test_logprobs_cal, dim=-1) # labels: (num_examples,) -> (num_examples, num_prompts) - labels = torch.tensor(ds['test']['label']).unsqueeze(1) + labels = torch.tensor(ds["test"]["label"]).unsqueeze(1) labels = labels.repeat(1, num_prompts) # accuracy calculation @@ -183,16 +183,21 @@ def get_logprobs_tensor(split): # get aggregate truth probability # (1 - e^l_neg) * 0.5 + (e^l_pos) * 0.5 - truth_prob = ((1 - torch.exp(test_logprobs_cal[:, :, 0])) + torch.exp(train_logprobs[:, :, 1])) * 0.5 + truth_prob = ( + (1 - torch.exp(test_logprobs_cal[:, :, 0])) + torch.exp(train_logprobs[:, :, 1]) + ) * 0.5 # auroc calculation # logprobs -> probs test_probs = torch.exp(test_logprobs) auroc = roc_auc_score(labels.flatten(), truth_prob.flatten()) - print('Raw accuracy: %.4f\nCalibrated accuracy:%.4f\nAUROC: %.4f' % (raw_acc, cal_acc, auroc)) + print( + "Raw accuracy: %.4f\nCalibrated accuracy:%.4f\nAUROC: %.4f" + % (raw_acc, cal_acc, auroc) + ) - cols = ['acc', 'cal_acc', 'auroc'] + cols = ["acc", "cal_acc", "auroc"] with open(out_dir / "eval.csv", "w") as f: writer = csv.writer(f)