Skip to content

Commit

Permalink
Merge branch 'main' into raw-extraction
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexTMallen committed Apr 18, 2023
2 parents e986195 + 8bb3802 commit ad4cf34
Show file tree
Hide file tree
Showing 11 changed files with 235 additions and 61 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,4 @@ repos:
hooks:
- id: codespell
# The promptsource templates spuriously get flagged without this
args: ["--skip=*.yaml"]
args: ["-L fpr", "--skip=*.yaml"]
7 changes: 4 additions & 3 deletions elk/evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,12 +120,13 @@ def evaluate_reporter(
with open(lr_dir / f"layer_{layer}.pt", "rb") as f:
lr_model = torch.load(f, map_location=device).eval()

lr_auroc, lr_acc = evaluate_supervised(lr_model, val_h, val_gt)

lr_auroc_res, lr_acc = evaluate_supervised(lr_model, val_h, val_gt)
with torch.no_grad():
ds_preds[f"lr_{layer}"] = lr_model(val_h).cpu().numpy().squeeze(-1)

stats_row["lr_auroc"] = lr_auroc
stats_row["lr_auroc"] = lr_auroc_res.estimate
stats_row["lr_auroc_lower"] = lr_auroc_res.lower
stats_row["lr_auroc_upper"] = lr_auroc_res.upper
stats_row["lr_acc"] = lr_acc

row_buf.append(stats_row)
Expand Down
9 changes: 3 additions & 6 deletions elk/extraction/prompt_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def explode(self) -> list["PromptConfig"]:
def load_prompts(
ds_string: str,
label_column: Optional[str] = None,
num_classes: int = 0,
num_shots: int = 0,
num_variants: int = -1,
seed: int = 42,
Expand Down Expand Up @@ -184,9 +185,8 @@ def load_prompts(
if rank == 0:
print(f"Using {num_variants} variants of each prompt")

feats = assert_type(Features, ds.features)
label_column = infer_label_column(feats)
num_classes = infer_num_classes(feats[label_column])
label_column = label_column or infer_label_column(ds.features)
num_classes = num_classes or infer_num_classes(ds.features[label_column])
rng = Random(seed)

if num_shots > 0:
Expand All @@ -203,9 +203,6 @@ def load_prompts(
extra_cols = list(assert_type(Features, ds.features))
extra_cols.remove(label_column)

if label_column != "label":
ds = ds.rename_column(label_column, "label")

for example in ds:
yield _convert_to_prompts(
example,
Expand Down
122 changes: 122 additions & 0 deletions elk/metrics.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from typing import NamedTuple

import torch
from torch import Tensor


Expand Down Expand Up @@ -34,3 +37,122 @@ def accuracy(y_true: Tensor, y_pred: Tensor) -> float:
hard_preds = y_pred.argmax(-1)

return hard_preds.cpu().eq(y_true.cpu()).float().mean().item()


class RocAucResult(NamedTuple):
"""Named tuple for storing ROC AUC results."""

estimate: float
"""Point estimate of the ROC AUC computed on this sample."""
lower: float
"""Lower bound of the bootstrap confidence interval."""
upper: float
"""Upper bound of the bootstrap confidence interval."""


def roc_auc(y_true: Tensor, y_pred: Tensor) -> Tensor:
"""Area under the receiver operating characteristic curve (ROC AUC).
Unlike scikit-learn's implementation, this function supports batched inputs of
shape `(N, n)` where `N` is the number of datasets and `n` is the number of samples
within each dataset. This is primarily useful for efficiently computing bootstrap
confidence intervals.
Args:
y_true: Ground truth tensor of shape `(N,)` or `(N, n)`.
y_pred: Predicted class tensor of shape `(N,)` or `(N, n)`.
Returns:
Tensor: If the inputs are 1D, a scalar containing the ROC AUC. If they're 2D,
a tensor of shape (N,) containing the ROC AUC for each dataset.
"""
if y_true.shape != y_pred.shape:
raise ValueError(
f"y_true and y_pred should have the same shape; "
f"got {y_true.shape} and {y_pred.shape}"
)
if y_true.dim() not in (1, 2):
raise ValueError("y_true and y_pred should be 1D or 2D tensors")

# Sort y_pred in descending order and get indices
indices = y_pred.argsort(descending=True, dim=-1)

# Reorder y_true based on sorted y_pred indices
y_true_sorted = y_true.gather(-1, indices)

# Calculate number of positive and negative samples
num_positives = y_true.sum(dim=-1)
num_negatives = y_true.shape[-1] - num_positives

# Calculate cumulative sum of true positive counts (TPs)
tps = torch.cumsum(y_true_sorted, dim=-1)

# Calculate cumulative sum of false positive counts (FPs)
fps = torch.cumsum(1 - y_true_sorted, dim=-1)

# Calculate true positive rate (TPR) and false positive rate (FPR)
tpr = tps / num_positives.view(-1, 1)
fpr = fps / num_negatives.view(-1, 1)

# Calculate differences between consecutive FPR values (widths of trapezoids)
fpr_diffs = torch.cat(
[fpr[..., 1:] - fpr[..., :-1], torch.zeros_like(fpr[..., :1])], dim=-1
)

# Calculate area under the ROC curve for each dataset using trapezoidal rule
return torch.sum(tpr * fpr_diffs, dim=-1).squeeze()


def roc_auc_ci(
y_true: Tensor,
y_pred: Tensor,
*,
num_samples: int = 1000,
level: float = 0.95,
seed: int = 42,
) -> RocAucResult:
"""Bootstrap confidence interval for the ROC AUC.
Args:
y_true: Ground truth tensor of shape `(N,)`.
y_pred: Predicted class tensor of shape `(N,)`.
num_samples (int): Number of bootstrap samples to use.
level (float): Confidence level of the confidence interval.
seed (int): Random seed for reproducibility.
Returns:
RocAucResult: Named tuple containing the lower and upper bounds of the
confidence interval, along with the point estimate.
"""
if y_true.shape != y_pred.shape:
raise ValueError(
f"y_true and y_pred should have the same shape; "
f"got {y_true.shape} and {y_pred.shape}"
)
if y_true.dim() != 1:
raise ValueError("y_true and y_pred should be 1D tensors")

device = y_true.device
N = y_true.shape[0]

# Generate random indices for bootstrap samples (shape: [num_bootstraps, N])
rng = torch.Generator(device=device).manual_seed(seed)
indices = torch.randint(0, N, (num_samples, N), device=device, generator=rng)

# Create bootstrap samples of true labels and predicted probabilities
y_true_bootstraps = y_true[indices]
y_pred_bootstraps = y_pred[indices]

# Compute ROC AUC scores for bootstrap samples
bootstrap_aucs = roc_auc(y_true_bootstraps, y_pred_bootstraps)

# Calculate the lower and upper bounds of the confidence interval. We use
# nanquantile instead of quantile because some bootstrap samples may have
# NaN values due to the fact that they have only one class.
alpha = (1 - level) / 2
q = y_pred.new_tensor([alpha, 1 - alpha])
lower, upper = bootstrap_aucs.nanquantile(q).tolist()

# Compute the point estimate
estimate = roc_auc(y_true, y_pred).item()
return RocAucResult(estimate, lower, upper)
4 changes: 2 additions & 2 deletions elk/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def get_device(self, devices, world_size: int) -> str:

def prepare_data(
self, device: str, layer: int, split_type: Literal["train", "val"]
) -> dict[str, tuple[Tensor, Tensor, np.ndarray | None]]:
) -> dict[str, tuple[Tensor, Tensor, Tensor | None]]:
"""Prepare data for the specified layer and split type."""
assert self.cfg.data.prompts.datasets != [
"raw"
Expand All @@ -112,7 +112,7 @@ def prepare_individual(
labels = assert_type(Tensor, ds["label"])
hid = int16_to_float32(assert_type(Tensor, ds[f"hidden_{layer}"]))

with ds.formatted_as("numpy"):
with ds.formatted_as("torch", device=device):
has_preds = "model_preds" in ds.features
lm_preds = ds["model_preds"] if has_preds else None

Expand Down
6 changes: 3 additions & 3 deletions elk/training/ccs_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@

import torch
import torch.nn as nn
from sklearn.metrics import roc_auc_score
from torch import Tensor
from torch.nn.functional import binary_cross_entropy as bce

from ..metrics import roc_auc
from ..parsing import parse_loss
from ..utils.typing import assert_type
from .classifier import Classifier
Expand Down Expand Up @@ -175,8 +175,8 @@ def check_separability(
pseudo_preds = pseudo_clf(
# b v d -> (b v) d
torch.cat([val_x0, val_x1]).flatten(0, 1)
)
return float(roc_auc_score(pseudo_val_labels.cpu(), pseudo_preds.cpu()))
).squeeze(-1)
return roc_auc(pseudo_val_labels, pseudo_preds).item()

def unsupervised_loss(self, logit0: Tensor, logit1: Tensor) -> Tensor:
loss = sum(
Expand Down
34 changes: 20 additions & 14 deletions elk/training/reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,10 @@
import torch.nn as nn
from einops import rearrange, repeat
from simple_parsing.helpers import Serializable
from sklearn.metrics import roc_auc_score
from torch import Tensor

from ..calibration import CalibrationError
from ..metrics import accuracy, to_one_hot
from ..metrics import RocAucResult, accuracy, roc_auc_ci, to_one_hot


class EvalResult(NamedTuple):
Expand All @@ -24,9 +23,12 @@ class EvalResult(NamedTuple):
accuracy, calibrated accuracy, and AUROC.
"""

auroc: float
auroc_lower: float
auroc_upper: float

acc: float
cal_acc: float
auroc: float
ece: float


Expand Down Expand Up @@ -101,6 +103,13 @@ def score_contrast_set(self, labels: Tensor, contrast_set: Tensor) -> EvalResult
where x is the proprtion of examples with ground truth label `True`.
AUROC: averaged over the n * v * c binary questions
ECE: Expected Calibration Error
accuracy, and AUROC of the probe on `contrast_set`.
Accuracy: top-1 accuracy averaged over questions and variants.
Calibrated accuracy: top-1 accuracy averaged over questions and
variants, calibrated so that x% of the predictions are `True`,
where x is the proprtion of examples with ground truth label `True`.
AUROC: averaged over the n * v * c binary questions
ECE: Expected Calibration Error
"""
logits = self(contrast_set)
(_, v, c) = logits.shape
Expand All @@ -123,18 +132,16 @@ def score_contrast_set(self, labels: Tensor, contrast_set: Tensor) -> EvalResult
cal_err = 0.0

Y_one_hot = to_one_hot(Y, c).long().flatten()
if len(labels.unique()) == 1:
auroc = -1.0
else:
auroc = roc_auc_score(Y_one_hot.cpu(), logits.cpu().flatten())
auroc_result = RocAucResult(-1., -1., -1.) if len(labels.unique()) == 1 else roc_auc_ci(Y_one_hot.cpu(), logits.cpu().flatten())

raw_preds = logits.argmax(dim=-1).long()
raw_acc = accuracy(Y, raw_preds.flatten())

return EvalResult(
auroc=auroc_result.estimate,
auroc_lower=auroc_result.lower,
auroc_upper=auroc_result.upper,
acc=float(raw_acc),
cal_acc=cal_acc,
auroc=float(auroc),
ece=cal_err,
)

Expand Down Expand Up @@ -165,14 +172,13 @@ def score(self, labels: Tensor, hiddens: Tensor) -> EvalResult:
preds = probs.gt(0.5).to(torch.int)
acc = preds.flatten().eq(labels).float().mean().item()

if len(labels.unique()) == 1:
auroc = -1.0
else:
auroc = roc_auc_score(labels.cpu(), logits.cpu().flatten())
auroc_result = RocAucResult(-1., -1., -1.) if len(labels.unique()) == 1 else roc_auc_ci(labels.cpu(), logits.cpu().flatten())

return EvalResult(
auroc=auroc_result.estimate,
auroc_lower=auroc_result.lower,
auroc_upper=auroc_result.upper,
acc=float(acc),
cal_acc=cal_acc,
auroc=float(auroc),
ece=cal_err,
)
16 changes: 6 additions & 10 deletions elk/training/supervised.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,21 @@
import torch
from einops import rearrange, repeat
from sklearn.metrics import roc_auc_score
from torch import Tensor

from ..metrics import accuracy, to_one_hot
from ..metrics import RocAucResult, accuracy, roc_auc_ci, to_one_hot
from ..utils import assert_type
from .classifier import Classifier


def evaluate_supervised(
lr_model: Classifier, val_h: Tensor, val_labels: Tensor
) -> tuple[float, float]:
) -> tuple[RocAucResult, float]:
if len(val_h.shape) == 4:
# hiddens in a contrast set
(n, v, c, d) = val_h.shape
(_, v, c, _) = val_h.shape

with torch.no_grad():
logits = rearrange(lr_model(val_h).cpu().squeeze(), "n v c -> (n v) c")
logits = rearrange(lr_model(val_h).squeeze(), "n v c -> (n v) c")
raw_preds = to_one_hot(logits.argmax(dim=-1), c).long()

labels = repeat(val_labels, "n -> (n v)", v=v)
Expand All @@ -35,12 +34,9 @@ def evaluate_supervised(
raise ValueError(f"Invalid val_h shape: {val_h.shape}")

lr_acc = accuracy(labels, raw_preds.flatten())
if len(labels.unique()) == 1:
lr_auroc = -1.0
else:
lr_auroc = roc_auc_score(labels.cpu(), logits.cpu().flatten())
lr_auroc = RocAucResult(-1., -1., -1.) if len(labels.unique()) == 1 else roc_auc_ci(labels, logits.flatten())

return assert_type(float, lr_auroc), assert_type(float, lr_acc)
return lr_auroc, assert_type(float, lr_acc)


def train_supervised(data: dict[str, tuple], device: str, cv: bool) -> Classifier:
Expand Down
Loading

0 comments on commit ad4cf34

Please sign in to comment.