From 9255ea67bea46538f51f744680d704d88acb4b56 Mon Sep 17 00:00:00 2001 From: Oyvind Tafjord Date: Fri, 22 Mar 2024 17:27:17 -0700 Subject: [PATCH 1/8] Add ce_loss metric type --- olmo/eval/downstream.py | 24 ++++++++++++++++++++---- olmo/eval/evaluator.py | 2 +- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/olmo/eval/downstream.py b/olmo/eval/downstream.py index b81f7927a..c24daedbc 100644 --- a/olmo/eval/downstream.py +++ b/olmo/eval/downstream.py @@ -19,7 +19,7 @@ class ICLMetric(Metric): full_state_update: bool = False def __init__(self, metric_type="acc") -> None: - """metric_type: f1, acc, len_norm, pmi_dc""" + """metric_type: f1, acc, len_norm, pmi_dc, ce_loss""" super().__init__(sync_on_compute=True) self.metric_type = metric_type @@ -33,7 +33,7 @@ def reset( self.loglikelihoods = [] self.labels = [] - def update(self, batch: Dict[str, Any], lm_logits: torch.Tensor, dc_lm_logits=None): + def update(self, batch: Dict[str, Any], lm_logits: torch.Tensor, dc_lm_logits=None, ce_loss=None): lm_logits = F.log_softmax(lm_logits, dim=-1) if self.metric_type == "pmi_dc": @@ -69,6 +69,8 @@ def update(self, batch: Dict[str, Any], lm_logits: torch.Tensor, dc_lm_logits=No log_likelihood = ( torch.gather(lm_cont_logits, 1, cont_tokens.unsqueeze(-1)).sum() / batch["cont_str_len"][idx] ) + elif self.metric_type == "ce_loss": + log_likelihood = ce_loss[idx] else: raise ValueError(self.metric_type) @@ -123,8 +125,10 @@ def compute(self) -> torch.Tensor: if skip_document: continue - - correct.append(1.0 if torch.argmax(loglikelihoods).item() == label_dict[doc_id] else 0.0) + if self.metric_type == "ce_loss": + correct.append(loglikelihoods[0]) # Only one answer is scored + else: + correct.append(1.0 if torch.argmax(loglikelihoods).item() == label_dict[doc_id] else 0.0) if self.metric_type == "f1": assert preds is not None @@ -754,6 +758,17 @@ def __init__(self, tokenizer, dataset_path="ai2_arc", dataset_name="ARC-Challeng ) +class ArcEasyCELoss(ArcEasy): + """ArcEasyCELoss is ARCEasy using an alternate ce_loss metric""" + + metric_type = "ce_loss" + + def doc_to_continuations(self, doc): + # We only consider the correct answer for this metric + answer = doc["choices"]["text"][self.doc_to_label(doc)] + return [" " + answer] + + class BasicArithmetic(ArcEasy): """This is a basic arithmetic task follows the same prompt format as ArcEasy. Example: @@ -1233,6 +1248,7 @@ def doc_to_domain_conditional(self, doc): "boolq": BoolQ, "sciq": SciQ, "arc_easy": ArcEasy, + "arc_easy_ppl": ArcEasyCELoss, "arc_challenge": ArcChallenge, "basic_arithmetic": BasicArithmetic, "copa": COPA, diff --git a/olmo/eval/evaluator.py b/olmo/eval/evaluator.py index ddc85a603..3a5ca9c2d 100644 --- a/olmo/eval/evaluator.py +++ b/olmo/eval/evaluator.py @@ -70,7 +70,7 @@ def update_metrics( ) -> None: if self.type == EvaluatorType.downstream: assert isinstance(self.eval_metric, ICLMetric) - self.eval_metric.update(batch, logits) # type: ignore + self.eval_metric.update(batch, logits, ce_loss=ce_loss) # type: ignore elif self.type == EvaluatorType.lm: # Metric(s) = cross entropy loss for metadata, instance_loss in zip(batch["metadata"], ce_loss): From 64b5e660203f397da5d9dea99df20cfbb77dd68f Mon Sep 17 00:00:00 2001 From: Oyvind Tafjord Date: Fri, 22 Mar 2024 17:38:24 -0700 Subject: [PATCH 2/8] Log ce_loss evaluations to separate panel --- olmo/eval/evaluator.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/olmo/eval/evaluator.py b/olmo/eval/evaluator.py index 3a5ca9c2d..abda1c96f 100644 --- a/olmo/eval/evaluator.py +++ b/olmo/eval/evaluator.py @@ -29,9 +29,11 @@ def reset_metrics(self) -> None: def compute_metrics(self) -> Dict[str, float]: if self.type == EvaluatorType.downstream: assert isinstance(self.eval_metric, ICLMetric) - return { - f"eval/downstream/{self.label}_{self.eval_metric.metric_type}": self.eval_metric.compute().item(), - } + value = self.eval_metric.compute().item() + key = f"eval/downstream/{self.label}_{self.eval_metric.metric_type}" + if self.eval_metric.metric_type == "ce_loss": + key = key.replace("/downstream/", "/downstream_ce_loss/") + return {key: value} elif self.type == EvaluatorType.lm: # Metric(s) = cross entropy loss metrics: Dict[str, Metric] From c2e630e4282eaee3598bad7f63cba9b0e0eaf77d Mon Sep 17 00:00:00 2001 From: Oyvind Tafjord Date: Fri, 22 Mar 2024 17:49:34 -0700 Subject: [PATCH 3/8] Add trivia_qa and natural_qs tasks (ce loss) --- olmo/eval/downstream.py | 76 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 76 insertions(+) diff --git a/olmo/eval/downstream.py b/olmo/eval/downstream.py index c24daedbc..bf8209d17 100644 --- a/olmo/eval/downstream.py +++ b/olmo/eval/downstream.py @@ -768,6 +768,9 @@ def doc_to_continuations(self, doc): answer = doc["choices"]["text"][self.doc_to_label(doc)] return [" " + answer] + def doc_to_label(self, doc): + return 0 + class BasicArithmetic(ArcEasy): """This is a basic arithmetic task follows the same prompt format as ArcEasy. @@ -1240,6 +1243,77 @@ def doc_to_domain_conditional(self, doc): return "Answer:" +class TriviaQACELoss(ICLMultiChoiceTaskDataset): + """Sample TriviaQA entity with some fields suppressed. For CE Loss we only consider the "value" + field as the answer to score. + + { + 'question': 'Which Lloyd Webber musical premiered in the US on 10th December 1993?', + 'question_id': 'tc_33', + 'answer': { + 'aliases': ['Sunset Blvd', ...], + 'normalized_aliases': ['sunset boulevard', ...], + 'normalized_value': 'sunset boulevard', + 'value': 'Sunset Boulevard' + } + } + """ + + metric_type = "ce_loss" + + def __init__(self, tokenizer, dataset_path="trivia_qa", dataset_name="rc.wikipedia.nocontext"): + super().__init__( + tokenizer=tokenizer, + dataset_path=dataset_path, + dataset_name=dataset_name, + ) + + def doc_to_text(self, doc): + return "\nQuestion: " + doc["question"] + "\nAnswer:" + + def doc_to_continuations(self, doc): + return [" " + doc["answer"]["value"]] + + def doc_to_label(self, doc): + return 0 + + def doc_to_domain_conditional(self, doc): + del doc + return "Answer:" + + +class NaturalQuestionsCELoss(ICLMultiChoiceTaskDataset): + """Sample NaturalQuestions entity. For CE Loss we only consider the first answer entry to score. + + { + 'question': 'when was the last time anyone was on the moon', + 'answer': ['14 December 1972 UTC', 'December 1972'] + } + """ + + metric_type = "ce_loss" + + def __init__(self, tokenizer, dataset_path="trivia_qa", dataset_name="rc.wikipedia.nocontext"): + super().__init__( + tokenizer=tokenizer, + dataset_path=dataset_path, + dataset_name=dataset_name, + ) + + def doc_to_text(self, doc): + return "\nQuestion: " + doc["question"] + "\nAnswer:" + + def doc_to_continuations(self, doc): + return [" " + doc["answer"][0]] + + def doc_to_label(self, doc): + return 0 + + def doc_to_domain_conditional(self, doc): + del doc + return "Answer:" + + label_to_task_map = { "piqa": PIQA, "hellaswag": HellaSwag, @@ -1258,6 +1332,8 @@ def doc_to_domain_conditional(self, doc): "sst2": SST2, "commonsense_qa": CommonsenseQA, "social_iqa": SocialIQa, + "trivia_qa_web_ppl": TriviaQACELoss, + "natural_qs_open_ppl": NaturalQuestionsCELoss, "mmlu_stem_test": (MMLU, {"dataset_name": "stem", "split": "test"}), "mmlu_humanities_test": (MMLU, {"dataset_name": "humanities", "split": "test"}), "mmlu_social_sciences_test": (MMLU, {"dataset_name": "social_sciences", "split": "test"}), From cba6076f07396bddadd9aaad10118191a008815d Mon Sep 17 00:00:00 2001 From: Oyvind Tafjord Date: Fri, 22 Mar 2024 17:58:03 -0700 Subject: [PATCH 4/8] Fix bug --- olmo/eval/downstream.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/olmo/eval/downstream.py b/olmo/eval/downstream.py index bf8209d17..ee8d8ae2d 100644 --- a/olmo/eval/downstream.py +++ b/olmo/eval/downstream.py @@ -1293,7 +1293,7 @@ class NaturalQuestionsCELoss(ICLMultiChoiceTaskDataset): metric_type = "ce_loss" - def __init__(self, tokenizer, dataset_path="trivia_qa", dataset_name="rc.wikipedia.nocontext"): + def __init__(self, tokenizer, dataset_path="nq_open", dataset_name=None): super().__init__( tokenizer=tokenizer, dataset_path=dataset_path, From 13de1353fcf1a6fcd43a0e5437b96a830682815d Mon Sep 17 00:00:00 2001 From: Oyvind Tafjord Date: Fri, 22 Mar 2024 18:04:13 -0700 Subject: [PATCH 5/8] Fix and simplify ce_loss computation --- olmo/eval/downstream.py | 6 ++---- olmo/eval/evaluator.py | 2 +- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/olmo/eval/downstream.py b/olmo/eval/downstream.py index ee8d8ae2d..ccfc1ea55 100644 --- a/olmo/eval/downstream.py +++ b/olmo/eval/downstream.py @@ -33,7 +33,7 @@ def reset( self.loglikelihoods = [] self.labels = [] - def update(self, batch: Dict[str, Any], lm_logits: torch.Tensor, dc_lm_logits=None, ce_loss=None): + def update(self, batch: Dict[str, Any], lm_logits: torch.Tensor, dc_lm_logits=None): lm_logits = F.log_softmax(lm_logits, dim=-1) if self.metric_type == "pmi_dc": @@ -65,12 +65,10 @@ def update(self, batch: Dict[str, Any], lm_logits: torch.Tensor, dc_lm_logits=No elif self.metric_type == "acc" or self.metric_type == "f1": # gather log-probs at continuation token indices log_likelihood = torch.gather(lm_cont_logits, 1, cont_tokens.unsqueeze(-1)).sum() - elif self.metric_type == "len_norm": + elif self.metric_type == "len_norm" or self.metric_type == "ce_loss": log_likelihood = ( torch.gather(lm_cont_logits, 1, cont_tokens.unsqueeze(-1)).sum() / batch["cont_str_len"][idx] ) - elif self.metric_type == "ce_loss": - log_likelihood = ce_loss[idx] else: raise ValueError(self.metric_type) diff --git a/olmo/eval/evaluator.py b/olmo/eval/evaluator.py index abda1c96f..85a20da5d 100644 --- a/olmo/eval/evaluator.py +++ b/olmo/eval/evaluator.py @@ -72,7 +72,7 @@ def update_metrics( ) -> None: if self.type == EvaluatorType.downstream: assert isinstance(self.eval_metric, ICLMetric) - self.eval_metric.update(batch, logits, ce_loss=ce_loss) # type: ignore + self.eval_metric.update(batch, logits) # type: ignore elif self.type == EvaluatorType.lm: # Metric(s) = cross entropy loss for metadata, instance_loss in zip(batch["metadata"], ce_loss): From ed9250b3dcf3c3a993e853001df54f9c31e6a58b Mon Sep 17 00:00:00 2001 From: Oyvind Tafjord Date: Fri, 22 Mar 2024 18:06:19 -0700 Subject: [PATCH 6/8] Fix sign of ce_loss --- olmo/eval/downstream.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/olmo/eval/downstream.py b/olmo/eval/downstream.py index ccfc1ea55..f5e8e801b 100644 --- a/olmo/eval/downstream.py +++ b/olmo/eval/downstream.py @@ -69,6 +69,8 @@ def update(self, batch: Dict[str, Any], lm_logits: torch.Tensor, dc_lm_logits=No log_likelihood = ( torch.gather(lm_cont_logits, 1, cont_tokens.unsqueeze(-1)).sum() / batch["cont_str_len"][idx] ) + if self.metric_type == "ce_loss": + log_likelihood = -log_likelihood else: raise ValueError(self.metric_type) From 9a9cc8e8a438e6f377f9efad580d955d04c1c536 Mon Sep 17 00:00:00 2001 From: Oyvind Tafjord Date: Fri, 22 Mar 2024 18:18:48 -0700 Subject: [PATCH 7/8] Update CHANGELOG.md --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 13c1316c5..9cd559257 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added support for Grouped Query Attention. - Added commonsense_qa and social_iqa downstream evaluation tasks +- Added ce_loss metric, with TriviaQA and NaturalQuestions tasks ### Changed From ea54c12c8517d384fa19637743d37b2fef9d9e2b Mon Sep 17 00:00:00 2001 From: Oyvind Tafjord Date: Fri, 22 Mar 2024 18:22:55 -0700 Subject: [PATCH 8/8] Rename _web to _wiki --- olmo/eval/downstream.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/olmo/eval/downstream.py b/olmo/eval/downstream.py index f5e8e801b..ccb37e0ea 100644 --- a/olmo/eval/downstream.py +++ b/olmo/eval/downstream.py @@ -1332,7 +1332,7 @@ def doc_to_domain_conditional(self, doc): "sst2": SST2, "commonsense_qa": CommonsenseQA, "social_iqa": SocialIQa, - "trivia_qa_web_ppl": TriviaQACELoss, + "trivia_qa_wiki_ppl": TriviaQACELoss, "natural_qs_open_ppl": NaturalQuestionsCELoss, "mmlu_stem_test": (MMLU, {"dataset_name": "stem", "split": "test"}), "mmlu_humanities_test": (MMLU, {"dataset_name": "humanities", "split": "test"}),