From 8756fa4332bc2a03b8def6dcaea76fbb12508625 Mon Sep 17 00:00:00 2001 From: Walter Laurito Date: Fri, 11 Aug 2023 20:49:08 +0000 Subject: [PATCH 1/4] remove first layer (embeddings) from layer_indices --- elk/extraction/extraction.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index 26e8a7f1..aad05ff7 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -194,7 +194,8 @@ def extract_hiddens( seed=cfg.seed, ) - layer_indices = cfg.layers or tuple(range(model.config.num_hidden_layers)) + layer_indices = cfg.layers or tuple(range(model.config.num_hidden_layers)[1:]) + global_max_examples = cfg.max_examples[0 if split_type == "train" else 1] @@ -366,12 +367,13 @@ def hidden_features(cfg: Extract) -> tuple[DatasetInfo, Features]: if num_dropped: print(f"Dropping {num_dropped} non-multiple choice templates") + layer_indices = cfg.layers or tuple(range(model_cfg.num_hidden_layers)[1:]) layer_cols = { f"hidden_{layer}": Array3D( dtype="int16", shape=(num_variants, num_classes, model_cfg.hidden_size), ) - for layer in cfg.layers or range(model_cfg.num_hidden_layers) + for layer in layer_indices } other_cols = { "variant_ids": Sequence( @@ -458,6 +460,7 @@ def extract( mp.set_start_method("spawn", force=True) # type: ignore[attr-defined] ds = dict() + breakpoint() for split, builder in builders.items(): builder.download_and_prepare( download_mode=DownloadMode.FORCE_REDOWNLOAD if disable_cache else None, From 015b5f523dc7b3ba90dbe15b156e60a350d2bd7e Mon Sep 17 00:00:00 2001 From: Walter Laurito Date: Fri, 11 Aug 2023 21:17:37 +0000 Subject: [PATCH 2/4] remove breakpoint --- elk/extraction/extraction.py | 1 - 1 file changed, 1 deletion(-) diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index aad05ff7..5f6db432 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -460,7 +460,6 @@ def extract( mp.set_start_method("spawn", force=True) # type: ignore[attr-defined] ds = dict() - breakpoint() for split, builder in builders.items(): builder.download_and_prepare( download_mode=DownloadMode.FORCE_REDOWNLOAD if disable_cache else None, From cc095cb5e2a574b3cce82933258fdba63daf7151 Mon Sep 17 00:00:00 2001 From: Walter Laurito Date: Fri, 11 Aug 2023 21:18:51 +0000 Subject: [PATCH 3/4] pre-commit cleanup --- elk/extraction/extraction.py | 1 - 1 file changed, 1 deletion(-) diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index 5f6db432..30bb0adc 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -196,7 +196,6 @@ def extract_hiddens( layer_indices = cfg.layers or tuple(range(model.config.num_hidden_layers)[1:]) - global_max_examples = cfg.max_examples[0 if split_type == "train" else 1] # break `max_examples` among the processes roughly equally From af4c9d11b710c8c60c832d4ce8ff6a7f5607abe4 Mon Sep 17 00:00:00 2001 From: Walter Laurito Date: Fri, 11 Aug 2023 21:50:21 +0000 Subject: [PATCH 4/4] replace slicing --- elk/extraction/extraction.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index 30bb0adc..fe4fdc84 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -194,7 +194,7 @@ def extract_hiddens( seed=cfg.seed, ) - layer_indices = cfg.layers or tuple(range(model.config.num_hidden_layers)[1:]) + layer_indices = cfg.layers or tuple(range(1, model.config.num_hidden_layers)) global_max_examples = cfg.max_examples[0 if split_type == "train" else 1] @@ -366,7 +366,7 @@ def hidden_features(cfg: Extract) -> tuple[DatasetInfo, Features]: if num_dropped: print(f"Dropping {num_dropped} non-multiple choice templates") - layer_indices = cfg.layers or tuple(range(model_cfg.num_hidden_layers)[1:]) + layer_indices = cfg.layers or tuple(range(1, model_cfg.num_hidden_layers)) layer_cols = { f"hidden_{layer}": Array3D( dtype="int16",