Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MMLU prompt variants #484

Merged
merged 10 commits into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added
- Added `output_hidden_states` argument and associated functionality to `OLMo` and `OLMoForCausalLM` to return model intermediate hidden states.
- Added MMLU downstream evaluation tasks.
- Added MMLU downstream evaluation tasks, with prompt variations.
- Added support for PyTorch v2.2.

## [v0.2.4](https://github.com/allenai/OLMo/releases/tag/v0.2.4) - 2024-02-02
Expand Down
161 changes: 112 additions & 49 deletions olmo/eval/downstream.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import abc
import logging
import re
from typing import Any, ClassVar, Dict, List, Optional, Sequence, Union

Expand All @@ -10,6 +11,8 @@

from ..tokenizer import Tokenizer

log = logging.getLogger(__name__)


class ICLMetric(Metric):
# update method does not require access to global metric state
Expand Down Expand Up @@ -152,13 +155,17 @@ def __init__(
dataset_name: Union[str, Sequence[str], None] = None,
model_ctx_len: int = 2048,
split="validation",
prompts=[None], # List of prompt variants to use
):
super().__init__()

self.tokenizer = tokenizer
self.dataset_path = dataset_path
self.dataset_name = dataset_name
self.model_ctx_len = model_ctx_len
self.prompts = prompts
self.current_prompt = None
self.log_instances = 5 # Log the first few instances as a sanity check

self.samples: List[Dict[str, Any]] = []
dataset_names: Sequence[Optional[str]]
Expand All @@ -174,6 +181,7 @@ def __init__(
path=self.dataset_path,
name=ds_name,
split=split,
trust_remote_code=True,
)
)
self.dataset = datasets.concatenate_datasets(dataset_list)
Expand All @@ -191,51 +199,65 @@ def prep_examples(self):
"""Append doc_ids to each example so that they are processed together in the metric"""
doc_id = 0
for doc in self.dataset:
# from EAI harness
# how this all works:
# CTX CONT
# inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1]
# gpt2 \ \
# logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the
# cont_toks 4 5 6 7 8 9 [:, -len(continuation_enc):, :self.vocab_size] slice

continuations = self.doc_to_continuations(doc)
label_id = self.doc_to_label(doc)
ctx = self.token_encode(self.doc_to_text(doc))
dc = self.token_encode(self.doc_to_domain_conditional(doc))

for cont_id, continuation_str in enumerate(continuations):
cont_str_len = len(continuation_str) - 1 # continuation contain leading blank
continuation = self.token_encode(continuation_str)

# query, remove last token from continuation, truncate from left is longer than model ctx length
query = ctx + continuation[:-1]
query = query[-self.model_ctx_len :]

# get domain conditional query
# we don't expect this to be longer than self.model_ctx_len and it won't make sense to truncate from left
dc_query = dc + continuation[:-1]

# form a sample
self.samples.append(
{
"doc_id": doc_id,
"cont_id": cont_id,
"ctx": ctx,
"continuation": continuation,
"ctx_len": len(ctx),
"dc_len": len(dc),
"cont_len": len(
continuation
), # even if query has last token removed, LM will output same cont len
"cont_str_len": cont_str_len,
"query": query, # remove last token from continuation
"dc_query": dc_query,
"label_id": label_id,
}
)

doc_id += 1
for prompt in self.prompts:
self.current_prompt = prompt
# from EAI harness
# how this all works:
# CTX CONT
# inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1]
# gpt2 \ \
# logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the
# cont_toks 4 5 6 7 8 9 [:, -len(continuation_enc):, :self.vocab_size] slice

continuations = self.doc_to_continuations(doc)
label_id = self.doc_to_label(doc)
doc_text = self.doc_to_text(doc)
ctx = self.token_encode(doc_text)
dc = self.token_encode(self.doc_to_domain_conditional(doc))
if self.log_instances > 0:
self.log_instances -= 1
ds_name = self.dataset_name
if isinstance(ds_name, list):
ds_name = ds_name[0]
log.info(
f"Sample doc from ({self.dataset_path}, {ds_name}, {self.current_prompt}):"
+ f"\ndoc_text: {doc_text}\ncontinuations: {continuations}"
)

for cont_id, continuation_str in enumerate(continuations):
cont_str_len = len(continuation_str) - 1 # continuation contain leading blank
continuation = self.token_encode(continuation_str)

# query, remove last token from continuation, truncate from left is longer than model ctx length
query = ctx + continuation[:-1]
query = query[-self.model_ctx_len :]
# this will be different from len(ctx) when truncated by model_ctx_len
actual_ctx_len = len(query) - len(continuation) + 1

