diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 38fff767..42e9a5f4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -24,4 +24,4 @@ repos: hooks: - id: codespell # The promptsource templates spuriously get flagged without this - args: ["--skip=*.yaml"] + args: ["-L fpr", "--skip=*.yaml"] diff --git a/elk/evaluation/evaluate.py b/elk/evaluation/evaluate.py index 98ce2baa..40c07ff6 100644 --- a/elk/evaluation/evaluate.py +++ b/elk/evaluation/evaluate.py @@ -120,12 +120,13 @@ def evaluate_reporter( with open(lr_dir / f"layer_{layer}.pt", "rb") as f: lr_model = torch.load(f, map_location=device).eval() - lr_auroc, lr_acc = evaluate_supervised(lr_model, val_h, val_gt) - + lr_auroc_res, lr_acc = evaluate_supervised(lr_model, val_h, val_gt) with torch.no_grad(): ds_preds[f"lr_{layer}"] = lr_model(val_h).cpu().numpy().squeeze(-1) - stats_row["lr_auroc"] = lr_auroc + stats_row["lr_auroc"] = lr_auroc_res.estimate + stats_row["lr_auroc_lower"] = lr_auroc_res.lower + stats_row["lr_auroc_upper"] = lr_auroc_res.upper stats_row["lr_acc"] = lr_acc row_buf.append(stats_row) diff --git a/elk/extraction/prompt_loading.py b/elk/extraction/prompt_loading.py index 8b01849a..b8004f5a 100644 --- a/elk/extraction/prompt_loading.py +++ b/elk/extraction/prompt_loading.py @@ -130,6 +130,7 @@ def explode(self) -> list["PromptConfig"]: def load_prompts( ds_string: str, label_column: Optional[str] = None, + num_classes: int = 0, num_shots: int = 0, num_variants: int = -1, seed: int = 42, @@ -184,9 +185,8 @@ def load_prompts( if rank == 0: print(f"Using {num_variants} variants of each prompt") - feats = assert_type(Features, ds.features) - label_column = infer_label_column(feats) - num_classes = infer_num_classes(feats[label_column]) + label_column = label_column or infer_label_column(ds.features) + num_classes = num_classes or infer_num_classes(ds.features[label_column]) rng = Random(seed) if num_shots > 0: @@ -203,9 +203,6 @@ def load_prompts( extra_cols = list(assert_type(Features, ds.features)) extra_cols.remove(label_column) - if label_column != "label": - ds = ds.rename_column(label_column, "label") - for example in ds: yield _convert_to_prompts( example, diff --git a/elk/metrics.py b/elk/metrics.py index 46a6113a..0150f02f 100644 --- a/elk/metrics.py +++ b/elk/metrics.py @@ -1,3 +1,6 @@ +from typing import NamedTuple + +import torch from torch import Tensor @@ -34,3 +37,122 @@ def accuracy(y_true: Tensor, y_pred: Tensor) -> float: hard_preds = y_pred.argmax(-1) return hard_preds.cpu().eq(y_true.cpu()).float().mean().item() + + +class RocAucResult(NamedTuple): + """Named tuple for storing ROC AUC results.""" + + estimate: float + """Point estimate of the ROC AUC computed on this sample.""" + lower: float + """Lower bound of the bootstrap confidence interval.""" + upper: float + """Upper bound of the bootstrap confidence interval.""" + + +def roc_auc(y_true: Tensor, y_pred: Tensor) -> Tensor: + """Area under the receiver operating characteristic curve (ROC AUC). + + Unlike scikit-learn's implementation, this function supports batched inputs of + shape `(N, n)` where `N` is the number of datasets and `n` is the number of samples + within each dataset. This is primarily useful for efficiently computing bootstrap + confidence intervals. + + Args: + y_true: Ground truth tensor of shape `(N,)` or `(N, n)`. + y_pred: Predicted class tensor of shape `(N,)` or `(N, n)`. + + Returns: + Tensor: If the inputs are 1D, a scalar containing the ROC AUC. If they're 2D, + a tensor of shape (N,) containing the ROC AUC for each dataset. + """ + if y_true.shape != y_pred.shape: + raise ValueError( + f"y_true and y_pred should have the same shape; " + f"got {y_true.shape} and {y_pred.shape}" + ) + if y_true.dim() not in (1, 2): + raise ValueError("y_true and y_pred should be 1D or 2D tensors") + + # Sort y_pred in descending order and get indices + indices = y_pred.argsort(descending=True, dim=-1) + + # Reorder y_true based on sorted y_pred indices + y_true_sorted = y_true.gather(-1, indices) + + # Calculate number of positive and negative samples + num_positives = y_true.sum(dim=-1) + num_negatives = y_true.shape[-1] - num_positives + + # Calculate cumulative sum of true positive counts (TPs) + tps = torch.cumsum(y_true_sorted, dim=-1) + + # Calculate cumulative sum of false positive counts (FPs) + fps = torch.cumsum(1 - y_true_sorted, dim=-1) + + # Calculate true positive rate (TPR) and false positive rate (FPR) + tpr = tps / num_positives.view(-1, 1) + fpr = fps / num_negatives.view(-1, 1) + + # Calculate differences between consecutive FPR values (widths of trapezoids) + fpr_diffs = torch.cat( + [fpr[..., 1:] - fpr[..., :-1], torch.zeros_like(fpr[..., :1])], dim=-1 + ) + + # Calculate area under the ROC curve for each dataset using trapezoidal rule + return torch.sum(tpr * fpr_diffs, dim=-1).squeeze() + + +def roc_auc_ci( + y_true: Tensor, + y_pred: Tensor, + *, + num_samples: int = 1000, + level: float = 0.95, + seed: int = 42, +) -> RocAucResult: + """Bootstrap confidence interval for the ROC AUC. + + Args: + y_true: Ground truth tensor of shape `(N,)`. + y_pred: Predicted class tensor of shape `(N,)`. + num_samples (int): Number of bootstrap samples to use. + level (float): Confidence level of the confidence interval. + seed (int): Random seed for reproducibility. + + Returns: + RocAucResult: Named tuple containing the lower and upper bounds of the + confidence interval, along with the point estimate. + """ + if y_true.shape != y_pred.shape: + raise ValueError( + f"y_true and y_pred should have the same shape; " + f"got {y_true.shape} and {y_pred.shape}" + ) + if y_true.dim() != 1: + raise ValueError("y_true and y_pred should be 1D tensors") + + device = y_true.device + N = y_true.shape[0] + + # Generate random indices for bootstrap samples (shape: [num_bootstraps, N]) + rng = torch.Generator(device=device).manual_seed(seed) + indices = torch.randint(0, N, (num_samples, N), device=device, generator=rng) + + # Create bootstrap samples of true labels and predicted probabilities + y_true_bootstraps = y_true[indices] + y_pred_bootstraps = y_pred[indices] + + # Compute ROC AUC scores for bootstrap samples + bootstrap_aucs = roc_auc(y_true_bootstraps, y_pred_bootstraps) + + # Calculate the lower and upper bounds of the confidence interval. We use + # nanquantile instead of quantile because some bootstrap samples may have + # NaN values due to the fact that they have only one class. + alpha = (1 - level) / 2 + q = y_pred.new_tensor([alpha, 1 - alpha]) + lower, upper = bootstrap_aucs.nanquantile(q).tolist() + + # Compute the point estimate + estimate = roc_auc(y_true, y_pred).item() + return RocAucResult(estimate, lower, upper) diff --git a/elk/run.py b/elk/run.py index 8b015ff7..dacd4e23 100644 --- a/elk/run.py +++ b/elk/run.py @@ -87,7 +87,7 @@ def get_device(self, devices, world_size: int) -> str: def prepare_data( self, device: str, layer: int, split_type: Literal["train", "val"] - ) -> dict[str, tuple[Tensor, Tensor, np.ndarray | None]]: + ) -> dict[str, tuple[Tensor, Tensor, Tensor | None]]: """Prepare data for the specified layer and split type.""" assert self.cfg.data.prompts.datasets != [ "raw" @@ -112,7 +112,7 @@ def prepare_individual( labels = assert_type(Tensor, ds["label"]) hid = int16_to_float32(assert_type(Tensor, ds[f"hidden_{layer}"])) - with ds.formatted_as("numpy"): + with ds.formatted_as("torch", device=device): has_preds = "model_preds" in ds.features lm_preds = ds["model_preds"] if has_preds else None diff --git a/elk/training/ccs_reporter.py b/elk/training/ccs_reporter.py index c4436aec..b839e3fd 100644 --- a/elk/training/ccs_reporter.py +++ b/elk/training/ccs_reporter.py @@ -7,10 +7,10 @@ import torch import torch.nn as nn -from sklearn.metrics import roc_auc_score from torch import Tensor from torch.nn.functional import binary_cross_entropy as bce +from ..metrics import roc_auc from ..parsing import parse_loss from ..utils.typing import assert_type from .classifier import Classifier @@ -175,8 +175,8 @@ def check_separability( pseudo_preds = pseudo_clf( # b v d -> (b v) d torch.cat([val_x0, val_x1]).flatten(0, 1) - ) - return float(roc_auc_score(pseudo_val_labels.cpu(), pseudo_preds.cpu())) + ).squeeze(-1) + return roc_auc(pseudo_val_labels, pseudo_preds).item() def unsupervised_loss(self, logit0: Tensor, logit1: Tensor) -> Tensor: loss = sum( diff --git a/elk/training/reporter.py b/elk/training/reporter.py index ad349ae5..60bbcd12 100644 --- a/elk/training/reporter.py +++ b/elk/training/reporter.py @@ -9,11 +9,10 @@ import torch.nn as nn from einops import rearrange, repeat from simple_parsing.helpers import Serializable -from sklearn.metrics import roc_auc_score from torch import Tensor from ..calibration import CalibrationError -from ..metrics import accuracy, to_one_hot +from ..metrics import RocAucResult, accuracy, roc_auc_ci, to_one_hot class EvalResult(NamedTuple): @@ -24,9 +23,12 @@ class EvalResult(NamedTuple): accuracy, calibrated accuracy, and AUROC. """ + auroc: float + auroc_lower: float + auroc_upper: float + acc: float cal_acc: float - auroc: float ece: float @@ -101,6 +103,13 @@ def score_contrast_set(self, labels: Tensor, contrast_set: Tensor) -> EvalResult where x is the proprtion of examples with ground truth label `True`. AUROC: averaged over the n * v * c binary questions ECE: Expected Calibration Error + accuracy, and AUROC of the probe on `contrast_set`. + Accuracy: top-1 accuracy averaged over questions and variants. + Calibrated accuracy: top-1 accuracy averaged over questions and + variants, calibrated so that x% of the predictions are `True`, + where x is the proprtion of examples with ground truth label `True`. + AUROC: averaged over the n * v * c binary questions + ECE: Expected Calibration Error """ logits = self(contrast_set) (_, v, c) = logits.shape @@ -123,18 +132,16 @@ def score_contrast_set(self, labels: Tensor, contrast_set: Tensor) -> EvalResult cal_err = 0.0 Y_one_hot = to_one_hot(Y, c).long().flatten() - if len(labels.unique()) == 1: - auroc = -1.0 - else: - auroc = roc_auc_score(Y_one_hot.cpu(), logits.cpu().flatten()) + auroc_result = RocAucResult(-1., -1., -1.) if len(labels.unique()) == 1 else roc_auc_ci(Y_one_hot.cpu(), logits.cpu().flatten()) raw_preds = logits.argmax(dim=-1).long() raw_acc = accuracy(Y, raw_preds.flatten()) - return EvalResult( + auroc=auroc_result.estimate, + auroc_lower=auroc_result.lower, + auroc_upper=auroc_result.upper, acc=float(raw_acc), cal_acc=cal_acc, - auroc=float(auroc), ece=cal_err, ) @@ -165,14 +172,13 @@ def score(self, labels: Tensor, hiddens: Tensor) -> EvalResult: preds = probs.gt(0.5).to(torch.int) acc = preds.flatten().eq(labels).float().mean().item() - if len(labels.unique()) == 1: - auroc = -1.0 - else: - auroc = roc_auc_score(labels.cpu(), logits.cpu().flatten()) + auroc_result = RocAucResult(-1., -1., -1.) if len(labels.unique()) == 1 else roc_auc_ci(labels.cpu(), logits.cpu().flatten()) return EvalResult( + auroc=auroc_result.estimate, + auroc_lower=auroc_result.lower, + auroc_upper=auroc_result.upper, acc=float(acc), cal_acc=cal_acc, - auroc=float(auroc), ece=cal_err, ) diff --git a/elk/training/supervised.py b/elk/training/supervised.py index 62b7375d..ded52fb6 100644 --- a/elk/training/supervised.py +++ b/elk/training/supervised.py @@ -1,22 +1,21 @@ import torch from einops import rearrange, repeat -from sklearn.metrics import roc_auc_score from torch import Tensor -from ..metrics import accuracy, to_one_hot +from ..metrics import RocAucResult, accuracy, roc_auc_ci, to_one_hot from ..utils import assert_type from .classifier import Classifier def evaluate_supervised( lr_model: Classifier, val_h: Tensor, val_labels: Tensor -) -> tuple[float, float]: +) -> tuple[RocAucResult, float]: if len(val_h.shape) == 4: # hiddens in a contrast set - (n, v, c, d) = val_h.shape + (_, v, c, _) = val_h.shape with torch.no_grad(): - logits = rearrange(lr_model(val_h).cpu().squeeze(), "n v c -> (n v) c") + logits = rearrange(lr_model(val_h).squeeze(), "n v c -> (n v) c") raw_preds = to_one_hot(logits.argmax(dim=-1), c).long() labels = repeat(val_labels, "n -> (n v)", v=v) @@ -35,12 +34,9 @@ def evaluate_supervised( raise ValueError(f"Invalid val_h shape: {val_h.shape}") lr_acc = accuracy(labels, raw_preds.flatten()) - if len(labels.unique()) == 1: - lr_auroc = -1.0 - else: - lr_auroc = roc_auc_score(labels.cpu(), logits.cpu().flatten()) + lr_auroc = RocAucResult(-1., -1., -1.) if len(labels.unique()) == 1 else roc_auc_ci(labels, logits.flatten()) - return assert_type(float, lr_auroc), assert_type(float, lr_acc) + return lr_auroc, assert_type(float, lr_acc) def train_supervised(data: dict[str, tuple], device: str, cv: bool) -> Classifier: diff --git a/elk/training/train.py b/elk/training/train.py index 56543c75..f874718f 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -9,10 +9,9 @@ import torch from einops import rearrange, repeat from simple_parsing import Serializable, field, subgroups -from sklearn.metrics import roc_auc_score from ..extraction.extraction import Extract -from ..metrics import accuracy, to_one_hot +from ..metrics import accuracy, roc_auc_ci, to_one_hot from ..run import Run from ..training.supervised import evaluate_supervised, train_supervised from ..utils import select_usable_devices @@ -160,20 +159,6 @@ def train_reporter( ds_probs = {f"reporter_{layer}": reporter(val_h).cpu().numpy()} val_result = reporter.score_contrast_set(val_gt, val_h) - if val_lm_preds is not None: - (_, v, k, _) = val_h.shape - - val_gt_cpu = repeat(val_gt, "n -> (n v)", v=v).cpu() - val_lm_preds = rearrange(val_lm_preds, "n v ... -> (n v) ...") - val_lm_auroc = roc_auc_score( - to_one_hot(val_gt_cpu, k).long().flatten(), val_lm_preds.flatten() - ) - - val_lm_acc = accuracy(val_gt_cpu, torch.from_numpy(val_lm_preds)) - else: - val_lm_auroc = None - val_lm_acc = None - row = pd.Series( { "dataset": ds_name, @@ -181,18 +166,33 @@ def train_reporter( "pseudo_auroc": pseudo_auroc, "train_loss": train_loss, **val_result._asdict(), - "lm_auroc": val_lm_auroc, - "lm_acc": val_lm_acc, } ) + if val_lm_preds is not None: + (_, v, k, _) = val_h.shape + + val_gt_rep = repeat(val_gt, "n -> (n v)", v=v) + val_lm_preds = rearrange(val_lm_preds, "n v ... -> (n v) ...") + val_lm_auroc_res = roc_auc_ci( + to_one_hot(val_gt_rep, k).long().flatten(), val_lm_preds.flatten() + ) + row["lm_auroc"] = val_lm_auroc_res.estimate + row["lm_auroc_lower"] = val_lm_auroc_res.lower + row["lm_auroc_upper"] = val_lm_auroc_res.upper + row["lm_acc"] = accuracy(val_gt_rep, val_lm_preds) + if lr_model is not None: - row["lr_auroc"], row["lr_acc"] = evaluate_supervised( + lr_auroc_res, row["lr_acc"] = evaluate_supervised( lr_model, val_h, val_gt ) with torch.no_grad(): ds_probs[f"lr_{layer}"] = lr_model(val_h).cpu().numpy().squeeze(-1) + row["lr_auroc"] = lr_auroc_res.estimate + row["lr_auroc_lower"] = lr_auroc_res.lower + row["lr_auroc_upper"] = lr_auroc_res.upper + row_buf.append(row) probs_buf[ds_name] = ds_probs diff --git a/pyproject.toml b/pyproject.toml index f688416d..6575e57a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,8 +21,6 @@ dependencies = [ "pandas", # Basically any version should work as long as it supports the user's CUDA version "pynvml", - # Doesn't really matter but before 1.0.0 there might be weird breaking changes - "scikit-learn>=1.0.0", # Needed for certain HF tokenizers "sentencepiece==0.1.97", # We upstreamed bugfixes for Literal types in 0.1.1 @@ -43,7 +41,8 @@ dev = [ "hypothesis", "pre-commit", "pytest", - "pyright" + "pyright", + "scikit-learn", ] [project.scripts] diff --git a/tests/test_roc_auc.py b/tests/test_roc_auc.py new file mode 100644 index 00000000..244bdb88 --- /dev/null +++ b/tests/test_roc_auc.py @@ -0,0 +1,53 @@ +import numpy as np +import torch +from sklearn.datasets import make_classification +from sklearn.linear_model import LogisticRegression +from sklearn.metrics import roc_auc_score + +from elk.metrics import roc_auc + + +def test_roc_auc_score(): + # Generate 1D binary classification dataset + X_1d, y_true_1d = make_classification(n_samples=1000, random_state=42) + + # Generate 2D matrix of binary classification datasets + X_2d_1, y_true_2d_1 = make_classification(n_samples=1000, random_state=43) + X_2d_2, y_true_2d_2 = make_classification(n_samples=1000, random_state=44) + + # Fit LR models and get predicted probabilities for 1D and 2D cases + lr_1d = LogisticRegression(random_state=42).fit(X_1d, y_true_1d) + y_scores_1d = lr_1d.predict_proba(X_1d)[:, 1] + + lr_2d_1 = LogisticRegression(random_state=42).fit(X_2d_1, y_true_2d_1) + y_scores_2d_1 = lr_2d_1.predict_proba(X_2d_1)[:, 1] + + lr_2d_2 = LogisticRegression(random_state=42).fit(X_2d_2, y_true_2d_2) + y_scores_2d_2 = lr_2d_2.predict_proba(X_2d_2)[:, 1] + + # Stack the datasets into 2D matrices + y_true_2d = np.vstack((y_true_2d_1, y_true_2d_2)) + y_scores_2d = np.vstack((y_scores_2d_1, y_scores_2d_2)) + + # Convert to PyTorch tensors + y_true_1d_torch = torch.tensor(y_true_1d) + y_scores_1d_torch = torch.tensor(y_scores_1d) + y_true_2d_torch = torch.tensor(y_true_2d) + y_scores_2d_torch = torch.tensor(y_scores_2d) + + # Calculate ROC AUC score using batch_roc_auc_score function for 1D and 2D cases + roc_auc_1d_torch = roc_auc(y_true_1d_torch, y_scores_1d_torch).item() + roc_auc_2d_torch = roc_auc(y_true_2d_torch, y_scores_2d_torch).numpy() + + # Calculate ROC AUC score with sklearn's roc_auc_score function for 1D and 2D cases + roc_auc_1d_sklearn = roc_auc_score(y_true_1d, y_scores_1d) + roc_auc_2d_sklearn = np.array( + [ + roc_auc_score(y_true_2d_1, y_scores_2d_1), + roc_auc_score(y_true_2d_2, y_scores_2d_2), + ] + ) + + # Assert that the results from the two implementations are almost equal + np.testing.assert_almost_equal(roc_auc_1d_torch, roc_auc_1d_sklearn) + np.testing.assert_almost_equal(roc_auc_2d_torch, roc_auc_2d_sklearn)