Skip to content

Commit

Permalink
fix save_logprobs
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexTMallen committed Oct 4, 2023
1 parent d3586be commit abbec97
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 11 deletions.
4 changes: 4 additions & 0 deletions elk/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from .evaluation import Eval
from .extraction import Extract
from .training.train import Elicit

__all__ = [
"Extract",
"Elicit",
"Eval",
]
3 changes: 3 additions & 0 deletions elk/evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ def apply_to_layer(
}
)

if self.save_logprobs:
out_logprobs[ds_name]["lr"][mode] = dict()

for i, model in enumerate(lr_models):
model.eval()
val_log_odds = model(val_data.hiddens)
Expand Down
2 changes: 1 addition & 1 deletion elk/promptsource/templates/_default/templates.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ dataset: None
templates:
7eab7254-bd71-4b1d-9f8a-0fc7110f8371: !Template
answer_choices: False ||| True
id: 7eab7254-bd41-4b1d-9f8a-0fc7110f8371
id: 7eab7254-bd71-4b1d-9f8a-0fc7110f8371
jinja: "{{ statement }}"
metadata: !TemplateMetadata
choices_in_prompt: true
Expand Down
15 changes: 15 additions & 0 deletions elk/promptsource/templates/_no_suffix/templates.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
dataset: None
templates:
8eab7252-bd71-4b2d-9f8a-0fc7260f8371: !Template
answer_choices: False ||| True
id: 8eab7252-bd71-4b2d-9f8a-0fc7260f8371
jinja: "{{ statement }}"
metadata: !TemplateMetadata
choices_in_prompt: true
languages:
- en
metrics:
- Accuracy
original_task: true
name: _no_suffix
suffix: ""
13 changes: 6 additions & 7 deletions elk/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class Run(ABC, Serializable):
concatenated_layer_offset: int = 0
debug: bool = False
num_gpus: int = -1
out_dir: Path | None = None
min_gpu_mem: int = 0
disable_cache: bool = field(default=False, to_dict=False)

def execute(
Expand Down Expand Up @@ -113,7 +113,7 @@ def execute(
meta_f,
)

devices = select_usable_devices(self.num_gpus)
devices = select_usable_devices(self.num_gpus, min_memory=self.min_gpu_mem)
num_devices = len(devices)
func: Callable[[int], tuple[dict[str, pd.DataFrame], dict]] = partial(
self.apply_to_layer, devices=devices, world_size=num_devices
Expand Down Expand Up @@ -225,18 +225,17 @@ def apply_to_layers(
if self.save_logprobs:
save_dict = defaultdict(dict)
for ds_name, logprobs_dict in logprobs_dicts.items():
save_dict[ds_name]["row_ids"] = logprobs_dict[layers[0]][
"row_ids"
]
save_dict[ds_name]["texts"] = logprobs_dict[layers[0]]["texts"]
save_dict[ds_name]["labels"] = logprobs_dict[layers[0]][
"labels"
]
save_dict[ds_name]["lm"] = logprobs_dict[layers[0]]["lm"]
save_dict[ds_name]["reporter"] = dict()
save_dict[ds_name]["lr"] = dict()
for layer, logprobs_dict_by_mode in logprobs_dict.items():
save_dict[ds_name]["reporter"][
layer
] = logprobs_dict_by_mode["reporter"]
save_dict[ds_name]["lr"][layer] = logprobs_dict_by_mode[
"lr"
]
torch.save(save_dict, self.out_dir / "logprobs.pt")
torch.save(dict(save_dict), self.out_dir / "logprobs.pt")
2 changes: 1 addition & 1 deletion elk/training/supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def train_supervised(
(n, v, d) = train_data.hiddens.shape
train_h = rearrange(train_data.hiddens, "n v d -> (n v) d")

if erase_paraphrases:
if erase_paraphrases and v > 1:
if leace is None:
leace = LeaceFitter(
d,
Expand Down
7 changes: 5 additions & 2 deletions elk/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class Elicit(Run):
cross-validation. Defaults to "single", which means to train a single classifier
on the training data. "cv" means to use cross-validation."""

erase_paraphrases: bool = True
erase_paraphrases: bool = False
"""Whether to use LEACE to erase the paraphrase dimensions before training the
classifier."""

Expand All @@ -41,7 +41,7 @@ def default():
datasets=("<placeholder>",),
)
)

def create_models_dir(self, out_dir: Path):
lr_dir = out_dir / "lr_models"

Expand Down Expand Up @@ -115,6 +115,9 @@ def apply_to_layer(
}
)

if self.save_logprobs:
out_logprobs[ds_name]["lr"][mode] = dict()

for i, model in enumerate(lr_models):
model.eval()
val_log_odds = model(val.hiddens)
Expand Down

0 comments on commit abbec97

Please sign in to comment.