Skip to content

Commit

Permalink
Merge branch 'main' into dl/fix-embedding-resize
Browse files Browse the repository at this point in the history
  • Loading branch information
2015aroras authored Mar 28, 2024
2 parents e186f5b + 71f7014 commit aa5687d
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 8 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Added support for Grouped Query Attention.
- Added commonsense_qa and social_iqa downstream evaluation tasks
- Added MMLU multiple choice (A/B/C/D) 5-shot variant downstream tasks

### Changed

Expand All @@ -27,6 +28,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Don't log garbage on nodes that aren't rank 0
- Don't crash in the HF code when we are referring to a tokenizer in a local file
- Corrected the `resize_token_embeddings` method in the `OLMoForCausalLM` class to properly update the token embeddings when resizing the vocabulary.
- Fixed the size calculation for qk layer norm

## [v0.2.5](https://github.com/allenai/OLMo/releases/tag/v0.2.5) - 2024-03-06

Expand Down
58 changes: 53 additions & 5 deletions olmo/eval/downstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -1163,6 +1163,7 @@ def __init__(
dataset_name=None,
split="validation",
prompt_variations=None,
mc_labels=False,
):
dataset_names = []
# Collect the relevant categories
Expand All @@ -1178,9 +1179,15 @@ def __init__(
if dataset_name in cats:
dataset_names.append(name)
self.dev_set = {}
self.mc_labels = mc_labels
prompts: List[Union[None, str]] = [None]
if prompt_variations == 1:
prompts = [None, "inst", "inst+1", "inst+2", "inst+3", "inst+4", "inst+5"]
if prompt_variations is not None:
if prompt_variations == 1:
prompts = [None, "inst", "inst+1", "inst+2", "inst+3", "inst+4", "inst+5"]
elif prompt_variations == 2:
prompts = ["inst+5"]
else:
raise ValueError(f"Unknown prompt variations: {prompt_variations}")
# Need to grab the dev set for the few-shot prompts
for name in dataset_names:
self.dev_set[name] = datasets.load_dataset(
Expand All @@ -1195,7 +1202,20 @@ def __init__(
)

def doc_to_text(self, doc):
output_text = "Question: " + doc["question"] + "\nAnswer:"
def format_example(doc, keys):
question_prefix = ""
if not self.mc_labels:
question_prefix = "Question: " # To make context more clear
question = question_prefix + doc["question"].strip()
choices = ""
if self.mc_labels:
choices = "".join([f"{key}. {choice}\n" for key, choice in zip(keys, doc["choices"])])
prompt = f"{question}\n{choices}Answer:"
return prompt

keys = ["A", "B", "C", "D"]
output_text = format_example(doc, keys)

if self.current_prompt is not None:
prefix = ""
if "inst" in self.current_prompt:
Expand All @@ -1208,13 +1228,18 @@ def doc_to_text(self, doc):
for idx, dev_doc in enumerate(dev_set):
if idx >= num_shots_int:
break
answer = dev_doc["choices"][dev_doc["answer"]]
prefix += "Question: " + dev_doc["question"] + "\nAnswer: " + answer + "\n\n"
if self.mc_labels:
answer = keys[dev_doc["answer"]]
else:
answer = dev_doc["choices"][dev_doc["answer"]]
prefix += format_example(dev_doc, keys) + " " + answer + "\n\n"
output_text = prefix + output_text
return output_text

def doc_to_continuations(self, doc):
# add spaces in front of continuation
if self.mc_labels:
return [" A", " B", " C", " D"]
return [" " + choice for choice in doc["choices"]]

def doc_to_label(self, doc):
Expand Down Expand Up @@ -1254,4 +1279,27 @@ def doc_to_domain_conditional(self, doc):
"mmlu_humanities_var": (MMLU, {"dataset_name": "humanities", "prompt_variations": 1}),
"mmlu_social_sciences_var": (MMLU, {"dataset_name": "social_sciences", "prompt_variations": 1}),
"mmlu_other_var": (MMLU, {"dataset_name": "other", "prompt_variations": 1}),
"mmlu_stem_mc_5shot": (MMLU, {"dataset_name": "stem", "prompt_variations": 2, "mc_labels": True}),
"mmlu_humanities_mc_5shot": (MMLU, {"dataset_name": "humanities", "prompt_variations": 2, "mc_labels": True}),
"mmlu_social_sciences_mc_5shot": (
MMLU,
{"dataset_name": "social_sciences", "prompt_variations": 2, "mc_labels": True},
),
"mmlu_other_mc_5shot": (MMLU, {"dataset_name": "other", "prompt_variations": 2, "mc_labels": True}),
"mmlu_stem_mc_5shot_test": (
MMLU,
{"dataset_name": "stem", "split": "test", "prompt_variations": 2, "mc_labels": True},
),
"mmlu_humanities_mc_5shot_test": (
MMLU,
{"dataset_name": "humanities", "split": "test", "prompt_variations": 2, "mc_labels": True},
),
"mmlu_social_sciences_mc_5shot_test": (
MMLU,
{"dataset_name": "social_sciences", "split": "test", "prompt_variations": 2, "mc_labels": True},
),
"mmlu_other_mc_5shot_test": (
MMLU,
{"dataset_name": "other", "split": "test", "prompt_variations": 2, "mc_labels": True},
),
}
4 changes: 2 additions & 2 deletions olmo/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,10 +425,10 @@ def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache):
self.k_norm: Optional[LayerNormBase] = None
self.q_norm: Optional[LayerNormBase] = None
if config.attention_layer_norm:
assert config.n_kv_heads is not None
assert config.effective_n_kv_heads is not None
self.k_norm = LayerNormBase.build(
config,
size=config.d_model // config.effective_n_kv_heads,
size=(config.d_model // config.n_heads) * config.effective_n_kv_heads,
elementwise_affine=config.attention_layer_norm_with_affine,
)
self.q_norm = LayerNormBase.build(config, elementwise_affine=config.attention_layer_norm_with_affine)
Expand Down
2 changes: 1 addition & 1 deletion olmo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def cross_entropy_loss(

z_squared = logits.logsumexp(-1).pow(2)
if reduction == "mean":
z_squared = z_squared / (labels != ignore_index).mean()
z_squared = (z_squared * (labels != ignore_index)).mean()
elif reduction == "sum":
z_squared = (z_squared * (labels != ignore_index)).sum()

Expand Down

0 comments on commit aa5687d

Please sign in to comment.