Skip to content

Commit

Permalink
add option to save reporter & lr outputs
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexTMallen committed Apr 17, 2023
1 parent c617d6a commit e986195
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 20 deletions.
37 changes: 33 additions & 4 deletions elk/evaluation/evaluate.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import Callable

import pandas as pd
import torch
from einops import rearrange, repeat
from simple_parsing.helpers import Serializable, field
from sklearn.metrics import roc_auc_score

from elk.metrics import accuracy, to_one_hot

from ..extraction.extraction import Extract
from ..files import elk_reporter_dir
Expand Down Expand Up @@ -35,6 +40,7 @@ class Eval(Serializable):
data: Extract
source: str = field(positional=True)

preds_out_dir: Path | None = None
concatenated_layer_offset: int = 0
debug: bool = False
min_gpu_mem: int | None = None
Expand All @@ -56,7 +62,7 @@ class Evaluate(Run):

def evaluate_reporter(
self, layer: int, devices: list[str], world_size: int = 1
) -> pd.DataFrame:
) -> tuple[pd.DataFrame, dict]:
"""Evaluate a single reporter on a single layer."""
is_raw = self.cfg.data.prompts.datasets == ["raw"]
device = self.get_device(devices, world_size)
Expand All @@ -75,18 +81,37 @@ def evaluate_reporter(
reporter.eval()

row_buf = []
for ds_name, (val_h, val_gt, _) in val_output.items():
preds_buf = dict()
for ds_name, (val_h, val_gt, val_lm_preds) in val_output.items():
with torch.no_grad():
ds_preds = {f"reporter_{layer}": reporter(val_h).cpu().numpy()}
val_result = (
reporter.score(val_gt, val_h)
if is_raw
else reporter.score_contrast_set(val_gt, val_h)
)

if val_lm_preds is not None:
(_, v, k, _) = val_h.shape

val_gt_cpu = repeat(val_gt, "n -> (n v)", v=v).cpu()
val_lm_preds = rearrange(val_lm_preds, "n v ... -> (n v) ...")
val_lm_auroc = roc_auc_score(
to_one_hot(val_gt_cpu, k).long().flatten(), val_lm_preds.flatten()
)

val_lm_acc = accuracy(val_gt_cpu, torch.from_numpy(val_lm_preds))
else:
val_lm_auroc = None
val_lm_acc = None

stats_row = pd.Series(
{
"dataset": ds_name,
"layer": layer,
**val_result._asdict(),
"lm_auroc": val_lm_auroc,
"lm_acc": val_lm_acc,
}
)

Expand All @@ -97,12 +122,16 @@ def evaluate_reporter(

lr_auroc, lr_acc = evaluate_supervised(lr_model, val_h, val_gt)

with torch.no_grad():
ds_preds[f"lr_{layer}"] = lr_model(val_h).cpu().numpy().squeeze(-1)

stats_row["lr_auroc"] = lr_auroc
stats_row["lr_acc"] = lr_acc

row_buf.append(stats_row)
preds_buf[ds_name] = ds_preds

return pd.DataFrame(row_buf)
return pd.DataFrame(row_buf), preds_buf

def evaluate(self):
"""Evaluate the reporter on all layers."""
Expand All @@ -111,7 +140,7 @@ def evaluate(self):
)

num_devices = len(devices)
func: Callable[[int], pd.DataFrame] = partial(
func: Callable[[int], tuple[pd.DataFrame, dict]] = partial(
self.evaluate_reporter, devices=devices, world_size=num_devices
)
self.apply_to_layers(func=func, num_devices=num_devices)
2 changes: 1 addition & 1 deletion elk/extraction/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ class Extract(Serializable):
prompts: PromptConfig
model: str = field(positional=True)

hiddens_out_dir: Path | None = None
layers: tuple[int, ...] = ()
layer_stride: InitVar[int] = 1
token_loc: Literal["first", "last", "mean"] = "last"
Expand Down Expand Up @@ -281,6 +280,7 @@ def get_max_examples(global_max_examples: int, rank: int, world_size: int) -> in
return max_examples


@torch.no_grad()
def raw_extract_hiddens(
cfg: "Extract",
*,
Expand Down
37 changes: 25 additions & 12 deletions elk/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,6 @@ def __post_init__(self):
for cfg in self.cfg.data.explode()
]

# Save the hidden states to disk if requested
if self.cfg.data.hiddens_out_dir is not None:
print("Saving hidden states to disk at", self.cfg.data.hiddens_out_dir)
for ds_name, ds in zip(self.cfg.data.prompts.datasets, self.datasets):
for split in ds.keys():
path = self.cfg.data.hiddens_out_dir / ds_name / split
path.mkdir(parents=True, exist_ok=True)
ds[split].save_to_disk(path)

# TODO support raw evaluation
if self.out_dir is None:
# Save in a memorably-named directory inside of
# ELK_REPORTER_DIR/<model_name>/<dataset_name>
Expand Down Expand Up @@ -137,7 +127,7 @@ def concatenate(self, layers):

def apply_to_layers(
self,
func: Callable[[int], pd.DataFrame],
func: Callable[[int], tuple[pd.DataFrame, dict]],
num_devices: int,
):
"""Apply a function to each layer of the datasets in parallel
Expand All @@ -161,10 +151,14 @@ def apply_to_layers(
with ctx.Pool(num_devices) as pool, open(self.out_dir / "eval.csv", "w") as f:
mapper = pool.imap_unordered if num_devices > 1 else map
df_buf = []
preds_buf = {ds_name: dict() for ds_name in self.cfg.data.prompts.datasets}

try:
for df in tqdm(mapper(func, layers), total=len(layers)):
# TODO: also save reporter outputs and LR outputs for each layer
for df, preds in tqdm(mapper(func, layers), total=len(layers)):
df_buf.append(df)
for ds_name, ds_preds in preds.items():
preds_buf[ds_name].update(ds_preds)
finally:
# Make sure the CSV is written even if we crash or get interrupted
if df_buf:
Expand All @@ -176,3 +170,22 @@ def apply_to_layers(
self.out_dir,
is_raw=self.cfg.data.prompts.datasets == ["raw"],
)

is_raw = self.cfg.data.prompts.datasets == ["raw"]

# Save the hidden states to disk if requested
if self.cfg.preds_out_dir is not None:
print("Saving hidden states to disk at", self.cfg.preds_out_dir)

for ds_name, ds in zip(self.cfg.data.prompts.datasets, self.datasets):
val_name = "val" if is_raw else select_train_val_splits(ds)[1]
val_ds = ds[val_name]

# Add the predictions to the dataset
ds_preds = assert_type(dict, preds_buf[ds_name])
for key, preds in ds_preds.items():
val_ds = val_ds.add_column(key, preds.tolist()) # type: ignore

path = self.cfg.preds_out_dir / ds_name
path.mkdir(parents=True, exist_ok=True)
val_ds.save_to_disk(path.as_posix())
13 changes: 10 additions & 3 deletions elk/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class Elicit(Serializable):
num_gpus: int = -1
out_dir: Path | None = None
supervised: Literal["none", "single", "cv"] = "single"
preds_out_dir: Path | None = None

def __post_init__(self):
if self.data.prompts.datasets == ["raw"]:
Expand Down Expand Up @@ -84,7 +85,7 @@ def train_reporter(
layer: int,
devices: list[str],
world_size: int = 1,
) -> pd.DataFrame:
) -> tuple[pd.DataFrame, dict]:
"""Train a single reporter on a single layer."""
self.make_reproducible(seed=self.cfg.net.seed + layer)
device = self.get_device(devices, world_size)
Expand Down Expand Up @@ -153,7 +154,10 @@ def train_reporter(
lr_model = None

row_buf = []
probs_buf = dict()
for ds_name, (val_h, val_gt, val_lm_preds) in val_dict.items():
with torch.no_grad():
ds_probs = {f"reporter_{layer}": reporter(val_h).cpu().numpy()}
val_result = reporter.score_contrast_set(val_gt, val_h)

if val_lm_preds is not None:
Expand Down Expand Up @@ -186,16 +190,19 @@ def train_reporter(
row["lr_auroc"], row["lr_acc"] = evaluate_supervised(
lr_model, val_h, val_gt
)
with torch.no_grad():
ds_probs[f"lr_{layer}"] = lr_model(val_h).cpu().numpy().squeeze(-1)

row_buf.append(row)
probs_buf[ds_name] = ds_probs

return pd.DataFrame(row_buf)
return pd.DataFrame(row_buf), probs_buf

def train(self):
"""Train a reporter on each layer of the network."""
devices = select_usable_devices(self.cfg.num_gpus)
num_devices = len(devices)
func: Callable[[int], pd.DataFrame] = partial(
func: Callable[[int], tuple[pd.DataFrame, dict]] = partial(
self.train_reporter, devices=devices, world_size=num_devices
)
self.apply_to_layers(func=func, num_devices=num_devices)

0 comments on commit e986195

Please sign in to comment.