Skip to content

Commit

Permalink
Merge pull request #520 from allenai/add-ce-loss-metric
Browse files Browse the repository at this point in the history
Add ce_loss metric and TriviaQA/NaturalQuestions tasks
  • Loading branch information
OyvindTafjord authored Apr 25, 2024
2 parents 3b16e21 + ddb9d04 commit 829f1d6
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 7 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Added support for Grouped Query Attention.
- Added commonsense_qa and social_iqa downstream evaluation tasks
- Added ce_loss metric, with TriviaQA and NaturalQuestions tasks
- Makes it possible to read from http/https the same way we read from s3/r2.
- Added MMLU multiple choice (A/B/C/D) 5-shot variant downstream tasks
- Tokenizer patch
Expand Down
100 changes: 96 additions & 4 deletions olmo/eval/downstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class ICLMetric(Metric):
full_state_update: bool = False

def __init__(self, metric_type="acc") -> None:
"""metric_type: f1, acc, len_norm, pmi_dc"""
"""metric_type: f1, acc, len_norm, pmi_dc, ce_loss"""
super().__init__(sync_on_compute=True)

self.metric_type = metric_type
Expand Down Expand Up @@ -65,10 +65,12 @@ def update(self, batch: Dict[str, Any], lm_logits: torch.Tensor, dc_lm_logits=No
elif self.metric_type == "acc" or self.metric_type == "f1":
# gather log-probs at continuation token indices
log_likelihood = torch.gather(lm_cont_logits, 1, cont_tokens.unsqueeze(-1)).sum()
elif self.metric_type == "len_norm":
elif self.metric_type == "len_norm" or self.metric_type == "ce_loss":
log_likelihood = (
torch.gather(lm_cont_logits, 1, cont_tokens.unsqueeze(-1)).sum() / batch["cont_str_len"][idx]
)
if self.metric_type == "ce_loss":
log_likelihood = -log_likelihood
else:
raise ValueError(self.metric_type)

Expand Down Expand Up @@ -123,8 +125,10 @@ def compute(self) -> torch.Tensor:

if skip_document:
continue

correct.append(1.0 if torch.argmax(loglikelihoods).item() == label_dict[doc_id] else 0.0)
if self.metric_type == "ce_loss":
correct.append(loglikelihoods[0]) # Only one answer is scored
else:
correct.append(1.0 if torch.argmax(loglikelihoods).item() == label_dict[doc_id] else 0.0)

if self.metric_type == "f1":
assert preds is not None
Expand Down Expand Up @@ -754,6 +758,20 @@ def __init__(self, tokenizer, dataset_path="ai2_arc", dataset_name="ARC-Challeng
)


class ArcEasyCELoss(ArcEasy):
"""ArcEasyCELoss is ARCEasy using an alternate ce_loss metric"""

metric_type = "ce_loss"

def doc_to_continuations(self, doc):
# We only consider the correct answer for this metric
answer = doc["choices"]["text"][self.doc_to_label(doc)]
return [" " + answer]

def doc_to_label(self, doc):
return 0


class BasicArithmetic(ArcEasy):
"""This is a basic arithmetic task follows the same prompt format as ArcEasy.
Example:
Expand Down Expand Up @@ -1250,6 +1268,77 @@ def doc_to_domain_conditional(self, doc):
return "Answer:"


class TriviaQACELoss(ICLMultiChoiceTaskDataset):
"""Sample TriviaQA entity with some fields suppressed. For CE Loss we only consider the "value"
field as the answer to score.
{
'question': 'Which Lloyd Webber musical premiered in the US on 10th December 1993?',
'question_id': 'tc_33',
'answer': {
'aliases': ['Sunset Blvd', ...],
'normalized_aliases': ['sunset boulevard', ...],
'normalized_value': 'sunset boulevard',
'value': 'Sunset Boulevard'
}
}
"""

metric_type = "ce_loss"

def __init__(self, tokenizer, dataset_path="trivia_qa", dataset_name="rc.wikipedia.nocontext"):
super().__init__(
tokenizer=tokenizer,
dataset_path=dataset_path,
dataset_name=dataset_name,
)

def doc_to_text(self, doc):
return "\nQuestion: " + doc["question"] + "\nAnswer:"

def doc_to_continuations(self, doc):
return [" " + doc["answer"]["value"]]

def doc_to_label(self, doc):
return 0

def doc_to_domain_conditional(self, doc):
del doc
return "Answer:"


class NaturalQuestionsCELoss(ICLMultiChoiceTaskDataset):
"""Sample NaturalQuestions entity. For CE Loss we only consider the first answer entry to score.
{
'question': 'when was the last time anyone was on the moon',
'answer': ['14 December 1972 UTC', 'December 1972']
}
"""

metric_type = "ce_loss"

def __init__(self, tokenizer, dataset_path="nq_open", dataset_name=None):
super().__init__(
tokenizer=tokenizer,
dataset_path=dataset_path,
dataset_name=dataset_name,
)

def doc_to_text(self, doc):
return "\nQuestion: " + doc["question"] + "\nAnswer:"

def doc_to_continuations(self, doc):
return [" " + doc["answer"][0]]

def doc_to_label(self, doc):
return 0

def doc_to_domain_conditional(self, doc):
del doc
return "Answer:"


label_to_task_map = {
"piqa": PIQA,
"hellaswag": HellaSwag,
Expand All @@ -1258,6 +1347,7 @@ def doc_to_domain_conditional(self, doc):
"boolq": BoolQ,
"sciq": SciQ,
"arc_easy": ArcEasy,
"arc_easy_ppl": ArcEasyCELoss,
"arc_challenge": ArcChallenge,
"basic_arithmetic": BasicArithmetic,
"copa": COPA,
Expand All @@ -1267,6 +1357,8 @@ def doc_to_domain_conditional(self, doc):
"sst2": SST2,
"commonsense_qa": CommonsenseQA,
"social_iqa": SocialIQa,
"trivia_qa_wiki_ppl": TriviaQACELoss,
"natural_qs_open_ppl": NaturalQuestionsCELoss,
"mmlu_stem_test": (MMLU, {"dataset_name": "stem", "split": "test"}),
"mmlu_humanities_test": (MMLU, {"dataset_name": "humanities", "split": "test"}),
"mmlu_social_sciences_test": (MMLU, {"dataset_name": "social_sciences", "split": "test"}),
Expand Down
8 changes: 5 additions & 3 deletions olmo/eval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,11 @@ def reset_metrics(self) -> None:
def compute_metrics(self) -> Dict[str, float]:
if self.type == EvaluatorType.downstream:
assert isinstance(self.eval_metric, ICLMetric)
return {
f"eval/downstream/{self.label}_{self.eval_metric.metric_type}": self.eval_metric.compute().item(),
}
value = self.eval_metric.compute().item()
key = f"eval/downstream/{self.label}_{self.eval_metric.metric_type}"
if self.eval_metric.metric_type == "ce_loss":
key = key.replace("/downstream/", "/downstream_ce_loss/")
return {key: value}
elif self.type == EvaluatorType.lm:
# Metric(s) = cross entropy loss
metrics: Dict[str, Metric]
Expand Down

0 comments on commit 829f1d6

Please sign in to comment.