From 985c21205a77a1639d54c5c4aec0003f99719a81 Mon Sep 17 00:00:00 2001 From: David Brandfonbrener Date: Sat, 23 Mar 2024 20:23:47 -0400 Subject: [PATCH 1/4] Fix k_norm dimension The k_norm should be the dimension of the key vector, which is head_dim * effective_n_kv_heads, and not d_model / effective_n_kv_heads. Also fixing unnecessary assert. --- olmo/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/olmo/model.py b/olmo/model.py index 882f7d6e8..555e0ca81 100644 --- a/olmo/model.py +++ b/olmo/model.py @@ -425,10 +425,10 @@ def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache): self.k_norm: Optional[LayerNormBase] = None self.q_norm: Optional[LayerNormBase] = None if config.attention_layer_norm: - assert config.n_kv_heads is not None + assert config.effective_n_kv_heads is not None self.k_norm = LayerNormBase.build( config, - size=config.d_model // config.effective_n_kv_heads, + size=(config.d_model // config.n_heads) * config.effective_n_kv_heads, elementwise_affine=config.attention_layer_norm_with_affine, ) self.q_norm = LayerNormBase.build(config, elementwise_affine=config.attention_layer_norm_with_affine) From a57f380332e7021755d0a36cf79406b4423cf361 Mon Sep 17 00:00:00 2001 From: David Brandfonbrener Date: Sun, 24 Mar 2024 15:10:17 -0400 Subject: [PATCH 2/4] Fix mean reduction in cross entropy loss The mean reduction should reduce the s_loss to a scalar. Also, I'm not sure why division was being used here instead of multiplication by the mask, but I changed it to multiplication. --- olmo/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/olmo/train.py b/olmo/train.py index 1494a1b49..4454786e3 100644 --- a/olmo/train.py +++ b/olmo/train.py @@ -105,7 +105,7 @@ def cross_entropy_loss( z_squared = logits.logsumexp(-1).pow(2) if reduction == "mean": - z_squared = z_squared / (labels != ignore_index).mean() + z_squared = (z_squared * (labels != ignore_index)).mean() elif reduction == "sum": z_squared = (z_squared * (labels != ignore_index)).sum() From 322c8b65f877b88e9852de11ad7f1de5f00cf24a Mon Sep 17 00:00:00 2001 From: Dirk Groeneveld Date: Mon, 25 Mar 2024 14:26:23 -0700 Subject: [PATCH 3/4] Changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 13c1316c5..2a6cb79a3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Don't log garbage on nodes that aren't rank 0 - Don't crash in the HF code when we are referring to a tokenizer in a local file +- Fixed the size calculation for qk layer norm ## [v0.2.5](https://github.com/allenai/OLMo/releases/tag/v0.2.5) - 2024-03-06 From a164b62dd7f1777dcde13f2bc5aa96ef38b0a5ca Mon Sep 17 00:00:00 2001 From: yulinggu-cs Date: Tue, 26 Mar 2024 13:34:03 -0700 Subject: [PATCH 4/4] Add MMLU multiple choice 5-shot --- CHANGELOG.md | 1 + olmo/eval/downstream.py | 58 +++++++++++++++++++++++++++++++++++++---- 2 files changed, 54 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 13c1316c5..8c64ef1fb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added support for Grouped Query Attention. - Added commonsense_qa and social_iqa downstream evaluation tasks +- Added MMLU multiple choice (A/B/C/D) 5-shot variant downstream tasks ### Changed diff --git a/olmo/eval/downstream.py b/olmo/eval/downstream.py index b81f7927a..09df95de0 100644 --- a/olmo/eval/downstream.py +++ b/olmo/eval/downstream.py @@ -1163,6 +1163,7 @@ def __init__( dataset_name=None, split="validation", prompt_variations=None, + mc_labels=False, ): dataset_names = [] # Collect the relevant categories @@ -1178,9 +1179,15 @@ def __init__( if dataset_name in cats: dataset_names.append(name) self.dev_set = {} + self.mc_labels = mc_labels prompts: List[Union[None, str]] = [None] - if prompt_variations == 1: - prompts = [None, "inst", "inst+1", "inst+2", "inst+3", "inst+4", "inst+5"] + if prompt_variations is not None: + if prompt_variations == 1: + prompts = [None, "inst", "inst+1", "inst+2", "inst+3", "inst+4", "inst+5"] + elif prompt_variations == 2: + prompts = ["inst+5"] + else: + raise ValueError(f"Unknown prompt variations: {prompt_variations}") # Need to grab the dev set for the few-shot prompts for name in dataset_names: self.dev_set[name] = datasets.load_dataset( @@ -1195,7 +1202,20 @@ def __init__( ) def doc_to_text(self, doc): - output_text = "Question: " + doc["question"] + "\nAnswer:" + def format_example(doc, keys): + question_prefix = "" + if not self.mc_labels: + question_prefix = "Question: " # To make context more clear + question = question_prefix + doc["question"].strip() + choices = "" + if self.mc_labels: + choices = "".join([f"{key}. {choice}\n" for key, choice in zip(keys, doc["choices"])]) + prompt = f"{question}\n{choices}Answer:" + return prompt + + keys = ["A", "B", "C", "D"] + output_text = format_example(doc, keys) + if self.current_prompt is not None: prefix = "" if "inst" in self.current_prompt: @@ -1208,13 +1228,18 @@ def doc_to_text(self, doc): for idx, dev_doc in enumerate(dev_set): if idx >= num_shots_int: break - answer = dev_doc["choices"][dev_doc["answer"]] - prefix += "Question: " + dev_doc["question"] + "\nAnswer: " + answer + "\n\n" + if self.mc_labels: + answer = keys[dev_doc["answer"]] + else: + answer = dev_doc["choices"][dev_doc["answer"]] + prefix += format_example(dev_doc, keys) + " " + answer + "\n\n" output_text = prefix + output_text return output_text def doc_to_continuations(self, doc): # add spaces in front of continuation + if self.mc_labels: + return [" A", " B", " C", " D"] return [" " + choice for choice in doc["choices"]] def doc_to_label(self, doc): @@ -1254,4 +1279,27 @@ def doc_to_domain_conditional(self, doc): "mmlu_humanities_var": (MMLU, {"dataset_name": "humanities", "prompt_variations": 1}), "mmlu_social_sciences_var": (MMLU, {"dataset_name": "social_sciences", "prompt_variations": 1}), "mmlu_other_var": (MMLU, {"dataset_name": "other", "prompt_variations": 1}), + "mmlu_stem_mc_5shot": (MMLU, {"dataset_name": "stem", "prompt_variations": 2, "mc_labels": True}), + "mmlu_humanities_mc_5shot": (MMLU, {"dataset_name": "humanities", "prompt_variations": 2, "mc_labels": True}), + "mmlu_social_sciences_mc_5shot": ( + MMLU, + {"dataset_name": "social_sciences", "prompt_variations": 2, "mc_labels": True}, + ), + "mmlu_other_mc_5shot": (MMLU, {"dataset_name": "other", "prompt_variations": 2, "mc_labels": True}), + "mmlu_stem_mc_5shot_test": ( + MMLU, + {"dataset_name": "stem", "split": "test", "prompt_variations": 2, "mc_labels": True}, + ), + "mmlu_humanities_mc_5shot_test": ( + MMLU, + {"dataset_name": "humanities", "split": "test", "prompt_variations": 2, "mc_labels": True}, + ), + "mmlu_social_sciences_mc_5shot_test": ( + MMLU, + {"dataset_name": "social_sciences", "split": "test", "prompt_variations": 2, "mc_labels": True}, + ), + "mmlu_other_mc_5shot_test": ( + MMLU, + {"dataset_name": "other", "split": "test", "prompt_variations": 2, "mc_labels": True}, + ), }