From 98a8deab133b499178c23866be0498a1798d33c0 Mon Sep 17 00:00:00 2001 From: Alex Mallen Date: Thu, 21 Sep 2023 07:37:38 +0000 Subject: [PATCH] lm preds --- elk/evaluation/evaluate.py | 22 ++++++++++++++++++---- elk/run.py | 10 ++++++++-- elk/training/train.py | 26 +++++++++++++++++++++----- 3 files changed, 47 insertions(+), 11 deletions(-) diff --git a/elk/evaluation/evaluate.py b/elk/evaluation/evaluate.py index fee480ee..4f034970 100644 --- a/elk/evaluation/evaluate.py +++ b/elk/evaluation/evaluate.py @@ -58,13 +58,27 @@ def apply_to_layer( lr=dict(), ) for mode in ("none", "full"): - # TODO save lm logprobs and add to buf + if val_data.lm_log_odds is not None: + if self.save_logprobs: + out_logprobs[ds_name]["lm"][mode] = get_logprobs( + val_data.lm_log_odds, mode + ).cpu() + row_bufs["lm_eval"].append( + { + "ensembling": mode, + **meta, + **evaluate_preds( + val_data.labels, val_data.lm_log_odds, mode + ).to_dict(), + } + ) + for i, model in enumerate(lr_models): model.eval() - val_credences = model(val_data.hiddens) + val_log_odds = model(val_data.hiddens) if self.save_logprobs: out_logprobs[ds_name]["lr"][mode][i] = get_logprobs( - val_credences, mode + val_log_odds, mode ).cpu() row_bufs["lr_eval"].append( { @@ -72,7 +86,7 @@ def apply_to_layer( "inlp_iter": i, **meta, **evaluate_preds( - val_data.labels, val_credences, mode + val_data.labels, val_log_odds, mode ).to_dict(), } ) diff --git a/elk/run.py b/elk/run.py index bdda78ba..28ac1d4f 100644 --- a/elk/run.py +++ b/elk/run.py @@ -35,7 +35,7 @@ class LayerData: hiddens: Tensor labels: Tensor - lm_preds: Tensor | None + lm_log_odds: Tensor | None texts: list[list[str]] # (n, v) row_ids: list[int] # (n,) variant_ids: list[list[str]] # (n, v) @@ -156,10 +156,16 @@ def prepare_data( if self.prompt_indices: hiddens = hiddens[:, self.prompt_indices] + if "lm_log_odds" in split.column_names: + with split.formatted_as("torch", device=device): + lm_preds = assert_type(Tensor, split["lm_log_odds"]) + else: + lm_preds = None + out[ds_name] = LayerData( hiddens=hiddens, labels=labels, - lm_preds=None, # TODO: implement + lm_log_odds=lm_preds, texts=split["texts"], row_ids=split["row_id"], variant_ids=split["variant_ids"], diff --git a/elk/training/train.py b/elk/training/train.py index bc588ebc..84a55a39 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -89,14 +89,30 @@ def apply_to_layer( ) for mode in ("none", "full"): + if val.lm_log_odds is not None: + if self.save_logprobs: + out_logprobs[ds_name]["lm"][mode] = ( + get_logprobs(val.lm_log_odds, mode).detach().cpu() + ) + + row_bufs["lm_eval"].append( + { + **meta, + "ensembling": mode, + **evaluate_preds( + val.labels, val.lm_log_odds, mode + ).to_dict(), + } + ) + for i, model in enumerate(lr_models): model.eval() - val_credences = model(val.hiddens) - train_credences = model(train.hiddens) + val_log_odds = model(val.hiddens) + train_log_odds = model(train.hiddens) if self.save_logprobs: out_logprobs[ds_name]["lr"][mode][i] = ( - get_logprobs(val_credences, mode).detach().cpu() + get_logprobs(val_log_odds, mode).detach().cpu() ) row_bufs["lr_eval"].append( @@ -104,7 +120,7 @@ def apply_to_layer( **meta, "ensembling": mode, "inlp_iter": i, - **evaluate_preds(val.labels, val_credences, mode).to_dict(), + **evaluate_preds(val.labels, val_log_odds, mode).to_dict(), } ) @@ -114,7 +130,7 @@ def apply_to_layer( "ensembling": mode, "inlp_iter": i, **evaluate_preds( - train.labels, train_credences, mode + train.labels, train_log_odds, mode ).to_dict(), } )