diff --git a/olmo/eval/downstream.py b/olmo/eval/downstream.py index fb06ca90f..2a9d1d365 100644 --- a/olmo/eval/downstream.py +++ b/olmo/eval/downstream.py @@ -95,9 +95,7 @@ def update(self, batch: Dict[str, Any], lm_logits: torch.Tensor, dc_lm_logits=No torch.LongTensor((doc_id, cont_id, batch["label_id"][idx])).to(batch["label_id"][idx].device) ) - def compute(self) -> Dict[str, torch.Tensor]: - # Task "suffix" -> tensor - + def compute(self) -> torch.Tensor: # states should have been synced from all accelerators at this point # account for duplicates here because of DistributedSampler compensating for drop_last=False loglikelihood_dict: Dict[int, Dict[int, float]] = {} @@ -118,9 +116,6 @@ def compute(self) -> Dict[str, torch.Tensor]: # compute acc correct = [] - soft_scores = [] - soft_log_scores = [] - preds: Optional[List[float]] = None labels: Optional[List[int]] = None if self.metric_type == "f1": @@ -145,15 +140,14 @@ def compute(self) -> Dict[str, torch.Tensor]: continue if self.metric_type in ["ce_loss", "bpb"]: correct.append(loglikelihoods[0]) # Only one answer is scored - elif self.metric_type == "f1": + else: + correct.append(1.0 if torch.argmax(loglikelihoods).item() == label_dict[doc_id] else 0.0) + + if self.metric_type == "f1": assert preds is not None assert labels is not None preds.append(torch.argmax(loglikelihoods).item()) labels.append(label_dict[doc_id]) - else: - correct.append(1.0 if torch.argmax(loglikelihoods).item() == label_dict[doc_id] else 0.0) - soft_scores.append(torch.softmax(loglikelihoods, dim=0)[label_dict[doc_id]].item()) - soft_log_scores.append(torch.log_softmax(loglikelihoods, dim=0)[label_dict[doc_id]].item()) if self.metric_type == "f1": assert preds is not None @@ -163,15 +157,7 @@ def compute(self) -> Dict[str, torch.Tensor]: else: score = sum(correct) / len(correct) - outputs = { - "": torch.tensor(score), - } - - if soft_scores: - outputs["_soft"] = torch.tensor(sum(soft_scores) / len(soft_scores)) - outputs["_soft_log"] = torch.tensor(sum(soft_log_scores) / len(soft_log_scores)) - - return outputs + return torch.tensor(score) class ICLMultiChoiceTaskDataset(metaclass=abc.ABCMeta): diff --git a/olmo/eval/evaluator.py b/olmo/eval/evaluator.py index e71a7d859..29ef049d7 100644 --- a/olmo/eval/evaluator.py +++ b/olmo/eval/evaluator.py @@ -29,14 +29,11 @@ def reset_metrics(self) -> None: def compute_metrics(self) -> Dict[str, float]: if self.type == EvaluatorType.downstream: assert isinstance(self.eval_metric, ICLMetric) - suffix_to_value = self.eval_metric.compute() - outputs = {} - for suffix, value in suffix_to_value.items(): - key = f"eval/downstream/{self.label}_{self.eval_metric.metric_type}{suffix}" - if self.eval_metric.metric_type in ["ce_loss", "bpb"]: - key = key.replace("/downstream/", f"/downstream_{self.eval_metric.metric_type}/") - outputs[key] = value.item() - return outputs + value = self.eval_metric.compute().item() + key = f"eval/downstream/{self.label}_{self.eval_metric.metric_type}" + if self.eval_metric.metric_type in ["ce_loss", "bpb"]: + key = key.replace("/downstream/", f"/downstream_{self.eval_metric.metric_type}/") + return {key: value} elif self.type == EvaluatorType.lm: # Metric(s) = cross entropy loss metrics: Dict[str, Metric] @@ -55,7 +52,7 @@ def compute_metrics(self) -> Dict[str, float]: # This can happen when the evaluator contains multiple tasks/datasets and we didn't # get to this one within the current evaluation loop. metric.update(0.0, 0.0) - loss = metric.compute()[""] # always no suffix + loss = metric.compute() if loss.isnan().item(): # This can happen when the evaluator contains multiple tasks/datasets and we didn't # get to this one within the current evaluation loop.