From d7008a4e7cc665fcc17beff3b2bcf67faf89327c Mon Sep 17 00:00:00 2001 From: Walter Laurito Date: Fri, 9 Jun 2023 15:45:43 +0200 Subject: [PATCH 01/41] add burns norm --- elk/training/burns_norm.py | 17 +++++++++++++++ elk/training/ccs_reporter.py | 41 +++++++++++++++++++++--------------- 2 files changed, 41 insertions(+), 17 deletions(-) create mode 100644 elk/training/burns_norm.py diff --git a/elk/training/burns_norm.py b/elk/training/burns_norm.py new file mode 100644 index 00000000..f25c199a --- /dev/null +++ b/elk/training/burns_norm.py @@ -0,0 +1,17 @@ +import torch +from torch import Tensor, nn + + +class BurnsNorm(nn.Module): + """ Burns et al. style normalization """ + + def forward(self, x: Tensor) -> Tensor: + breakpoint() + assert x.dim() == 3, "the input should have a dimension of 3." # TODO: add info about needed shape + + print("Per Prompt Normalization...") + x: Tensor = x - torch.mean(x, dim=0) + norm = torch.linalg.norm(x, dim=2) + avg_norm = torch.mean(norm) + return x / avg_norm * torch.sqrt(torch.tensor(x.shape[2], dtype=torch.float32)) + diff --git a/elk/training/ccs_reporter.py b/elk/training/ccs_reporter.py index f5e5da34..a0aaf5af 100644 --- a/elk/training/ccs_reporter.py +++ b/elk/training/ccs_reporter.py @@ -10,6 +10,7 @@ import torch.nn as nn from torch import Tensor +from .burns_norm import BurnsNorm from ..metrics import roc_auc from ..parsing import parse_loss from ..utils.typing import assert_type @@ -57,6 +58,7 @@ class CcsReporterConfig(ReporterConfig): init: Literal["default", "pca", "spherical", "zero"] = "default" loss: list[str] = field(default_factory=lambda: ["ccs"]) loss_dict: dict[str, float] = field(default_factory=dict, init=False) + normalization: Literal["lace", "burns"] = "leace" num_layers: int = 1 pre_ln: bool = False supervised_weight: float = 0.0 @@ -67,6 +69,7 @@ class CcsReporterConfig(ReporterConfig): optimizer: Literal["adam", "lbfgs"] = "lbfgs" weight_decay: float = 0.01 + @classmethod def reporter_class(cls) -> type[Reporter]: return CcsReporter @@ -107,13 +110,15 @@ def __init__( self.scale = nn.Parameter(torch.ones(1, device=device, dtype=dtype)) hidden_size = cfg.hidden_size or 4 * in_features // 3 - - self.norm = ConceptEraser( - in_features, - 2 * num_variants, - device=device, - dtype=dtype, - ) + if self.config.normalization == "burns": + self.norm = BurnsNorm() + else: + self.norm = ConceptEraser( + in_features, + 2 * num_variants, + device=device, + dtype=dtype, + ) self.probe = nn.Sequential( nn.Linear( in_features, @@ -262,16 +267,18 @@ def fit(self, hiddens: Tensor) -> float: n, v, _ = x_neg.shape prompt_ids = torch.eye(v, device=x_neg.device).expand(n, -1, -1) - self.norm.update( - x=x_neg, - # Independent indicator for each (template, pseudo-label) pair - y=torch.cat([torch.zeros_like(prompt_ids), prompt_ids], dim=-1), - ) - self.norm.update( - x=x_pos, - # Independent indicator for each (template, pseudo-label) pair - y=torch.cat([prompt_ids, torch.zeros_like(prompt_ids)], dim=-1), - ) + if self.config.normalization == "leace": + self.norm.update( + x=x_neg, + # Independent indicator for each (template, pseudo-label) pair + y=torch.cat([torch.zeros_like(prompt_ids), prompt_ids], dim=-1), + ) + self.norm.update( + x=x_pos, + # Independent indicator for each (template, pseudo-label) pair + y=torch.cat([prompt_ids, torch.zeros_like(prompt_ids)], dim=-1), + ) + x_neg, x_pos = self.norm(x_neg), self.norm(x_pos) # Record the best acc, loss, and params found so far From b22d15562601f0c10600216d5717ff71cba75328 Mon Sep 17 00:00:00 2001 From: Walter Laurito Date: Thu, 15 Jun 2023 13:12:55 +0200 Subject: [PATCH 02/41] cleanup class and add annotation --- elk/training/burns_norm.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/elk/training/burns_norm.py b/elk/training/burns_norm.py index f25c199a..d0266c40 100644 --- a/elk/training/burns_norm.py +++ b/elk/training/burns_norm.py @@ -3,15 +3,12 @@ class BurnsNorm(nn.Module): - """ Burns et al. style normalization """ + """ Burns et al. style normalization Minimal changes from the original code. """ - def forward(self, x: Tensor) -> Tensor: - breakpoint() - assert x.dim() == 3, "the input should have a dimension of 3." # TODO: add info about needed shape + def forward(self, x: Tensor) -> Tensor: + assert x.dim() == 3, f"the input should have a dimension of 3 not dimension {x.dim()}, shape of x: {x.shape}" - print("Per Prompt Normalization...") x: Tensor = x - torch.mean(x, dim=0) norm = torch.linalg.norm(x, dim=2) avg_norm = torch.mean(norm) - return x / avg_norm * torch.sqrt(torch.tensor(x.shape[2], dtype=torch.float32)) - + return x / avg_norm * torch.sqrt(torch.tensor(x.shape[2], dtype=torch.float32)) \ No newline at end of file From 29bd0af660c32d752705255d8665bf68a2502555 Mon Sep 17 00:00:00 2001 From: Walter Laurito Date: Thu, 15 Jun 2023 13:13:18 +0200 Subject: [PATCH 03/41] shorten arg name to norm --- elk/training/ccs_reporter.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/elk/training/ccs_reporter.py b/elk/training/ccs_reporter.py index a0aaf5af..b04c55cf 100644 --- a/elk/training/ccs_reporter.py +++ b/elk/training/ccs_reporter.py @@ -58,7 +58,7 @@ class CcsReporterConfig(ReporterConfig): init: Literal["default", "pca", "spherical", "zero"] = "default" loss: list[str] = field(default_factory=lambda: ["ccs"]) loss_dict: dict[str, float] = field(default_factory=dict, init=False) - normalization: Literal["lace", "burns"] = "leace" + norm: Literal["leace", "burns"] = "leace" # TODO: move to parent class ? num_layers: int = 1 pre_ln: bool = False supervised_weight: float = 0.0 @@ -110,8 +110,9 @@ def __init__( self.scale = nn.Parameter(torch.ones(1, device=device, dtype=dtype)) hidden_size = cfg.hidden_size or 4 * in_features // 3 - if self.config.normalization == "burns": + if self.config.norm == "burns": self.norm = BurnsNorm() + print("buuurn") else: self.norm = ConceptEraser( in_features, @@ -236,7 +237,7 @@ def reset_parameters(self): def forward(self, x: Tensor) -> Tensor: """Return the credence assigned to the hidden state `x`.""" - raw_scores = self.probe(self.norm(x)).squeeze(-1) + raw_scores = self.probe(x).squeeze(-1) return raw_scores.mul(self.scale).add(self.bias).squeeze(-1) def loss(self, logit0: Tensor, logit1: Tensor) -> Tensor: @@ -267,7 +268,7 @@ def fit(self, hiddens: Tensor) -> float: n, v, _ = x_neg.shape prompt_ids = torch.eye(v, device=x_neg.device).expand(n, -1, -1) - if self.config.normalization == "leace": + if self.config.norm == "leace": self.norm.update( x=x_neg, # Independent indicator for each (template, pseudo-label) pair @@ -278,7 +279,7 @@ def fit(self, hiddens: Tensor) -> float: # Independent indicator for each (template, pseudo-label) pair y=torch.cat([prompt_ids, torch.zeros_like(prompt_ids)], dim=-1), ) - + x_neg, x_pos = self.norm(x_neg), self.norm(x_pos) # Record the best acc, loss, and params found so far From 87ed4b85260d710a0a8e40b6bd0497dde722f015 Mon Sep 17 00:00:00 2001 From: Walter Laurito Date: Thu, 15 Jun 2023 18:05:19 +0200 Subject: [PATCH 04/41] cleanup comment --- elk/training/burns_norm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/elk/training/burns_norm.py b/elk/training/burns_norm.py index d0266c40..34de0c83 100644 --- a/elk/training/burns_norm.py +++ b/elk/training/burns_norm.py @@ -6,7 +6,7 @@ class BurnsNorm(nn.Module): """ Burns et al. style normalization Minimal changes from the original code. """ def forward(self, x: Tensor) -> Tensor: - assert x.dim() == 3, f"the input should have a dimension of 3 not dimension {x.dim()}, shape of x: {x.shape}" + assert x.dim() == 3, f"the input should have a dimension of 3 not dimension {x.dim()}, current shape of input x: {x.shape}" x: Tensor = x - torch.mean(x, dim=0) norm = torch.linalg.norm(x, dim=2) From bd7dfcebb22de2581dbab82d359a39a9298c2f67 Mon Sep 17 00:00:00 2001 From: Walter Laurito Date: Thu, 15 Jun 2023 18:07:53 +0200 Subject: [PATCH 05/41] remove print --- elk/training/ccs_reporter.py | 1 - 1 file changed, 1 deletion(-) diff --git a/elk/training/ccs_reporter.py b/elk/training/ccs_reporter.py index b04c55cf..e83d2f37 100644 --- a/elk/training/ccs_reporter.py +++ b/elk/training/ccs_reporter.py @@ -112,7 +112,6 @@ def __init__( hidden_size = cfg.hidden_size or 4 * in_features // 3 if self.config.norm == "burns": self.norm = BurnsNorm() - print("buuurn") else: self.norm = ConceptEraser( in_features, From 28d6044f897a10751fbce8c8ac0e47bdb5defce5 Mon Sep 17 00:00:00 2001 From: Walter Laurito Date: Thu, 15 Jun 2023 18:10:15 +0200 Subject: [PATCH 06/41] remove comment --- elk/training/ccs_reporter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/elk/training/ccs_reporter.py b/elk/training/ccs_reporter.py index e83d2f37..a410425f 100644 --- a/elk/training/ccs_reporter.py +++ b/elk/training/ccs_reporter.py @@ -58,7 +58,7 @@ class CcsReporterConfig(ReporterConfig): init: Literal["default", "pca", "spherical", "zero"] = "default" loss: list[str] = field(default_factory=lambda: ["ccs"]) loss_dict: dict[str, float] = field(default_factory=dict, init=False) - norm: Literal["leace", "burns"] = "leace" # TODO: move to parent class ? + norm: Literal["leace", "burns"] = "leace" num_layers: int = 1 pre_ln: bool = False supervised_weight: float = 0.0 From cd8fdc00cafb151f444c5ec7df3e58604076c97d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 15 Jun 2023 16:34:58 +0000 Subject: [PATCH 07/41] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- elk/training/burns_norm.py | 10 ++++++---- elk/training/ccs_reporter.py | 5 ++--- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/elk/training/burns_norm.py b/elk/training/burns_norm.py index 34de0c83..39bc8f64 100644 --- a/elk/training/burns_norm.py +++ b/elk/training/burns_norm.py @@ -3,12 +3,14 @@ class BurnsNorm(nn.Module): - """ Burns et al. style normalization Minimal changes from the original code. """ + """Burns et al. style normalization Minimal changes from the original code.""" - def forward(self, x: Tensor) -> Tensor: - assert x.dim() == 3, f"the input should have a dimension of 3 not dimension {x.dim()}, current shape of input x: {x.shape}" + def forward(self, x: Tensor) -> Tensor: + assert ( + x.dim() == 3 + ), f"the input should have a dimension of 3 not dimension {x.dim()}, current shape of input x: {x.shape}" x: Tensor = x - torch.mean(x, dim=0) norm = torch.linalg.norm(x, dim=2) avg_norm = torch.mean(norm) - return x / avg_norm * torch.sqrt(torch.tensor(x.shape[2], dtype=torch.float32)) \ No newline at end of file + return x / avg_norm * torch.sqrt(torch.tensor(x.shape[2], dtype=torch.float32)) diff --git a/elk/training/ccs_reporter.py b/elk/training/ccs_reporter.py index a410425f..ff66b3c9 100644 --- a/elk/training/ccs_reporter.py +++ b/elk/training/ccs_reporter.py @@ -10,10 +10,10 @@ import torch.nn as nn from torch import Tensor -from .burns_norm import BurnsNorm from ..metrics import roc_auc from ..parsing import parse_loss from ..utils.typing import assert_type +from .burns_norm import BurnsNorm from .classifier import Classifier from .concept_eraser import ConceptEraser from .losses import LOSSES @@ -69,7 +69,6 @@ class CcsReporterConfig(ReporterConfig): optimizer: Literal["adam", "lbfgs"] = "lbfgs" weight_decay: float = 0.01 - @classmethod def reporter_class(cls) -> type[Reporter]: return CcsReporter @@ -278,7 +277,7 @@ def fit(self, hiddens: Tensor) -> float: # Independent indicator for each (template, pseudo-label) pair y=torch.cat([prompt_ids, torch.zeros_like(prompt_ids)], dim=-1), ) - + x_neg, x_pos = self.norm(x_neg), self.norm(x_pos) # Record the best acc, loss, and params found so far From cc2dbd547ad227f83bf8a1e001a941dbaa046f01 Mon Sep 17 00:00:00 2001 From: Walter Laurito Date: Thu, 15 Jun 2023 18:53:12 +0200 Subject: [PATCH 08/41] pre-commit cleanup --- .pre-commit-config.yaml | 2 +- elk/training/burns_norm.py | 11 +++++++---- elk/training/ccs_reporter.py | 7 +++---- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 04929ca3..ded2fbfe 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -24,4 +24,4 @@ repos: hooks: - id: codespell # The promptsource templates spuriously get flagged without this - args: ["-L fpr", "--skip=*.yaml"] + args: ["-L fpr,leace", "--skip=*.yaml"] diff --git a/elk/training/burns_norm.py b/elk/training/burns_norm.py index 34de0c83..2e76b6e6 100644 --- a/elk/training/burns_norm.py +++ b/elk/training/burns_norm.py @@ -3,12 +3,15 @@ class BurnsNorm(nn.Module): - """ Burns et al. style normalization Minimal changes from the original code. """ + """Burns et al. style normalization Minimal changes from the original code.""" - def forward(self, x: Tensor) -> Tensor: - assert x.dim() == 3, f"the input should have a dimension of 3 not dimension {x.dim()}, current shape of input x: {x.shape}" + def forward(self, x: Tensor) -> Tensor: + assert ( + x.dim() == 3 + ), f"the input should have a dimension of 3 not dimension {x.dim()}, \ + current shape of input x: {x.shape}" x: Tensor = x - torch.mean(x, dim=0) norm = torch.linalg.norm(x, dim=2) avg_norm = torch.mean(norm) - return x / avg_norm * torch.sqrt(torch.tensor(x.shape[2], dtype=torch.float32)) \ No newline at end of file + return x / avg_norm * torch.sqrt(torch.tensor(x.shape[2], dtype=torch.float32)) diff --git a/elk/training/ccs_reporter.py b/elk/training/ccs_reporter.py index a410425f..f434894c 100644 --- a/elk/training/ccs_reporter.py +++ b/elk/training/ccs_reporter.py @@ -10,10 +10,10 @@ import torch.nn as nn from torch import Tensor -from .burns_norm import BurnsNorm from ..metrics import roc_auc from ..parsing import parse_loss from ..utils.typing import assert_type +from .burns_norm import BurnsNorm from .classifier import Classifier from .concept_eraser import ConceptEraser from .losses import LOSSES @@ -58,7 +58,7 @@ class CcsReporterConfig(ReporterConfig): init: Literal["default", "pca", "spherical", "zero"] = "default" loss: list[str] = field(default_factory=lambda: ["ccs"]) loss_dict: dict[str, float] = field(default_factory=dict, init=False) - norm: Literal["leace", "burns"] = "leace" + norm: Literal["leace", "burns"] = "leace" # codespell: ignore num_layers: int = 1 pre_ln: bool = False supervised_weight: float = 0.0 @@ -69,7 +69,6 @@ class CcsReporterConfig(ReporterConfig): optimizer: Literal["adam", "lbfgs"] = "lbfgs" weight_decay: float = 0.01 - @classmethod def reporter_class(cls) -> type[Reporter]: return CcsReporter @@ -278,7 +277,7 @@ def fit(self, hiddens: Tensor) -> float: # Independent indicator for each (template, pseudo-label) pair y=torch.cat([prompt_ids, torch.zeros_like(prompt_ids)], dim=-1), ) - + x_neg, x_pos = self.norm(x_neg), self.norm(x_pos) # Record the best acc, loss, and params found so far From 8e65e2dd0710e6e12ac6a77184ef08bc8fa75e70 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 16 Jun 2023 11:11:28 +0000 Subject: [PATCH 09/41] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- elk/extraction/extraction.py | 4 ++-- elk/training/burns_norm.py | 8 ++++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index 615faebf..fbb93763 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -295,7 +295,7 @@ def extract_hiddens( ) # Throw out layers we don't care about hiddens = [hiddens[i] for i in layer_indices] - + # Current shape of each element: (batch_size, seq_len, hidden_size) if cfg.token_loc == "first": hiddens = [h[..., 0, :] for h in hiddens] @@ -320,7 +320,7 @@ def extract_hiddens( # We skipped a variant because it was too long; move on to the next example if len(text_questions) != num_variants: continue - + out_record: dict[str, Any] = dict( label=example["label"], variant_ids=example["template_names"], diff --git a/elk/training/burns_norm.py b/elk/training/burns_norm.py index 62b7e477..d9b1281e 100644 --- a/elk/training/burns_norm.py +++ b/elk/training/burns_norm.py @@ -10,7 +10,7 @@ def forward(self, x: Tensor) -> Tensor: x.dim() == 3 ), f"the input should have a dimension of 3 not dimension {x.dim()}, \ current shape of input x: {x.shape}" - + x_mean: Tensor = x - torch.mean(x, dim=0) if torch.all(x_mean == 0): # input embeddings entries are identical, which leads to x_mean having only zero entries. @@ -18,4 +18,8 @@ def forward(self, x: Tensor) -> Tensor: else: norm = torch.linalg.norm(x_mean, dim=2) avg_norm = torch.mean(norm) - return x_mean / avg_norm * torch.sqrt(torch.tensor(x_mean.shape[2], dtype=torch.float32)) + return ( + x_mean + / avg_norm + * torch.sqrt(torch.tensor(x_mean.shape[2], dtype=torch.float32)) + ) From a3b8c3cb8ce08d3d2065985264b2a0b29cd3250a Mon Sep 17 00:00:00 2001 From: Walter Laurito Date: Fri, 16 Jun 2023 14:06:50 +0100 Subject: [PATCH 10/41] Update ccs_reporter.py remove ignore --- elk/training/ccs_reporter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/elk/training/ccs_reporter.py b/elk/training/ccs_reporter.py index f434894c..ff66b3c9 100644 --- a/elk/training/ccs_reporter.py +++ b/elk/training/ccs_reporter.py @@ -58,7 +58,7 @@ class CcsReporterConfig(ReporterConfig): init: Literal["default", "pca", "spherical", "zero"] = "default" loss: list[str] = field(default_factory=lambda: ["ccs"]) loss_dict: dict[str, float] = field(default_factory=dict, init=False) - norm: Literal["leace", "burns"] = "leace" # codespell: ignore + norm: Literal["leace", "burns"] = "leace" num_layers: int = 1 pre_ln: bool = False supervised_weight: float = 0.0 From 73bfa32b0b34ad7eeca8d68393db95749cd0a997 Mon Sep 17 00:00:00 2001 From: Walter Laurito Date: Fri, 16 Jun 2023 16:25:29 +0200 Subject: [PATCH 11/41] add nn.module --- elk/training/ccs_reporter.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/elk/training/ccs_reporter.py b/elk/training/ccs_reporter.py index ff66b3c9..57f3addd 100644 --- a/elk/training/ccs_reporter.py +++ b/elk/training/ccs_reporter.py @@ -110,9 +110,9 @@ def __init__( hidden_size = cfg.hidden_size or 4 * in_features // 3 if self.config.norm == "burns": - self.norm = BurnsNorm() + self.norm: nn.module = BurnsNorm() else: - self.norm = ConceptEraser( + self.norm: nn.module = ConceptEraser( in_features, 2 * num_variants, device=device, From fa8751bcf0ca7e403107b8239567e9864a956f5d Mon Sep 17 00:00:00 2001 From: Walter Laurito Date: Thu, 22 Jun 2023 11:02:04 +0000 Subject: [PATCH 12/41] correct annotation for norm --- elk/training/ccs_reporter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/elk/training/ccs_reporter.py b/elk/training/ccs_reporter.py index 57f3addd..225dda05 100644 --- a/elk/training/ccs_reporter.py +++ b/elk/training/ccs_reporter.py @@ -38,7 +38,7 @@ class CcsReporterConfig(ReporterConfig): Example: --loss 1.0*consistency_squared 0.5*prompt_var corresponds to the loss function 1.0*consistency_squared + 0.5*prompt_var. Defaults to the loss "ccs_squared_loss". - normalization: The kind of normalization to apply to the hidden states. + norm: The kind of normalization to apply to the hidden states. num_layers: The number of layers in the MLP. Defaults to 1. pre_ln: Whether to include a LayerNorm module before the first linear layer. Defaults to False. From a04f85215f847b005e50346ab77e530cab9731ac Mon Sep 17 00:00:00 2001 From: jon Date: Mon, 26 Jun 2023 14:58:49 +0100 Subject: [PATCH 13/41] remove erroring type annotation --- elk/training/burns_norm.py | 3 ++- elk/training/ccs_reporter.py | 5 +++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/elk/training/burns_norm.py b/elk/training/burns_norm.py index d9b1281e..dec6d90b 100644 --- a/elk/training/burns_norm.py +++ b/elk/training/burns_norm.py @@ -13,7 +13,8 @@ def forward(self, x: Tensor) -> Tensor: x_mean: Tensor = x - torch.mean(x, dim=0) if torch.all(x_mean == 0): - # input embeddings entries are identical, which leads to x_mean having only zero entries. + # input embeddings entries are identical, + # which leads to x_mean having only zero entries. return x else: norm = torch.linalg.norm(x_mean, dim=2) diff --git a/elk/training/ccs_reporter.py b/elk/training/ccs_reporter.py index 225dda05..a40bd7b9 100644 --- a/elk/training/ccs_reporter.py +++ b/elk/training/ccs_reporter.py @@ -109,10 +109,11 @@ def __init__( self.scale = nn.Parameter(torch.ones(1, device=device, dtype=dtype)) hidden_size = cfg.hidden_size or 4 * in_features // 3 + if self.config.norm == "burns": - self.norm: nn.module = BurnsNorm() + self.norm = BurnsNorm() else: - self.norm: nn.module = ConceptEraser( + self.norm = ConceptEraser( in_features, 2 * num_variants, device=device, From dffbe6808c31f18bb29ed564ae0e7260cce77e70 Mon Sep 17 00:00:00 2001 From: jon Date: Mon, 26 Jun 2023 16:44:25 +0100 Subject: [PATCH 14/41] patch over bug with pyright --- elk/training/ccs_reporter.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/elk/training/ccs_reporter.py b/elk/training/ccs_reporter.py index a40bd7b9..5d3008d2 100644 --- a/elk/training/ccs_reporter.py +++ b/elk/training/ccs_reporter.py @@ -268,12 +268,13 @@ def fit(self, hiddens: Tensor) -> float: prompt_ids = torch.eye(v, device=x_neg.device).expand(n, -1, -1) if self.config.norm == "leace": - self.norm.update( + # type ignore because otherwise throws error, probably bug with pyright + self.norm.update( # type: ignore x=x_neg, # Independent indicator for each (template, pseudo-label) pair y=torch.cat([torch.zeros_like(prompt_ids), prompt_ids], dim=-1), ) - self.norm.update( + self.norm.update( # type: ignore x=x_pos, # Independent indicator for each (template, pseudo-label) pair y=torch.cat([prompt_ids, torch.zeros_like(prompt_ids)], dim=-1), From 80826c8f5a27d24dff4ddcf19ec0d76bb814bb7a Mon Sep 17 00:00:00 2001 From: jon Date: Fri, 30 Jun 2023 11:47:59 +0100 Subject: [PATCH 15/41] remove ignores --- elk/training/ccs_reporter.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/elk/training/ccs_reporter.py b/elk/training/ccs_reporter.py index 5d3008d2..6397a24b 100644 --- a/elk/training/ccs_reporter.py +++ b/elk/training/ccs_reporter.py @@ -266,15 +266,13 @@ def fit(self, hiddens: Tensor) -> float: # One-hot indicators for each prompt template n, v, _ = x_neg.shape prompt_ids = torch.eye(v, device=x_neg.device).expand(n, -1, -1) - - if self.config.norm == "leace": - # type ignore because otherwise throws error, probably bug with pyright - self.norm.update( # type: ignore + if isinstance(self.norm, ConceptEraser): + self.norm.update( x=x_neg, # Independent indicator for each (template, pseudo-label) pair y=torch.cat([torch.zeros_like(prompt_ids), prompt_ids], dim=-1), ) - self.norm.update( # type: ignore + self.norm.update( x=x_pos, # Independent indicator for each (template, pseudo-label) pair y=torch.cat([prompt_ids, torch.zeros_like(prompt_ids)], dim=-1), From 70c737e750377e9ef98f40f0fc1ae9041b16e21b Mon Sep 17 00:00:00 2001 From: jon Date: Mon, 3 Jul 2023 18:42:12 +0000 Subject: [PATCH 16/41] fix nans better --- elk/training/burns_norm.py | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/elk/training/burns_norm.py b/elk/training/burns_norm.py index dec6d90b..b98a5f35 100644 --- a/elk/training/burns_norm.py +++ b/elk/training/burns_norm.py @@ -12,15 +12,11 @@ def forward(self, x: Tensor) -> Tensor: current shape of input x: {x.shape}" x_mean: Tensor = x - torch.mean(x, dim=0) - if torch.all(x_mean == 0): - # input embeddings entries are identical, - # which leads to x_mean having only zero entries. - return x - else: - norm = torch.linalg.norm(x_mean, dim=2) - avg_norm = torch.mean(norm) - return ( - x_mean - / avg_norm - * torch.sqrt(torch.tensor(x_mean.shape[2], dtype=torch.float32)) - ) + norm = torch.linalg.norm(x_mean, dim=2) + avg_norm = torch.mean(norm) + eps = torch.finfo(x.dtype).eps + return ( + x_mean + / (avg_norm + eps) # to avoid division by zero + * torch.sqrt(torch.tensor(x_mean.shape[2], dtype=torch.float32)) + ) From 92f9086e4f0161414fec2adb9c5235fd354e7ccc Mon Sep 17 00:00:00 2001 From: Walter Laurito Date: Wed, 19 Jul 2023 17:35:41 +0000 Subject: [PATCH 17/41] remove checking for embeddings --- elk/training/burns_norm.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/elk/training/burns_norm.py b/elk/training/burns_norm.py index d9b1281e..4fea3231 100644 --- a/elk/training/burns_norm.py +++ b/elk/training/burns_norm.py @@ -12,14 +12,10 @@ def forward(self, x: Tensor) -> Tensor: current shape of input x: {x.shape}" x_mean: Tensor = x - torch.mean(x, dim=0) - if torch.all(x_mean == 0): - # input embeddings entries are identical, which leads to x_mean having only zero entries. - return x - else: - norm = torch.linalg.norm(x_mean, dim=2) - avg_norm = torch.mean(norm) - return ( - x_mean - / avg_norm - * torch.sqrt(torch.tensor(x_mean.shape[2], dtype=torch.float32)) - ) + norm = torch.linalg.norm(x_mean, dim=2) + avg_norm = torch.mean(norm) + return ( + x_mean + / avg_norm + * torch.sqrt(torch.tensor(x_mean.shape[2], dtype=torch.float32)) + ) From 9027a925f26d571fd7ce3f3a4bb03baaba85b1a8 Mon Sep 17 00:00:00 2001 From: Walter Laurito Date: Wed, 19 Jul 2023 18:53:48 +0000 Subject: [PATCH 18/41] fix precommit stuff --- elk/training/ccs_reporter.py | 1 - 1 file changed, 1 deletion(-) diff --git a/elk/training/ccs_reporter.py b/elk/training/ccs_reporter.py index 09860314..7e80063d 100644 --- a/elk/training/ccs_reporter.py +++ b/elk/training/ccs_reporter.py @@ -13,7 +13,6 @@ from ..parsing import parse_loss from ..utils.typing import assert_type from .burns_norm import BurnsNorm -from .classifier import Classifier from .common import FitterConfig from .losses import LOSSES from .platt_scaling import PlattMixin From 4ae416a7297adc37fee6273257b4a0e008dfac3c Mon Sep 17 00:00:00 2001 From: Walter Laurito Date: Wed, 19 Jul 2023 19:55:46 +0000 Subject: [PATCH 19/41] readd self.norm --- elk/training/ccs_reporter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/elk/training/ccs_reporter.py b/elk/training/ccs_reporter.py index 7e80063d..ad56e858 100644 --- a/elk/training/ccs_reporter.py +++ b/elk/training/ccs_reporter.py @@ -164,7 +164,7 @@ def reset_parameters(self): def forward(self, x: Tensor) -> Tensor: """Return the credence assigned to the hidden state `x`.""" assert self.norm is not None, "Must call fit() before forward()" - raw_scores = self.probe(x).squeeze(-1) + raw_scores = self.probe(self.norm(x)).squeeze(-1) return raw_scores.mul(self.scale).add(self.bias).squeeze(-1) def loss(self, logit0: Tensor, logit1: Tensor) -> Tensor: From 55c867ec448212a140d1f4f6ffbc52223d546493 Mon Sep 17 00:00:00 2001 From: Walter Laurito Date: Wed, 19 Jul 2023 19:56:37 +0000 Subject: [PATCH 20/41] readd space --- elk/training/ccs_reporter.py | 1 + 1 file changed, 1 insertion(+) diff --git a/elk/training/ccs_reporter.py b/elk/training/ccs_reporter.py index ad56e858..39d4502e 100644 --- a/elk/training/ccs_reporter.py +++ b/elk/training/ccs_reporter.py @@ -164,6 +164,7 @@ def reset_parameters(self): def forward(self, x: Tensor) -> Tensor: """Return the credence assigned to the hidden state `x`.""" assert self.norm is not None, "Must call fit() before forward()" + raw_scores = self.probe(self.norm(x)).squeeze(-1) return raw_scores.mul(self.scale).add(self.bias).squeeze(-1) From b257c97b219e2494f7a9785e873f3f852b2a7f1e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 19 Jul 2023 19:56:52 +0000 Subject: [PATCH 21/41] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- elk/training/ccs_reporter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/elk/training/ccs_reporter.py b/elk/training/ccs_reporter.py index 39d4502e..01b9670c 100644 --- a/elk/training/ccs_reporter.py +++ b/elk/training/ccs_reporter.py @@ -164,7 +164,7 @@ def reset_parameters(self): def forward(self, x: Tensor) -> Tensor: """Return the credence assigned to the hidden state `x`.""" assert self.norm is not None, "Must call fit() before forward()" - + raw_scores = self.probe(self.norm(x)).squeeze(-1) return raw_scores.mul(self.scale).add(self.bias).squeeze(-1) From 6a12fdd00cead712540e28c387357aa790eb4112 Mon Sep 17 00:00:00 2001 From: Walter Laurito Date: Fri, 28 Jul 2023 15:59:07 +0000 Subject: [PATCH 22/41] fix burns norm to really normalize by prompt --- elk/training/burns_norm.py | 47 ++++++++++++++++++++++++++------------ 1 file changed, 32 insertions(+), 15 deletions(-) diff --git a/elk/training/burns_norm.py b/elk/training/burns_norm.py index 4fea3231..c337841f 100644 --- a/elk/training/burns_norm.py +++ b/elk/training/burns_norm.py @@ -1,21 +1,38 @@ +from dataclasses import dataclass import torch from torch import Tensor, nn - class BurnsNorm(nn.Module): - """Burns et al. style normalization Minimal changes from the original code.""" + """Burns et al. style normalization. Minimal changes from the original code.""" + + def __init__(self, scale:bool=True): + super().__init__() + self.scale: bool = scale def forward(self, x: Tensor) -> Tensor: - assert ( - x.dim() == 3 - ), f"the input should have a dimension of 3 not dimension {x.dim()}, \ - current shape of input x: {x.shape}" - - x_mean: Tensor = x - torch.mean(x, dim=0) - norm = torch.linalg.norm(x_mean, dim=2) - avg_norm = torch.mean(norm) - return ( - x_mean - / avg_norm - * torch.sqrt(torch.tensor(x_mean.shape[2], dtype=torch.float32)) - ) + num_elements = x.shape[0] + x_normalized: Tensor = x - x.mean(dim=0) if num_elements > 1 else x + + if not self.scale: + return x_normalized + else: + std = torch.linalg.norm(x_normalized, dim=0) / torch.sqrt( + torch.tensor(x_normalized.shape[0], dtype=torch.float32) + ) + assert std.dim() == x.dim() - 1 + + # Compute the dimensions over which we want to compute the mean standard deviation + dims = tuple(range(1, std.dim())) # exclude the first dimension (v) + + avg_norm = std.mean(dim=dims) + + # Add a singleton dimension at the beginning to allow broadcasting. + # This compensates for the dimension we lost when computing the norm. + avg_norm = avg_norm.unsqueeze(0) + + # Add singleton dimensions at the end to allow broadcasting. + # This compensates for the dimensions we lost when computing the mean. + for _ in range(1, x.dim() - 1): + avg_norm = avg_norm.unsqueeze(-1) + + return x_normalized / avg_norm \ No newline at end of file From 6a20c44b8490fd64bafb925d2a311a15e43faba6 Mon Sep 17 00:00:00 2001 From: Walter Laurito Date: Fri, 28 Jul 2023 15:59:27 +0000 Subject: [PATCH 23/41] add test for burns norm --- tests/test_burns_norm.py | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 tests/test_burns_norm.py diff --git a/tests/test_burns_norm.py b/tests/test_burns_norm.py new file mode 100644 index 00000000..ac885673 --- /dev/null +++ b/tests/test_burns_norm.py @@ -0,0 +1,38 @@ +import torch +from torch import Tensor + +from elk.training.burns_norm import BurnsNorm + + +def correct_but_slow_normalization(x_all: Tensor, scale=True) -> Tensor: + res = [] + xs = x_all.unbind(dim=1) + + for x in xs: + num_elements = x.shape[0] + x_mean: Tensor = x - x.mean(dim=0) if num_elements > 1 else x + if scale == True: + std = torch.linalg.norm(x_mean, axis=0) / torch.sqrt( + torch.tensor(x_mean.shape[0], dtype=torch.float32) + ) + avg_norm = std.mean() + x_mean = x_mean / avg_norm + res.append(x_mean) + + return torch.stack(res, dim=1) + +def test_BurnsNorm_3d_input(): + x_all_3d = torch.randn((2, 13, 768)) + expected_output_3d = correct_but_slow_normalization(x_all_3d) + bn = BurnsNorm() + output_3d = bn(x_all_3d) + diff = output_3d - expected_output_3d + assert (diff == torch.zeros_like(diff)).all() + +def test_BurnsNorm_4d_input(): + x_all_4d = torch.randn((2, 13, 2, 768)) + expected_output_4d = correct_but_slow_normalization(x_all_4d) + bn = BurnsNorm() + output_4d = bn(x_all_4d) + diff = output_4d - expected_output_4d + assert (diff == torch.zeros_like(diff)).all() From 8685c15b1383ec84d81e0b9aef830a4e1b162f30 Mon Sep 17 00:00:00 2001 From: Walter Laurito Date: Fri, 28 Jul 2023 16:04:57 +0000 Subject: [PATCH 24/41] no plattscale for burns norm; sigmoid instead --- elk/training/ccs_reporter.py | 31 ++++++++++++++++++++++--------- elk/training/train.py | 13 +++++++------ 2 files changed, 29 insertions(+), 15 deletions(-) diff --git a/elk/training/ccs_reporter.py b/elk/training/ccs_reporter.py index 01b9670c..29738dd0 100644 --- a/elk/training/ccs_reporter.py +++ b/elk/training/ccs_reporter.py @@ -100,14 +100,25 @@ def __init__( self.norm = None - self.probe = nn.Sequential( - nn.Linear( - in_features, - 1 if cfg.num_layers < 2 else hidden_size, - bias=cfg.bias, - device=device, - ), - ) + if self.cfg.norm == "burns": + self.probe = nn.Sequential( + nn.Linear( + in_features, + 1 if cfg.num_layers < 2 else hidden_size, + bias=cfg.bias, + device=device, + ), + nn.Sigmoid() + ) + else: + self.probe = nn.Sequential( + nn.Linear( + in_features, + 1 if cfg.num_layers < 2 else hidden_size, + bias=cfg.bias, + device=device, + ) + ) if cfg.pre_ln: self.probe.insert(0, nn.LayerNorm(in_features, elementwise_affine=False)) @@ -164,8 +175,10 @@ def reset_parameters(self): def forward(self, x: Tensor) -> Tensor: """Return the credence assigned to the hidden state `x`.""" assert self.norm is not None, "Must call fit() before forward()" - + raw_scores = self.probe(self.norm(x)).squeeze(-1) + return raw_scores + breakpoint() return raw_scores.mul(self.scale).add(self.bias).squeeze(-1) def loss(self, logit0: Tensor, logit1: Tensor) -> Tensor: diff --git a/elk/training/train.py b/elk/training/train.py index 8392f2d9..014125df 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -82,12 +82,13 @@ def apply_to_layer( reporter = CcsReporter(self.net, d, device=device, num_variants=v) train_loss = reporter.fit(first_train_h) - - (_, v, k, _) = first_train_h.shape - reporter.platt_scale( - to_one_hot(repeat(train_gt, "n -> (n v)", v=v), k).flatten(), - rearrange(first_train_h, "n v k d -> (n v k) d"), - ) + + if self.net.norm == "leace": + (_, v, k, _) = first_train_h.shape + reporter.platt_scale( + to_one_hot(repeat(train_gt, "n -> (n v)", v=v), k).flatten(), + rearrange(first_train_h, "n v k d -> (n v k) d"), + ) elif isinstance(self.net, EigenFitterConfig): fitter = EigenFitter( From 5760f7c0d75b308aa345cd7174c119e1d682fc83 Mon Sep 17 00:00:00 2001 From: Walter Laurito Date: Fri, 28 Jul 2023 18:27:38 +0000 Subject: [PATCH 25/41] fix naming --- elk/training/ccs_reporter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/elk/training/ccs_reporter.py b/elk/training/ccs_reporter.py index 29738dd0..3193001e 100644 --- a/elk/training/ccs_reporter.py +++ b/elk/training/ccs_reporter.py @@ -100,7 +100,7 @@ def __init__( self.norm = None - if self.cfg.norm == "burns": + if self.config.norm == "burns": self.probe = nn.Sequential( nn.Linear( in_features, From 7f5a88a8b88dcb620314ce4d034d4e421df001b7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 28 Jul 2023 18:27:53 +0000 Subject: [PATCH 26/41] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- elk/training/burns_norm.py | 14 +++++++------- elk/training/ccs_reporter.py | 6 +++--- elk/training/train.py | 2 +- tests/test_burns_norm.py | 4 +++- 4 files changed, 14 insertions(+), 12 deletions(-) diff --git a/elk/training/burns_norm.py b/elk/training/burns_norm.py index c337841f..44322ed1 100644 --- a/elk/training/burns_norm.py +++ b/elk/training/burns_norm.py @@ -1,16 +1,16 @@ -from dataclasses import dataclass import torch from torch import Tensor, nn + class BurnsNorm(nn.Module): """Burns et al. style normalization. Minimal changes from the original code.""" - - def __init__(self, scale:bool=True): + + def __init__(self, scale: bool = True): super().__init__() self.scale: bool = scale def forward(self, x: Tensor) -> Tensor: - num_elements = x.shape[0] + num_elements = x.shape[0] x_normalized: Tensor = x - x.mean(dim=0) if num_elements > 1 else x if not self.scale: @@ -22,11 +22,11 @@ def forward(self, x: Tensor) -> Tensor: assert std.dim() == x.dim() - 1 # Compute the dimensions over which we want to compute the mean standard deviation - dims = tuple(range(1, std.dim())) # exclude the first dimension (v) + dims = tuple(range(1, std.dim())) # exclude the first dimension (v) avg_norm = std.mean(dim=dims) - # Add a singleton dimension at the beginning to allow broadcasting. + # Add a singleton dimension at the beginning to allow broadcasting. # This compensates for the dimension we lost when computing the norm. avg_norm = avg_norm.unsqueeze(0) @@ -35,4 +35,4 @@ def forward(self, x: Tensor) -> Tensor: for _ in range(1, x.dim() - 1): avg_norm = avg_norm.unsqueeze(-1) - return x_normalized / avg_norm \ No newline at end of file + return x_normalized / avg_norm diff --git a/elk/training/ccs_reporter.py b/elk/training/ccs_reporter.py index 3193001e..76a93236 100644 --- a/elk/training/ccs_reporter.py +++ b/elk/training/ccs_reporter.py @@ -108,7 +108,7 @@ def __init__( bias=cfg.bias, device=device, ), - nn.Sigmoid() + nn.Sigmoid(), ) else: self.probe = nn.Sequential( @@ -118,7 +118,7 @@ def __init__( bias=cfg.bias, device=device, ) - ) + ) if cfg.pre_ln: self.probe.insert(0, nn.LayerNorm(in_features, elementwise_affine=False)) @@ -175,7 +175,7 @@ def reset_parameters(self): def forward(self, x: Tensor) -> Tensor: """Return the credence assigned to the hidden state `x`.""" assert self.norm is not None, "Must call fit() before forward()" - + raw_scores = self.probe(self.norm(x)).squeeze(-1) return raw_scores breakpoint() diff --git a/elk/training/train.py b/elk/training/train.py index 822aa2c4..b6888ea7 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -82,7 +82,7 @@ def apply_to_layer( reporter = CcsReporter(self.net, d, device=device, num_variants=v) train_loss = reporter.fit(first_train_h) - + if self.net.norm == "leace": (_, v, k, _) = first_train_h.shape reporter.platt_scale( diff --git a/tests/test_burns_norm.py b/tests/test_burns_norm.py index ac885673..4620ae0e 100644 --- a/tests/test_burns_norm.py +++ b/tests/test_burns_norm.py @@ -11,7 +11,7 @@ def correct_but_slow_normalization(x_all: Tensor, scale=True) -> Tensor: for x in xs: num_elements = x.shape[0] x_mean: Tensor = x - x.mean(dim=0) if num_elements > 1 else x - if scale == True: + if scale is True: std = torch.linalg.norm(x_mean, axis=0) / torch.sqrt( torch.tensor(x_mean.shape[0], dtype=torch.float32) ) @@ -21,6 +21,7 @@ def correct_but_slow_normalization(x_all: Tensor, scale=True) -> Tensor: return torch.stack(res, dim=1) + def test_BurnsNorm_3d_input(): x_all_3d = torch.randn((2, 13, 768)) expected_output_3d = correct_but_slow_normalization(x_all_3d) @@ -29,6 +30,7 @@ def test_BurnsNorm_3d_input(): diff = output_3d - expected_output_3d assert (diff == torch.zeros_like(diff)).all() + def test_BurnsNorm_4d_input(): x_all_4d = torch.randn((2, 13, 2, 768)) expected_output_4d = correct_but_slow_normalization(x_all_4d) From b15f55e94c4d0d5da0efeeb098d1dcdbba524315 Mon Sep 17 00:00:00 2001 From: Walter Laurito Date: Fri, 28 Jul 2023 18:49:58 +0000 Subject: [PATCH 27/41] pre-commit cleanup --- elk/training/burns_norm.py | 18 ++++++++++-------- elk/training/ccs_reporter.py | 6 +++--- elk/training/train.py | 2 +- tests/test_burns_norm.py | 4 +++- 4 files changed, 17 insertions(+), 13 deletions(-) diff --git a/elk/training/burns_norm.py b/elk/training/burns_norm.py index c337841f..e88616ba 100644 --- a/elk/training/burns_norm.py +++ b/elk/training/burns_norm.py @@ -1,16 +1,16 @@ -from dataclasses import dataclass import torch from torch import Tensor, nn + class BurnsNorm(nn.Module): """Burns et al. style normalization. Minimal changes from the original code.""" - - def __init__(self, scale:bool=True): + + def __init__(self, scale: bool = True): super().__init__() self.scale: bool = scale def forward(self, x: Tensor) -> Tensor: - num_elements = x.shape[0] + num_elements = x.shape[0] x_normalized: Tensor = x - x.mean(dim=0) if num_elements > 1 else x if not self.scale: @@ -21,12 +21,14 @@ def forward(self, x: Tensor) -> Tensor: ) assert std.dim() == x.dim() - 1 - # Compute the dimensions over which we want to compute the mean standard deviation - dims = tuple(range(1, std.dim())) # exclude the first dimension (v) + # Compute the dimensions over which + # we want to compute the mean standard deviation + # exclude the first dimension (v) + dims = tuple(range(1, std.dim())) avg_norm = std.mean(dim=dims) - # Add a singleton dimension at the beginning to allow broadcasting. + # Add a singleton dimension at the beginning to allow broadcasting. # This compensates for the dimension we lost when computing the norm. avg_norm = avg_norm.unsqueeze(0) @@ -35,4 +37,4 @@ def forward(self, x: Tensor) -> Tensor: for _ in range(1, x.dim() - 1): avg_norm = avg_norm.unsqueeze(-1) - return x_normalized / avg_norm \ No newline at end of file + return x_normalized / avg_norm diff --git a/elk/training/ccs_reporter.py b/elk/training/ccs_reporter.py index 3193001e..76a93236 100644 --- a/elk/training/ccs_reporter.py +++ b/elk/training/ccs_reporter.py @@ -108,7 +108,7 @@ def __init__( bias=cfg.bias, device=device, ), - nn.Sigmoid() + nn.Sigmoid(), ) else: self.probe = nn.Sequential( @@ -118,7 +118,7 @@ def __init__( bias=cfg.bias, device=device, ) - ) + ) if cfg.pre_ln: self.probe.insert(0, nn.LayerNorm(in_features, elementwise_affine=False)) @@ -175,7 +175,7 @@ def reset_parameters(self): def forward(self, x: Tensor) -> Tensor: """Return the credence assigned to the hidden state `x`.""" assert self.norm is not None, "Must call fit() before forward()" - + raw_scores = self.probe(self.norm(x)).squeeze(-1) return raw_scores breakpoint() diff --git a/elk/training/train.py b/elk/training/train.py index 822aa2c4..b6888ea7 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -82,7 +82,7 @@ def apply_to_layer( reporter = CcsReporter(self.net, d, device=device, num_variants=v) train_loss = reporter.fit(first_train_h) - + if self.net.norm == "leace": (_, v, k, _) = first_train_h.shape reporter.platt_scale( diff --git a/tests/test_burns_norm.py b/tests/test_burns_norm.py index ac885673..4620ae0e 100644 --- a/tests/test_burns_norm.py +++ b/tests/test_burns_norm.py @@ -11,7 +11,7 @@ def correct_but_slow_normalization(x_all: Tensor, scale=True) -> Tensor: for x in xs: num_elements = x.shape[0] x_mean: Tensor = x - x.mean(dim=0) if num_elements > 1 else x - if scale == True: + if scale is True: std = torch.linalg.norm(x_mean, axis=0) / torch.sqrt( torch.tensor(x_mean.shape[0], dtype=torch.float32) ) @@ -21,6 +21,7 @@ def correct_but_slow_normalization(x_all: Tensor, scale=True) -> Tensor: return torch.stack(res, dim=1) + def test_BurnsNorm_3d_input(): x_all_3d = torch.randn((2, 13, 768)) expected_output_3d = correct_but_slow_normalization(x_all_3d) @@ -29,6 +30,7 @@ def test_BurnsNorm_3d_input(): diff = output_3d - expected_output_3d assert (diff == torch.zeros_like(diff)).all() + def test_BurnsNorm_4d_input(): x_all_4d = torch.randn((2, 13, 2, 768)) expected_output_4d = correct_but_slow_normalization(x_all_4d) From 02c2a9ca0de8a238565b306d0d3a35c9807e63f0 Mon Sep 17 00:00:00 2001 From: Walter Laurito Date: Fri, 28 Jul 2023 18:56:41 +0000 Subject: [PATCH 28/41] remove code duplication --- elk/training/ccs_reporter.py | 29 ++++++++++------------------- 1 file changed, 10 insertions(+), 19 deletions(-) diff --git a/elk/training/ccs_reporter.py b/elk/training/ccs_reporter.py index 76a93236..d603be1f 100644 --- a/elk/training/ccs_reporter.py +++ b/elk/training/ccs_reporter.py @@ -100,25 +100,16 @@ def __init__( self.norm = None - if self.config.norm == "burns": - self.probe = nn.Sequential( - nn.Linear( - in_features, - 1 if cfg.num_layers < 2 else hidden_size, - bias=cfg.bias, - device=device, - ), - nn.Sigmoid(), - ) - else: - self.probe = nn.Sequential( - nn.Linear( - in_features, - 1 if cfg.num_layers < 2 else hidden_size, - bias=cfg.bias, - device=device, - ) - ) + self.probe = nn.Sequential( + nn.Linear( + in_features, + 1 if cfg.num_layers < 2 else hidden_size, + bias=cfg.bias, + device=device, + ), + *(nn.Sigmoid() if self.config.norm == "burns" else ()), + ) + if cfg.pre_ln: self.probe.insert(0, nn.LayerNorm(in_features, elementwise_affine=False)) From e172ab63ba80347446f03c949fbbfe818fe3c1f1 Mon Sep 17 00:00:00 2001 From: Walter Laurito Date: Fri, 28 Jul 2023 19:04:23 +0000 Subject: [PATCH 29/41] cleanup code duplication --- elk/training/ccs_reporter.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/elk/training/ccs_reporter.py b/elk/training/ccs_reporter.py index d603be1f..03ede336 100644 --- a/elk/training/ccs_reporter.py +++ b/elk/training/ccs_reporter.py @@ -100,16 +100,18 @@ def __init__( self.norm = None - self.probe = nn.Sequential( - nn.Linear( - in_features, - 1 if cfg.num_layers < 2 else hidden_size, - bias=cfg.bias, - device=device, - ), - *(nn.Sigmoid() if self.config.norm == "burns" else ()), - ) + layers = [nn.Linear( + in_features, + 1 if cfg.num_layers < 2 else hidden_size, + bias=cfg.bias, + device=device, + )] + + if self.config.norm == "burns": + layers.append(nn.Sigmoid()) + self.probe = nn.Sequential(*layers) + if cfg.pre_ln: self.probe.insert(0, nn.LayerNorm(in_features, elementwise_affine=False)) From 76b5eb1617cfcdafbcde124b0f1594a18e6817fe Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 28 Jul 2023 19:04:39 +0000 Subject: [PATCH 30/41] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- elk/training/ccs_reporter.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/elk/training/ccs_reporter.py b/elk/training/ccs_reporter.py index 03ede336..6f9257f0 100644 --- a/elk/training/ccs_reporter.py +++ b/elk/training/ccs_reporter.py @@ -100,18 +100,20 @@ def __init__( self.norm = None - layers = [nn.Linear( - in_features, - 1 if cfg.num_layers < 2 else hidden_size, - bias=cfg.bias, - device=device, - )] + layers = [ + nn.Linear( + in_features, + 1 if cfg.num_layers < 2 else hidden_size, + bias=cfg.bias, + device=device, + ) + ] if self.config.norm == "burns": layers.append(nn.Sigmoid()) self.probe = nn.Sequential(*layers) - + if cfg.pre_ln: self.probe.insert(0, nn.LayerNorm(in_features, elementwise_affine=False)) From 96e05608439172f5946810a44460582bb8600fe3 Mon Sep 17 00:00:00 2001 From: Walter Laurito Date: Fri, 28 Jul 2023 19:15:37 +0000 Subject: [PATCH 31/41] add type hint --- elk/training/ccs_reporter.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/elk/training/ccs_reporter.py b/elk/training/ccs_reporter.py index 03ede336..966a45e5 100644 --- a/elk/training/ccs_reporter.py +++ b/elk/training/ccs_reporter.py @@ -100,18 +100,20 @@ def __init__( self.norm = None - layers = [nn.Linear( - in_features, - 1 if cfg.num_layers < 2 else hidden_size, - bias=cfg.bias, - device=device, - )] + layers: list[nn.Module] = [ + nn.Linear( + in_features, + 1 if cfg.num_layers < 2 else hidden_size, + bias=cfg.bias, + device=device, + ) + ] if self.config.norm == "burns": layers.append(nn.Sigmoid()) self.probe = nn.Sequential(*layers) - + if cfg.pre_ln: self.probe.insert(0, nn.LayerNorm(in_features, elementwise_affine=False)) From 001deb83dbf8846d5693ec14ec22ac5b031521c6 Mon Sep 17 00:00:00 2001 From: Walter Laurito Date: Fri, 28 Jul 2023 19:19:16 +0000 Subject: [PATCH 32/41] fix forward by checking for leace --- elk/training/ccs_reporter.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/elk/training/ccs_reporter.py b/elk/training/ccs_reporter.py index 6f9257f0..25bd17d3 100644 --- a/elk/training/ccs_reporter.py +++ b/elk/training/ccs_reporter.py @@ -172,9 +172,10 @@ def forward(self, x: Tensor) -> Tensor: assert self.norm is not None, "Must call fit() before forward()" raw_scores = self.probe(self.norm(x)).squeeze(-1) - return raw_scores - breakpoint() - return raw_scores.mul(self.scale).add(self.bias).squeeze(-1) + if self.config.norm == "leace": + return raw_scores.mul(self.scale).add(self.bias).squeeze(-1) + else: + return raw_scores def loss(self, logit0: Tensor, logit1: Tensor) -> Tensor: """Return the loss of the reporter on the contrast pair (x0, x1). From e8872c8a28f2b5da878af840560b7533ec0d7b80 Mon Sep 17 00:00:00 2001 From: Walter Laurito Date: Fri, 28 Jul 2023 19:36:22 +0000 Subject: [PATCH 33/41] add annotation --- elk/training/burns_norm.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/elk/training/burns_norm.py b/elk/training/burns_norm.py index e88616ba..331964c9 100644 --- a/elk/training/burns_norm.py +++ b/elk/training/burns_norm.py @@ -3,13 +3,20 @@ class BurnsNorm(nn.Module): - """Burns et al. style normalization. Minimal changes from the original code.""" + """Burns et al. style normalization. Minimal changes from the original code. + """ def __init__(self, scale: bool = True): super().__init__() self.scale: bool = scale def forward(self, x: Tensor) -> Tensor: + """Normalizes per template + Args: + x: input of dimension (n, v, c, d) or (n, v, d) + Returns: + x_normalized: normalized output + """ num_elements = x.shape[0] x_normalized: Tensor = x - x.mean(dim=0) if num_elements > 1 else x @@ -23,7 +30,8 @@ def forward(self, x: Tensor) -> Tensor: # Compute the dimensions over which # we want to compute the mean standard deviation - # exclude the first dimension (v) + # exclude the first dimension v, + # which is the template dimension dims = tuple(range(1, std.dim())) avg_norm = std.mean(dim=dims) From 0f35b96a6cc5b4c24fcf6e01d36a2f1e7be36f21 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 28 Jul 2023 19:40:46 +0000 Subject: [PATCH 34/41] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- elk/training/burns_norm.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/elk/training/burns_norm.py b/elk/training/burns_norm.py index 331964c9..f9925f2f 100644 --- a/elk/training/burns_norm.py +++ b/elk/training/burns_norm.py @@ -3,8 +3,7 @@ class BurnsNorm(nn.Module): - """Burns et al. style normalization. Minimal changes from the original code. - """ + """Burns et al. style normalization. Minimal changes from the original code.""" def __init__(self, scale: bool = True): super().__init__() @@ -12,7 +11,7 @@ def __init__(self, scale: bool = True): def forward(self, x: Tensor) -> Tensor: """Normalizes per template - Args: + Args: x: input of dimension (n, v, c, d) or (n, v, d) Returns: x_normalized: normalized output @@ -30,7 +29,7 @@ def forward(self, x: Tensor) -> Tensor: # Compute the dimensions over which # we want to compute the mean standard deviation - # exclude the first dimension v, + # exclude the first dimension v, # which is the template dimension dims = tuple(range(1, std.dim())) From 94a498e2dc0d2dc3307aad27329e6517c1bfb166 Mon Sep 17 00:00:00 2001 From: Walter Laurito Date: Fri, 28 Jul 2023 19:42:46 +0000 Subject: [PATCH 35/41] add typing --- elk/training/ccs_reporter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/elk/training/ccs_reporter.py b/elk/training/ccs_reporter.py index 25bd17d3..89b19f6b 100644 --- a/elk/training/ccs_reporter.py +++ b/elk/training/ccs_reporter.py @@ -100,7 +100,7 @@ def __init__( self.norm = None - layers = [ + layers: list[nn.Module] = [ nn.Linear( in_features, 1 if cfg.num_layers < 2 else hidden_size, From 99646e9a85b770d98a29af958576627dde638b4e Mon Sep 17 00:00:00 2001 From: Walter Laurito Date: Fri, 28 Jul 2023 19:52:14 +0000 Subject: [PATCH 36/41] rename to probe_layers --- elk/training/ccs_reporter.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/elk/training/ccs_reporter.py b/elk/training/ccs_reporter.py index 89b19f6b..525c1490 100644 --- a/elk/training/ccs_reporter.py +++ b/elk/training/ccs_reporter.py @@ -100,7 +100,7 @@ def __init__( self.norm = None - layers: list[nn.Module] = [ + probe_layers: list[nn.Module] = [ nn.Linear( in_features, 1 if cfg.num_layers < 2 else hidden_size, @@ -110,9 +110,9 @@ def __init__( ] if self.config.norm == "burns": - layers.append(nn.Sigmoid()) + probe_layers.append(nn.Sigmoid()) - self.probe = nn.Sequential(*layers) + self.probe = nn.Sequential(*probe_layers) if cfg.pre_ln: self.probe.insert(0, nn.LayerNorm(in_features, elementwise_affine=False)) From aeb8b099e98f9dc52d5c02ec0d092f59c610a28d Mon Sep 17 00:00:00 2001 From: Walter Laurito Date: Thu, 3 Aug 2023 17:47:38 +0200 Subject: [PATCH 37/41] remove sigmoid from ccs reporter class + cleanup --- elk/training/ccs_reporter.py | 17 +++++++---------- elk/training/train.py | 2 +- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/elk/training/ccs_reporter.py b/elk/training/ccs_reporter.py index 525c1490..f6cc25ff 100644 --- a/elk/training/ccs_reporter.py +++ b/elk/training/ccs_reporter.py @@ -99,20 +99,14 @@ def __init__( hidden_size = cfg.hidden_size or 4 * in_features // 3 self.norm = None - - probe_layers: list[nn.Module] = [ + self.probe: nn.Sequential( nn.Linear( in_features, 1 if cfg.num_layers < 2 else hidden_size, bias=cfg.bias, device=device, - ) - ] - - if self.config.norm == "burns": - probe_layers.append(nn.Sigmoid()) - - self.probe = nn.Sequential(*probe_layers) + ), + ) if cfg.pre_ln: self.probe.insert(0, nn.LayerNorm(in_features, elementwise_affine=False)) @@ -174,8 +168,11 @@ def forward(self, x: Tensor) -> Tensor: raw_scores = self.probe(self.norm(x)).squeeze(-1) if self.config.norm == "leace": return raw_scores.mul(self.scale).add(self.bias).squeeze(-1) - else: + + elif self.config.norm == "burns": return raw_scores + else: + raise ValueError(f"Unknown normalization {self.config.norm}.") def loss(self, logit0: Tensor, logit1: Tensor) -> Tensor: """Return the loss of the reporter on the contrast pair (x0, x1). diff --git a/elk/training/train.py b/elk/training/train.py index b6888ea7..a7f0ef07 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -83,7 +83,7 @@ def apply_to_layer( reporter = CcsReporter(self.net, d, device=device, num_variants=v) train_loss = reporter.fit(first_train_h) - if self.net.norm == "leace": + if not self.net.norm == "burns": (_, v, k, _) = first_train_h.shape reporter.platt_scale( to_one_hot(repeat(train_gt, "n -> (n v)", v=v), k).flatten(), From f2ae1b7bf1cbc3a359231f1084a034f0be8f1f2b Mon Sep 17 00:00:00 2001 From: Walter Laurito Date: Thu, 3 Aug 2023 17:53:41 +0200 Subject: [PATCH 38/41] fix assignment of probe --- elk/training/ccs_reporter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/elk/training/ccs_reporter.py b/elk/training/ccs_reporter.py index f6cc25ff..472417f5 100644 --- a/elk/training/ccs_reporter.py +++ b/elk/training/ccs_reporter.py @@ -99,7 +99,7 @@ def __init__( hidden_size = cfg.hidden_size or 4 * in_features // 3 self.norm = None - self.probe: nn.Sequential( + self.probe = nn.Sequential( nn.Linear( in_features, 1 if cfg.num_layers < 2 else hidden_size, From 3d6042f13fde4c9791dc88d8f86cb7d6d24dfefa Mon Sep 17 00:00:00 2001 From: Walter Laurito Date: Fri, 11 Aug 2023 11:08:20 +0000 Subject: [PATCH 39/41] remove unnecessary tensor wrapping --- comparison-sweeps | 1 + elk/training/burns_norm.py | 4 +--- 2 files changed, 2 insertions(+), 3 deletions(-) create mode 160000 comparison-sweeps diff --git a/comparison-sweeps b/comparison-sweeps new file mode 160000 index 00000000..f4ed884b --- /dev/null +++ b/comparison-sweeps @@ -0,0 +1 @@ +Subproject commit f4ed884b59c99012c80b972d2a02c660b39c90cb diff --git a/elk/training/burns_norm.py b/elk/training/burns_norm.py index f9925f2f..8c87efe0 100644 --- a/elk/training/burns_norm.py +++ b/elk/training/burns_norm.py @@ -22,9 +22,7 @@ def forward(self, x: Tensor) -> Tensor: if not self.scale: return x_normalized else: - std = torch.linalg.norm(x_normalized, dim=0) / torch.sqrt( - torch.tensor(x_normalized.shape[0], dtype=torch.float32) - ) + std = torch.linalg.norm(x_normalized, dim=0) / x_normalized.shape[0] ** 0.5 assert std.dim() == x.dim() - 1 # Compute the dimensions over which From ee090c810fc8735aa460ca08334fa0b10f8fb82a Mon Sep 17 00:00:00 2001 From: Walter Laurito Date: Fri, 11 Aug 2023 11:19:27 +0000 Subject: [PATCH 40/41] remove unnessary unsqueezes; using keepdim instead --- elk/training/burns_norm.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/elk/training/burns_norm.py b/elk/training/burns_norm.py index 8c87efe0..5a0adc69 100644 --- a/elk/training/burns_norm.py +++ b/elk/training/burns_norm.py @@ -31,15 +31,6 @@ def forward(self, x: Tensor) -> Tensor: # which is the template dimension dims = tuple(range(1, std.dim())) - avg_norm = std.mean(dim=dims) - - # Add a singleton dimension at the beginning to allow broadcasting. - # This compensates for the dimension we lost when computing the norm. - avg_norm = avg_norm.unsqueeze(0) - - # Add singleton dimensions at the end to allow broadcasting. - # This compensates for the dimensions we lost when computing the mean. - for _ in range(1, x.dim() - 1): - avg_norm = avg_norm.unsqueeze(-1) + avg_norm = std.mean(dim=dims, keepdim=True) return x_normalized / avg_norm From d6b5ce6d157d8e48f31235a5d034918e15fcd2fd Mon Sep 17 00:00:00 2001 From: Walter Laurito Date: Fri, 11 Aug 2023 11:22:20 +0000 Subject: [PATCH 41/41] clearer commment for what template is used --- elk/training/burns_norm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/elk/training/burns_norm.py b/elk/training/burns_norm.py index 5a0adc69..f116937e 100644 --- a/elk/training/burns_norm.py +++ b/elk/training/burns_norm.py @@ -10,7 +10,7 @@ def __init__(self, scale: bool = True): self.scale: bool = scale def forward(self, x: Tensor) -> Tensor: - """Normalizes per template + """Normalizes per prompt template Args: x: input of dimension (n, v, c, d) or (n, v, d) Returns: