From 5d4e7a414c4a2f629f6c298a09de4d3bef5b3496 Mon Sep 17 00:00:00 2001 From: Walter Laurito Date: Wed, 23 Aug 2023 13:01:22 +0200 Subject: [PATCH 1/6] add platt scaling for burns + fix for leace --- elk/training/ccs_reporter.py | 13 +++++++------ elk/training/train.py | 9 +++------ 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/elk/training/ccs_reporter.py b/elk/training/ccs_reporter.py index 472417f5..cf9325f6 100644 --- a/elk/training/ccs_reporter.py +++ b/elk/training/ccs_reporter.py @@ -88,6 +88,8 @@ def __init__( num_variants: int = 1, ): super().__init__() + self._is_training = True + self.config = cfg self.in_features = in_features self.num_variants = num_variants @@ -164,15 +166,12 @@ def reset_parameters(self): def forward(self, x: Tensor) -> Tensor: """Return the credence assigned to the hidden state `x`.""" assert self.norm is not None, "Must call fit() before forward()" - raw_scores = self.probe(self.norm(x)).squeeze(-1) - if self.config.norm == "leace": - return raw_scores.mul(self.scale).add(self.bias).squeeze(-1) - - elif self.config.norm == "burns": + if self._is_training: return raw_scores else: - raise ValueError(f"Unknown normalization {self.config.norm}.") + platt_scaled_scores = raw_scores.mul(self.scale).add(self.bias).squeeze(-1) + return platt_scaled_scores def loss(self, logit0: Tensor, logit1: Tensor) -> Tensor: """Return the loss of the reporter on the contrast pair (x0, x1). @@ -248,6 +247,8 @@ def fit(self, hiddens: Tensor) -> float: raise RuntimeError("Got NaN/infinite loss during training") self.load_state_dict(best_state) + + self._is_training = False return best_loss def train_loop_adam(self, x_neg: Tensor, x_pos: Tensor) -> float: diff --git a/elk/training/train.py b/elk/training/train.py index a7f0ef07..cc2d0fbc 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -83,12 +83,9 @@ def apply_to_layer( reporter = CcsReporter(self.net, d, device=device, num_variants=v) train_loss = reporter.fit(first_train_h) - if not self.net.norm == "burns": - (_, v, k, _) = first_train_h.shape - reporter.platt_scale( - to_one_hot(repeat(train_gt, "n -> (n v)", v=v), k).flatten(), - rearrange(first_train_h, "n v k d -> (n v k) d"), - ) + labels = to_one_hot(train_gt, k) + labels = repeat(train_gt, "n -> n v k", v=v, k=k) + reporter.platt_scale(labels, first_train_h) elif isinstance(self.net, EigenFitterConfig): fitter = EigenFitter( From 707ce8e87f2f5f57c714b274d98cb1cc72eed2bb Mon Sep 17 00:00:00 2001 From: Walter Laurito Date: Wed, 23 Aug 2023 17:23:52 +0200 Subject: [PATCH 2/6] add comment --- elk/training/ccs_reporter.py | 1 + 1 file changed, 1 insertion(+) diff --git a/elk/training/ccs_reporter.py b/elk/training/ccs_reporter.py index cf9325f6..1c65d98d 100644 --- a/elk/training/ccs_reporter.py +++ b/elk/training/ccs_reporter.py @@ -170,6 +170,7 @@ def forward(self, x: Tensor) -> Tensor: if self._is_training: return raw_scores else: + # only do platt scaling after training the reporters platt_scaled_scores = raw_scores.mul(self.scale).add(self.bias).squeeze(-1) return platt_scaled_scores From 777912ea1983966ca0b3c585bdad4a2291b5418a Mon Sep 17 00:00:00 2001 From: Walter Laurito Date: Wed, 23 Aug 2023 22:23:00 +0200 Subject: [PATCH 3/6] repeat onehot in template dimension --- elk/training/train.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/elk/training/train.py b/elk/training/train.py index cc2d0fbc..84e81ccc 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -82,10 +82,8 @@ def apply_to_layer( reporter = CcsReporter(self.net, d, device=device, num_variants=v) train_loss = reporter.fit(first_train_h) - labels = to_one_hot(train_gt, k) - labels = repeat(train_gt, "n -> n v k", v=v, k=k) - reporter.platt_scale(labels, first_train_h) + labels = repeat(labels, "n k -> n v k", v=v) elif isinstance(self.net, EigenFitterConfig): fitter = EigenFitter( From 89f8987cd2988b4b27483aa818c98553504e21c6 Mon Sep 17 00:00:00 2001 From: Walter Laurito Date: Wed, 23 Aug 2023 23:13:33 +0200 Subject: [PATCH 4/6] readd platt scaling --- elk/training/train.py | 1 + 1 file changed, 1 insertion(+) diff --git a/elk/training/train.py b/elk/training/train.py index 84e81ccc..c654ca3a 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -84,6 +84,7 @@ def apply_to_layer( train_loss = reporter.fit(first_train_h) labels = to_one_hot(train_gt, k) labels = repeat(labels, "n k -> n v k", v=v) + reporter.platt_scale(labels, first_train_h) elif isinstance(self.net, EigenFitterConfig): fitter = EigenFitter( From f678c90220de3d47ee5dc8dac97c65ac4f899cb9 Mon Sep 17 00:00:00 2001 From: Walter Laurito Date: Thu, 24 Aug 2023 00:13:18 +0200 Subject: [PATCH 5/6] remove platt scaling paramters from param --- elk/training/ccs_reporter.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/elk/training/ccs_reporter.py b/elk/training/ccs_reporter.py index 1c65d98d..7a55a858 100644 --- a/elk/training/ccs_reporter.py +++ b/elk/training/ccs_reporter.py @@ -9,6 +9,7 @@ import torch.nn as nn from concept_erasure import LeaceFitter from torch import Tensor +from typing_extensions import override from ..parsing import parse_loss from ..utils.typing import assert_type @@ -88,7 +89,6 @@ def __init__( num_variants: int = 1, ): super().__init__() - self._is_training = True self.config = cfg self.in_features = in_features @@ -130,6 +130,15 @@ def __init__( ) ) + @override + def parameters(self, recurse=True): + parameters = super(CcsReporter, self).parameters(recurse=recurse) + for param in parameters: + # exclude the platt scaling parameters + # kind of a hack for now, we should find probably a cleaner way + if param is not self.scale and param is not self.bias: + yield param + def reset_parameters(self): """Reset the parameters of the probe. @@ -167,12 +176,8 @@ def forward(self, x: Tensor) -> Tensor: """Return the credence assigned to the hidden state `x`.""" assert self.norm is not None, "Must call fit() before forward()" raw_scores = self.probe(self.norm(x)).squeeze(-1) - if self._is_training: - return raw_scores - else: - # only do platt scaling after training the reporters - platt_scaled_scores = raw_scores.mul(self.scale).add(self.bias).squeeze(-1) - return platt_scaled_scores + platt_scaled_scores = raw_scores.mul(self.scale).add(self.bias).squeeze(-1) + return platt_scaled_scores def loss(self, logit0: Tensor, logit1: Tensor) -> Tensor: """Return the loss of the reporter on the contrast pair (x0, x1). @@ -249,7 +254,6 @@ def fit(self, hiddens: Tensor) -> float: self.load_state_dict(best_state) - self._is_training = False return best_loss def train_loop_adam(self, x_neg: Tensor, x_pos: Tensor) -> float: From 6ed5eca7eb35967e5117f83179d21ba1e7ba145d Mon Sep 17 00:00:00 2001 From: Walter Laurito Date: Mon, 28 Aug 2023 00:02:28 +0200 Subject: [PATCH 6/6] remove extra line for label --- elk/training/train.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/elk/training/train.py b/elk/training/train.py index c654ca3a..fb882240 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -82,8 +82,7 @@ def apply_to_layer( reporter = CcsReporter(self.net, d, device=device, num_variants=v) train_loss = reporter.fit(first_train_h) - labels = to_one_hot(train_gt, k) - labels = repeat(labels, "n k -> n v k", v=v) + labels = repeat(to_one_hot(train_gt, k), "n k -> n v k", v=v) reporter.platt_scale(labels, first_train_h) elif isinstance(self.net, EigenFitterConfig):