From e986195cbe07403d049240352598d3c672f8a080 Mon Sep 17 00:00:00 2001 From: Alex Mallen Date: Mon, 17 Apr 2023 02:02:20 -0400 Subject: [PATCH] add option to save reporter & lr outputs --- elk/evaluation/evaluate.py | 37 ++++++++++++++++++++++++++++++++---- elk/extraction/extraction.py | 2 +- elk/run.py | 37 ++++++++++++++++++++++++------------ elk/training/train.py | 13 ++++++++++--- 4 files changed, 69 insertions(+), 20 deletions(-) diff --git a/elk/evaluation/evaluate.py b/elk/evaluation/evaluate.py index 88e29166..98ce2baa 100644 --- a/elk/evaluation/evaluate.py +++ b/elk/evaluation/evaluate.py @@ -1,10 +1,15 @@ from dataclasses import dataclass from functools import partial +from pathlib import Path from typing import Callable import pandas as pd import torch +from einops import rearrange, repeat from simple_parsing.helpers import Serializable, field +from sklearn.metrics import roc_auc_score + +from elk.metrics import accuracy, to_one_hot from ..extraction.extraction import Extract from ..files import elk_reporter_dir @@ -35,6 +40,7 @@ class Eval(Serializable): data: Extract source: str = field(positional=True) + preds_out_dir: Path | None = None concatenated_layer_offset: int = 0 debug: bool = False min_gpu_mem: int | None = None @@ -56,7 +62,7 @@ class Evaluate(Run): def evaluate_reporter( self, layer: int, devices: list[str], world_size: int = 1 - ) -> pd.DataFrame: + ) -> tuple[pd.DataFrame, dict]: """Evaluate a single reporter on a single layer.""" is_raw = self.cfg.data.prompts.datasets == ["raw"] device = self.get_device(devices, world_size) @@ -75,18 +81,37 @@ def evaluate_reporter( reporter.eval() row_buf = [] - for ds_name, (val_h, val_gt, _) in val_output.items(): + preds_buf = dict() + for ds_name, (val_h, val_gt, val_lm_preds) in val_output.items(): + with torch.no_grad(): + ds_preds = {f"reporter_{layer}": reporter(val_h).cpu().numpy()} val_result = ( reporter.score(val_gt, val_h) if is_raw else reporter.score_contrast_set(val_gt, val_h) ) + if val_lm_preds is not None: + (_, v, k, _) = val_h.shape + + val_gt_cpu = repeat(val_gt, "n -> (n v)", v=v).cpu() + val_lm_preds = rearrange(val_lm_preds, "n v ... -> (n v) ...") + val_lm_auroc = roc_auc_score( + to_one_hot(val_gt_cpu, k).long().flatten(), val_lm_preds.flatten() + ) + + val_lm_acc = accuracy(val_gt_cpu, torch.from_numpy(val_lm_preds)) + else: + val_lm_auroc = None + val_lm_acc = None + stats_row = pd.Series( { "dataset": ds_name, "layer": layer, **val_result._asdict(), + "lm_auroc": val_lm_auroc, + "lm_acc": val_lm_acc, } ) @@ -97,12 +122,16 @@ def evaluate_reporter( lr_auroc, lr_acc = evaluate_supervised(lr_model, val_h, val_gt) + with torch.no_grad(): + ds_preds[f"lr_{layer}"] = lr_model(val_h).cpu().numpy().squeeze(-1) + stats_row["lr_auroc"] = lr_auroc stats_row["lr_acc"] = lr_acc row_buf.append(stats_row) + preds_buf[ds_name] = ds_preds - return pd.DataFrame(row_buf) + return pd.DataFrame(row_buf), preds_buf def evaluate(self): """Evaluate the reporter on all layers.""" @@ -111,7 +140,7 @@ def evaluate(self): ) num_devices = len(devices) - func: Callable[[int], pd.DataFrame] = partial( + func: Callable[[int], tuple[pd.DataFrame, dict]] = partial( self.evaluate_reporter, devices=devices, world_size=num_devices ) self.apply_to_layers(func=func, num_devices=num_devices) diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index 4beca526..90269b5e 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -61,7 +61,6 @@ class Extract(Serializable): prompts: PromptConfig model: str = field(positional=True) - hiddens_out_dir: Path | None = None layers: tuple[int, ...] = () layer_stride: InitVar[int] = 1 token_loc: Literal["first", "last", "mean"] = "last" @@ -281,6 +280,7 @@ def get_max_examples(global_max_examples: int, rank: int, world_size: int) -> in return max_examples +@torch.no_grad() def raw_extract_hiddens( cfg: "Extract", *, diff --git a/elk/run.py b/elk/run.py index aad593dd..8b015ff7 100644 --- a/elk/run.py +++ b/elk/run.py @@ -43,16 +43,6 @@ def __post_init__(self): for cfg in self.cfg.data.explode() ] - # Save the hidden states to disk if requested - if self.cfg.data.hiddens_out_dir is not None: - print("Saving hidden states to disk at", self.cfg.data.hiddens_out_dir) - for ds_name, ds in zip(self.cfg.data.prompts.datasets, self.datasets): - for split in ds.keys(): - path = self.cfg.data.hiddens_out_dir / ds_name / split - path.mkdir(parents=True, exist_ok=True) - ds[split].save_to_disk(path) - - # TODO support raw evaluation if self.out_dir is None: # Save in a memorably-named directory inside of # ELK_REPORTER_DIR// @@ -137,7 +127,7 @@ def concatenate(self, layers): def apply_to_layers( self, - func: Callable[[int], pd.DataFrame], + func: Callable[[int], tuple[pd.DataFrame, dict]], num_devices: int, ): """Apply a function to each layer of the datasets in parallel @@ -161,10 +151,14 @@ def apply_to_layers( with ctx.Pool(num_devices) as pool, open(self.out_dir / "eval.csv", "w") as f: mapper = pool.imap_unordered if num_devices > 1 else map df_buf = [] + preds_buf = {ds_name: dict() for ds_name in self.cfg.data.prompts.datasets} try: - for df in tqdm(mapper(func, layers), total=len(layers)): + # TODO: also save reporter outputs and LR outputs for each layer + for df, preds in tqdm(mapper(func, layers), total=len(layers)): df_buf.append(df) + for ds_name, ds_preds in preds.items(): + preds_buf[ds_name].update(ds_preds) finally: # Make sure the CSV is written even if we crash or get interrupted if df_buf: @@ -176,3 +170,22 @@ def apply_to_layers( self.out_dir, is_raw=self.cfg.data.prompts.datasets == ["raw"], ) + + is_raw = self.cfg.data.prompts.datasets == ["raw"] + + # Save the hidden states to disk if requested + if self.cfg.preds_out_dir is not None: + print("Saving hidden states to disk at", self.cfg.preds_out_dir) + + for ds_name, ds in zip(self.cfg.data.prompts.datasets, self.datasets): + val_name = "val" if is_raw else select_train_val_splits(ds)[1] + val_ds = ds[val_name] + + # Add the predictions to the dataset + ds_preds = assert_type(dict, preds_buf[ds_name]) + for key, preds in ds_preds.items(): + val_ds = val_ds.add_column(key, preds.tolist()) # type: ignore + + path = self.cfg.preds_out_dir / ds_name + path.mkdir(parents=True, exist_ok=True) + val_ds.save_to_disk(path.as_posix()) diff --git a/elk/training/train.py b/elk/training/train.py index 518f5742..56543c75 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -53,6 +53,7 @@ class Elicit(Serializable): num_gpus: int = -1 out_dir: Path | None = None supervised: Literal["none", "single", "cv"] = "single" + preds_out_dir: Path | None = None def __post_init__(self): if self.data.prompts.datasets == ["raw"]: @@ -84,7 +85,7 @@ def train_reporter( layer: int, devices: list[str], world_size: int = 1, - ) -> pd.DataFrame: + ) -> tuple[pd.DataFrame, dict]: """Train a single reporter on a single layer.""" self.make_reproducible(seed=self.cfg.net.seed + layer) device = self.get_device(devices, world_size) @@ -153,7 +154,10 @@ def train_reporter( lr_model = None row_buf = [] + probs_buf = dict() for ds_name, (val_h, val_gt, val_lm_preds) in val_dict.items(): + with torch.no_grad(): + ds_probs = {f"reporter_{layer}": reporter(val_h).cpu().numpy()} val_result = reporter.score_contrast_set(val_gt, val_h) if val_lm_preds is not None: @@ -186,16 +190,19 @@ def train_reporter( row["lr_auroc"], row["lr_acc"] = evaluate_supervised( lr_model, val_h, val_gt ) + with torch.no_grad(): + ds_probs[f"lr_{layer}"] = lr_model(val_h).cpu().numpy().squeeze(-1) row_buf.append(row) + probs_buf[ds_name] = ds_probs - return pd.DataFrame(row_buf) + return pd.DataFrame(row_buf), probs_buf def train(self): """Train a reporter on each layer of the network.""" devices = select_usable_devices(self.cfg.num_gpus) num_devices = len(devices) - func: Callable[[int], pd.DataFrame] = partial( + func: Callable[[int], tuple[pd.DataFrame, dict]] = partial( self.train_reporter, devices=devices, world_size=num_devices ) self.apply_to_layers(func=func, num_devices=num_devices)