diff --git a/elk/evaluation/evaluate.py b/elk/evaluation/evaluate.py index d58f39857..fee480eeb 100644 --- a/elk/evaluation/evaluate.py +++ b/elk/evaluation/evaluate.py @@ -7,7 +7,7 @@ from simple_parsing.helpers import field from ..files import elk_reporter_dir -from ..metrics import evaluate_preds +from ..metrics import evaluate_preds, get_logprobs from ..run import Run from ..utils import Color @@ -30,7 +30,7 @@ def execute(self, highlight_color: Color = "cyan"): @torch.inference_mode() def apply_to_layer( self, layer: int, devices: list[str], world_size: int - ) -> dict[str, pd.DataFrame]: + ) -> tuple[dict[str, pd.DataFrame], dict]: """Evaluate a single reporter on a single layer.""" device = self.get_device(devices, world_size) val_output = self.prepare_data(device, layer, "val") @@ -43,19 +43,38 @@ def apply_to_layer( if not isinstance(lr_models, list): # backward compatibility lr_models = [lr_models] + out_logprobs = defaultdict(dict) row_bufs = defaultdict(list) - for ds_name, (val_h, val_gt) in val_output.items(): + for ds_name, val_data in val_output.items(): meta = {"dataset": ds_name, "layer": layer} + + if self.save_logprobs: + out_logprobs[ds_name] = dict( + row_ids=val_data.row_ids, + variant_ids=val_data.variant_ids, + texts=val_data.texts, + labels=val_data.labels, + lm=dict(), + lr=dict(), + ) for mode in ("none", "full"): + # TODO save lm logprobs and add to buf for i, model in enumerate(lr_models): model.eval() + val_credences = model(val_data.hiddens) + if self.save_logprobs: + out_logprobs[ds_name]["lr"][mode][i] = get_logprobs( + val_credences, mode + ).cpu() row_bufs["lr_eval"].append( { "ensembling": mode, "inlp_iter": i, **meta, - **evaluate_preds(val_gt, model(val_h), mode).to_dict(), + **evaluate_preds( + val_data.labels, val_credences, mode + ).to_dict(), } ) - return {k: pd.DataFrame(v) for k, v in row_bufs.items()} + return {k: pd.DataFrame(v) for k, v in row_bufs.items()}, out_logprobs diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index 89669ba6a..faae5bc4a 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -358,6 +358,7 @@ def select_hiddens(outputs: Any) -> dict: if len(buffer[row_id]) == num_variants: # we have a complete example ex = buffer[row_id] + ex = sorted(ex, key=lambda d: d["variant_id"]) assert all(d["label"] == ex[0]["label"] for d in ex) assert len(set(d["variant_id"] for d in ex)) == num_variants out_record = dict( diff --git a/elk/metrics/__init__.py b/elk/metrics/__init__.py index 52aa3a907..25ed1b2a0 100644 --- a/elk/metrics/__init__.py +++ b/elk/metrics/__init__.py @@ -1,6 +1,6 @@ from .accuracy import accuracy_ci from .calibration import CalibrationError, CalibrationEstimate -from .eval import EvalResult, evaluate_preds +from .eval import EvalResult, evaluate_preds, get_logprobs from .roc_auc import RocAucResult, roc_auc, roc_auc_ci __all__ = [ @@ -9,6 +9,7 @@ "CalibrationEstimate", "EvalResult", "evaluate_preds", + "get_logprobs", "roc_auc", "roc_auc_ci", "RocAucResult", diff --git a/elk/metrics/eval.py b/elk/metrics/eval.py index b68c4f0f8..8c837e8f9 100644 --- a/elk/metrics/eval.py +++ b/elk/metrics/eval.py @@ -2,6 +2,7 @@ from typing import Literal import torch +import torch.nn.functional as F from einops import repeat from torch import Tensor @@ -41,6 +42,22 @@ def to_dict(self, prefix: str = "") -> dict[str, float]: return {**auroc_dict, **cal_acc_dict, **acc_dict, **cal_dict} +def get_logprobs( + y_logits: Tensor, ensembling: Literal["none", "full"] = "none" +) -> Tensor: + """ + Get the class probabilities from a tensor of logits. + Args: + y_logits: Predicted log-odds of the positive class, tensor of shape (n, v). + Returns: + Tensor of logprobs: If ensemble is "none", a tensor of shape (n, v). + If ensemble is "full", a tensor of shape (n,). + """ + if ensembling == "full": + y_logits = y_logits.mean(dim=1) + return F.logsigmoid(y_logits) + + def evaluate_preds( y_true: Tensor, y_logits: Tensor, diff --git a/elk/run.py b/elk/run.py index 04f901d0d..bdda78ba1 100644 --- a/elk/run.py +++ b/elk/run.py @@ -31,6 +31,16 @@ ) +@dataclass +class LayerData: + hiddens: Tensor + labels: Tensor + lm_preds: Tensor | None + texts: list[list[str]] # (n, v) + row_ids: list[int] # (n,) + variant_ids: list[list[str]] # (n, v) + + @dataclass class Run(ABC, Serializable): data: Extract @@ -46,6 +56,15 @@ class Run(ABC, Serializable): prompt_indices: tuple[int, ...] = () """The indices of the prompt templates to use. If empty, all prompts are used.""" + save_logprobs: bool = field(default=False, to_dict=False) + """ saves logprobs.pt containing + {: {"row_ids": [n,], "variant_ids": [n, v], + "labels": [n,], "texts": [n, v], + "lm": {"none": [n, v], "full": [n,]}, + "lr": {: {: {"none": [n, v], "full": [n,]}}} + }} + """ + concatenated_layer_offset: int = 0 debug: bool = False num_gpus: int = -1 @@ -96,7 +115,7 @@ def execute( devices = select_usable_devices(self.num_gpus) num_devices = len(devices) - func: Callable[[int], dict[str, pd.DataFrame]] = partial( + func: Callable[[int], tuple[dict[str, pd.DataFrame], dict]] = partial( self.apply_to_layer, devices=devices, world_size=num_devices ) self.apply_to_layers(func=func, num_devices=num_devices) @@ -104,7 +123,7 @@ def execute( @abstractmethod def apply_to_layer( self, layer: int, devices: list[str], world_size: int - ) -> dict[str, pd.DataFrame]: + ) -> tuple[dict[str, pd.DataFrame], dict]: """Train or eval a reporter on a single layer.""" def make_reproducible(self, seed: int): @@ -123,7 +142,7 @@ def get_device(self, devices, world_size: int) -> str: def prepare_data( self, device: str, layer: int, split_type: Literal["train", "val"] - ) -> dict[str, tuple[Tensor, Tensor]]: + ) -> dict[str, LayerData]: """Prepare data for the specified layer and split type.""" out = {} @@ -137,7 +156,14 @@ def prepare_data( if self.prompt_indices: hiddens = hiddens[:, self.prompt_indices] - out[ds_name] = (hiddens, labels.to(hiddens.device)) + out[ds_name] = LayerData( + hiddens=hiddens, + labels=labels, + lm_preds=None, # TODO: implement + texts=split["texts"], + row_ids=split["row_id"], + variant_ids=split["variant_ids"], + ) return out @@ -150,7 +176,7 @@ def concatenate(self, layers): def apply_to_layers( self, - func: Callable[[int], dict[str, pd.DataFrame]], + func: Callable[[int], tuple[dict[str, pd.DataFrame], dict]], num_devices: int, ): """Apply a function to each layer of the datasets in parallel @@ -173,11 +199,16 @@ def apply_to_layers( with ctx.Pool(num_devices) as pool: mapper = pool.imap_unordered if num_devices > 1 else map df_buffers = defaultdict(list) + logprobs_dicts = defaultdict(dict) try: - for df_dict in tqdm(mapper(func, layers), total=len(layers)): + for layer, (df_dict, logprobs_dict) in tqdm( + zip(layers, mapper(func, layers)), total=len(layers) + ): for k, v in df_dict.items(): df_buffers[k].append(v) + for k, v in logprobs_dict.items(): + logprobs_dicts[k][layer] = logprobs_dict[k] finally: # Make sure the CSVs are written even if we crash or get interrupted for name, dfs in df_buffers.items(): @@ -185,3 +216,21 @@ def apply_to_layers( df.round(4).to_csv(self.out_dir / f"{name}.csv", index=False) if self.debug: save_debug_log(self.datasets, self.out_dir) + if self.save_logprobs: + save_dict = defaultdict(dict) + for ds_name, logprobs_dict in logprobs_dicts.items(): + save_dict[ds_name]["texts"] = logprobs_dict[layers[0]]["texts"] + save_dict[ds_name]["labels"] = logprobs_dict[layers[0]][ + "labels" + ] + save_dict[ds_name]["lm"] = logprobs_dict[layers[0]]["lm"] + save_dict[ds_name]["reporter"] = dict() + save_dict[ds_name]["lr"] = dict() + for layer, logprobs_dict_by_mode in logprobs_dict.items(): + save_dict[ds_name]["reporter"][ + layer + ] = logprobs_dict_by_mode["reporter"] + save_dict[ds_name]["lr"][layer] = logprobs_dict_by_mode[ + "lr" + ] + torch.save(save_dict, self.out_dir / "logprobs.pt") diff --git a/elk/training/classifier.py b/elk/training/classifier.py index 6370c28e8..0aab9dbd7 100644 --- a/elk/training/classifier.py +++ b/elk/training/classifier.py @@ -68,7 +68,7 @@ def fit( x: Tensor, y: Tensor, *, - l2_penalty: float = 0.0, + l2_penalty: float = 0.001, max_iter: int = 10_000, ) -> float: """Fits the model to the input data using L-BFGS with L2 regularization. diff --git a/elk/training/supervised.py b/elk/training/supervised.py index 7b0f4efe2..ce2ce9b5f 100644 --- a/elk/training/supervised.py +++ b/elk/training/supervised.py @@ -2,11 +2,12 @@ from concept_erasure import LeaceFitter from einops import rearrange, repeat +from ..run import LayerData from .classifier import Classifier def train_supervised( - data: dict[str, tuple], device: str, mode: str, erase_paraphrases: bool = False + data: dict[str, LayerData], device: str, mode: str, erase_paraphrases: bool = False ) -> list[Classifier]: assert not ( erase_paraphrases and len(data) > 1 @@ -15,9 +16,9 @@ def train_supervised( leace = None - for train_h, labels in data.values(): - (n, v, d) = train_h.shape - train_h = rearrange(train_h, "n v d -> (n v) d") + for train_data in data.values(): + (n, v, d) = train_data.hiddens.shape + train_h = rearrange(train_data.hiddens, "n v d -> (n v) d") if erase_paraphrases: if leace is None: @@ -33,7 +34,7 @@ def train_supervised( ) # (n * v, v) leace = leace.update(train_h, indicators) - labels = repeat(labels, "n -> (n v)", v=v) + labels = repeat(train_data.labels, "n -> (n v)", v=v) Xs.append(train_h) train_labels.append(labels) diff --git a/elk/training/sweep.py b/elk/training/sweep.py index e4aca5a00..36e19c0b0 100755 --- a/elk/training/sweep.py +++ b/elk/training/sweep.py @@ -1,6 +1,5 @@ from dataclasses import InitVar, dataclass, replace -import numpy as np import torch from datasets import get_dataset_config_info from transformers import AutoConfig @@ -9,7 +8,6 @@ from ..extraction import Extract from ..files import memorably_named_dir, sweeps_dir from ..plotting.visualize import visualize_sweep -from ..training.eigen_reporter import EigenFitterConfig from ..utils import colorize from ..utils.constants import BURNS_DATASETS from .train import Elicit @@ -38,11 +36,6 @@ class Sweep: add_pooled: InitVar[bool] = False """Whether to add a dataset that pools all of the other datasets together.""" - hparam_step: float = -1.0 - """The step size for hyperparameter sweeps. Performs a 2D - sweep over a and b in (var_weight, inv_weight, neg_cov_weight) = (a, 1 - b, b) - If negative, no hyperparameter sweeps will be performed. Only valid for Eigen.""" - skip_transfer_eval: bool = False """Whether to perform transfer eval on every pair of datasets.""" @@ -64,13 +57,6 @@ def __post_init__(self, add_pooled: bool): raise ValueError("No datasets specified") if not self.models: raise ValueError("No models specified") - # can only use hparam_step if we're using an eigen net - if self.hparam_step > 0 and not isinstance( - self.run_template.net, EigenFitterConfig - ): - raise ValueError("Can only use hparam_step with EigenFitterConfig") - elif self.hparam_step > 1: - raise ValueError("hparam_step must be in [0, 1]") # Check for the magic dataset "burns" which is a shortcut for all of the # datasets used in Burns et al., except Story Cloze, which is not available @@ -115,9 +101,6 @@ def execute(self): } ) - step = self.hparam_step - weights = np.arange(0.0, 1.0 + step, step) if step > 0 else [None] - for i, model in enumerate(self.models): print(colorize(f"===== {model} ({i + 1} of {M}) =====", "magenta")) @@ -127,52 +110,42 @@ def execute(self): # single sweep. train_datasets = tuple(ds.strip() for ds in dataset_str.split("+")) - for var_weight in weights: - for neg_cov_weight in weights: - out_dir = sweep_dir / model / dataset_str + out_dir = sweep_dir / model / dataset_str - data = replace( - self.run_template.data, model=model, datasets=train_datasets - ) - run = replace(self.run_template, data=data, out_dir=out_dir) - if var_weight is not None and neg_cov_weight is not None: - assert isinstance(run.net, EigenFitterConfig) - run.net.var_weight = var_weight - run.net.neg_cov_weight = neg_cov_weight - - # Add hyperparameter values to output directory if needed - assert run.out_dir is not None - run.out_dir /= f"var_weight={var_weight:.2f}" - run.out_dir /= f"neg_cov_weight={neg_cov_weight:.2f}" - - try: - run.execute() - except torch.linalg.LinAlgError as e: - print(colorize(f"LinAlgError: {e}", "red")) + data = replace( + self.run_template.data, model=model, datasets=train_datasets + ) + run = replace(self.run_template, data=data, out_dir=out_dir) + + # Add hyperparameter values to output directory if needed + assert run.out_dir is not None + + try: + run.execute() + except torch.linalg.LinAlgError as e: + print(colorize(f"LinAlgError: {e}", "red")) + continue + + if not self.skip_transfer_eval: + if len(eval_datasets) > 1: + print(colorize("== Transfer eval ==", "green")) + + # Now evaluate the reporter on the other datasets + for eval_dataset in eval_datasets: + # We already evaluated on this one during training + if eval_dataset in train_datasets: continue - if not self.skip_transfer_eval: - if len(eval_datasets) > 1: - print(colorize("== Transfer eval ==", "green")) - - # Now evaluate the reporter on the other datasets - for eval_dataset in eval_datasets: - # We already evaluated on this one during training - if eval_dataset in train_datasets: - continue - - assert run.out_dir is not None - eval = Eval( - data=replace( - run.data, model=model, datasets=(eval_dataset,) - ), - source=run.out_dir, - out_dir=run.out_dir / "transfer" / eval_dataset, - num_gpus=run.num_gpus, - min_gpu_mem=run.min_gpu_mem, - skip_supervised=run.supervised == "none", - ) - eval.execute(highlight_color="green") + assert run.out_dir is not None + eval = Eval( + data=replace( + run.data, model=model, datasets=(eval_dataset,) + ), + source=run.out_dir, + out_dir=run.out_dir / "transfer" / eval_dataset, + num_gpus=run.num_gpus, + ) + eval.execute(highlight_color="green") if self.visualize: visualize_sweep(sweep_dir) diff --git a/elk/training/train.py b/elk/training/train.py index 141f4afc3..7e7aa0dfe 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -8,7 +8,7 @@ import pandas as pd import torch -from ..metrics import evaluate_preds +from ..metrics import evaluate_preds, get_logprobs from ..run import Run from ..training.supervised import train_supervised from ..utils.typing import assert_type @@ -41,7 +41,7 @@ def apply_to_layer( layer: int, devices: list[str], world_size: int, - ) -> dict[str, pd.DataFrame]: + ) -> tuple[dict[str, pd.DataFrame], dict]: """Train a single reporter on a single layer.""" self.make_reproducible(seed=self.seed + layer) @@ -50,9 +50,9 @@ def apply_to_layer( train_dict = self.prepare_data(device, layer, "train") val_dict = self.prepare_data(device, layer, "val") - (first_train_h, train_gt), *rest = train_dict.values() - (_, v, d) = first_train_h.shape - if not all(other_h.shape[-1] == d for other_h, _ in rest): + first_train_data, *rest = train_dict.values() + (_, v, d) = first_train_data.hiddens.shape + if not all(other_data.hiddens.shape[-1] == d for other_data in rest): raise ValueError("All datasets must have the same hidden state size") lr_dir = self.create_models_dir(assert_type(Path, self.out_dir)) @@ -68,20 +68,39 @@ def apply_to_layer( with open(lr_dir / f"layer_{layer}.pt", "wb") as file: torch.save(lr_models, file) + out_logprobs = defaultdict(dict) row_bufs = defaultdict(list) for ds_name in val_dict: - val_h, val_gt = val_dict[ds_name] - train_h, train_gt = train_dict[ds_name] + val, train = val_dict[ds_name], train_dict[ds_name] meta = {"dataset": ds_name, "layer": layer} + if self.save_logprobs: + out_logprobs[ds_name] = dict( + row_ids=val.row_ids, + variant_ids=val.variant_ids, + texts=val.texts, + labels=val.labels, + lm=dict(), + lr=dict(), + ) + for mode in ("none", "full"): for i, model in enumerate(lr_models): + model.eval() + val_credences = model(val.hiddens) + train_credences = model(train.hiddens) + + if self.save_logprobs: + out_logprobs[ds_name]["lr"][mode][i] = ( + get_logprobs(val_credences, mode).detach().cpu() + ) + row_bufs["lr_eval"].append( { **meta, "ensembling": mode, "inlp_iter": i, - **evaluate_preds(val_gt, model(val_h), mode).to_dict(), + **evaluate_preds(val.labels, val_credences, mode).to_dict(), } ) @@ -90,8 +109,10 @@ def apply_to_layer( **meta, "ensembling": mode, "inlp_iter": i, - **evaluate_preds(train_gt, model(train_h), mode).to_dict(), + **evaluate_preds( + train.labels, train_credences, mode + ).to_dict(), } ) - return {k: pd.DataFrame(v) for k, v in row_bufs.items()} + return {k: pd.DataFrame(v) for k, v in row_bufs.items()}, out_logprobs