Skip to content

Commit

Permalink
Merge pull request #288 from EleutherAI/fix-platt-scaling-ccs
Browse files Browse the repository at this point in the history
add platt scaling for burns + fix for leace
  • Loading branch information
lauritowal authored Aug 28, 2023
2 parents 3bbe26c + 6ed5eca commit 4a6b654
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 15 deletions.
22 changes: 14 additions & 8 deletions elk/training/ccs_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch.nn as nn
from concept_erasure import LeaceFitter
from torch import Tensor
from typing_extensions import override

from ..parsing import parse_loss
from ..utils.typing import assert_type
Expand Down Expand Up @@ -88,6 +89,7 @@ def __init__(
num_variants: int = 1,
):
super().__init__()

self.config = cfg
self.in_features = in_features
self.num_variants = num_variants
Expand Down Expand Up @@ -128,6 +130,15 @@ def __init__(
)
)

@override
def parameters(self, recurse=True):
parameters = super(CcsReporter, self).parameters(recurse=recurse)
for param in parameters:
# exclude the platt scaling parameters
# kind of a hack for now, we should find probably a cleaner way
if param is not self.scale and param is not self.bias:
yield param

def reset_parameters(self):
"""Reset the parameters of the probe.
Expand Down Expand Up @@ -164,15 +175,9 @@ def reset_parameters(self):
def forward(self, x: Tensor) -> Tensor:
"""Return the credence assigned to the hidden state `x`."""
assert self.norm is not None, "Must call fit() before forward()"

raw_scores = self.probe(self.norm(x)).squeeze(-1)
if self.config.norm == "leace":
return raw_scores.mul(self.scale).add(self.bias).squeeze(-1)

elif self.config.norm == "burns":
return raw_scores
else:
raise ValueError(f"Unknown normalization {self.config.norm}.")
platt_scaled_scores = raw_scores.mul(self.scale).add(self.bias).squeeze(-1)
return platt_scaled_scores

def loss(self, logit0: Tensor, logit1: Tensor) -> Tensor:
"""Return the loss of the reporter on the contrast pair (x0, x1).
Expand Down Expand Up @@ -248,6 +253,7 @@ def fit(self, hiddens: Tensor) -> float:
raise RuntimeError("Got NaN/infinite loss during training")

self.load_state_dict(best_state)

return best_loss

def train_loop_adam(self, x_neg: Tensor, x_pos: Tensor) -> float:
Expand Down
9 changes: 2 additions & 7 deletions elk/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,8 @@ def apply_to_layer(

reporter = CcsReporter(self.net, d, device=device, num_variants=v)
train_loss = reporter.fit(first_train_h)

if not self.net.norm == "burns":
(_, v, k, _) = first_train_h.shape
reporter.platt_scale(
to_one_hot(repeat(train_gt, "n -> (n v)", v=v), k).flatten(),
rearrange(first_train_h, "n v k d -> (n v k) d"),
)
labels = repeat(to_one_hot(train_gt, k), "n k -> n v k", v=v)
reporter.platt_scale(labels, first_train_h)

elif isinstance(self.net, EigenFitterConfig):
fitter = EigenFitter(
Expand Down

0 comments on commit 4a6b654

Please sign in to comment.