# get domain conditional query
# we don't expect this to be longer than self.model_ctx_len and it won't make sense to truncate from left
dc_query = dc + continuation[:-1]

# form a sample
self.samples.append(
{
"doc_id": doc_id,
"cont_id": cont_id,
"ctx": ctx,
"continuation": continuation,
"ctx_len": actual_ctx_len,
"dc_len": len(dc),
"cont_len": len(
continuation
), # even if query has last token removed, LM will output same cont len
"cont_str_len": cont_str_len,
"query": query, # remove last token from continuation
"dc_query": dc_query,
"label_id": label_id,
}
)

doc_id += 1

def pad_tokens_until_max(self, tokens, max_len=2048):
"""truncate from left if len(tokens) > model_ctx_len, max_len is not considered then
Expand Down Expand Up @@ -655,7 +677,7 @@ def __init__(self, tokenizer, dataset_path="sciq", dataset_name=None):
)

def doc_to_text(self, doc):
return doc["support"] + "\nQuestion: " + doc["question"] + "\nAnswer:".strip()
return doc["support"].strip() + "\nQuestion: " + doc["question"] + "\nAnswer:"

def doc_to_continuations(self, doc):
# add spaces in front of continuation
Expand Down Expand Up @@ -1055,7 +1077,14 @@ class MMLU(ICLMultiChoiceTaskDataset):
"other": ["other", "business", "health"],
}

def __init__(self, tokenizer, dataset_path="hails/mmlu_no_train", dataset_name=None, split="validation"):
def __init__(
self,
tokenizer,
dataset_path="hails/mmlu_no_train",
dataset_name=None,
split="validation",
prompt_variations=None,
):
dataset_names = []
# Collect the relevant categories
if dataset_name in MMLU._categories:
Expand All @@ -1069,10 +1098,40 @@ def __init__(self, tokenizer, dataset_path="hails/mmlu_no_train", dataset_name=N
for name, cats in MMLU._subcategories.items():
if dataset_name in cats:
dataset_names.append(name)
super().__init__(tokenizer=tokenizer, dataset_path=dataset_path, dataset_name=dataset_names, split=split)
self.dev_set = {}
if prompt_variations == 1:
prompts = [None, "inst", "inst+1", "inst+2", "inst+3", "inst+4", "inst+5"]
# Need to grab the dev set for the few-shot prompts
for name in dataset_names:
self.dev_set[name] = datasets.load_dataset(
path=dataset_path, name=name, split="dev", trust_remote_code=True
)
super().__init__(
tokenizer=tokenizer,
dataset_path=dataset_path,
dataset_name=dataset_names,
split=split,
prompts=prompts,
)

def doc_to_text(self, doc):
return "Question: " + doc["question"] + "\nAnswer:"
output_text = "Question: " + doc["question"] + "\nAnswer:"
if self.current_prompt is not None:
prefix = ""
if "inst" in self.current_prompt:
subject = doc.get("subject").replace("_", " ")
prefix = f"The following are multiple choice questions (with answers) about {subject}:\n\n"
num_shots = re.findall("\\+(\\d+)", self.current_prompt)
if num_shots:
dev_set = self.dev_set.get(doc.get("subject"), [])
num_shots_int = int(num_shots[0])
for idx, dev_doc in enumerate(dev_set):
if idx >= num_shots_int:
break
answer = dev_doc["choices"][dev_doc["answer"]]
prefix += "Question: " + dev_doc["question"] + "\nAnswer: " + answer + "\n\n"
output_text = prefix + output_text
return output_text

def doc_to_continuations(self, doc):
# add spaces in front of continuation
Expand Down Expand Up @@ -1108,4 +1167,8 @@ def doc_to_domain_conditional(self, doc):
"mmlu_humanities": (MMLU, {"dataset_name": "humanities"}),
"mmlu_social_sciences": (MMLU, {"dataset_name": "social_sciences"}),
"mmlu_other": (MMLU, {"dataset_name": "other"}),
"mmlu_stem_var": (MMLU, {"dataset_name": "stem", "prompt_variations": 1}),
"mmlu_humanities_var": (MMLU, {"dataset_name": "humanities", "prompt_variations": 1}),
"mmlu_social_sciences_var": (MMLU, {"dataset_name": "social_sciences", "prompt_variations": 1}),
"mmlu_other_var": (MMLU, {"dataset_name": "other", "prompt_variations": 1}),
}
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ dependencies = [
"numpy",
"torch>=2.0,<2.3",
"omegaconf",
"fsspec==2023.5.0", # temporary fix for HF dataset downloads
"rich",
"boto3",
"google-cloud-storage",
Expand Down
Loading