Skip to content

Commit

Permalink
Merge pull request #258 from EleutherAI/add_per_prompt_norm
Browse files Browse the repository at this point in the history
Add per prompt norm
  • Loading branch information
lauritowal authored Aug 11, 2023
2 parents d10b4c2 + d6b5ce6 commit 7e60fa7
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 18 deletions.
1 change: 1 addition & 0 deletions comparison-sweeps
Submodule comparison-sweeps added at f4ed88
36 changes: 36 additions & 0 deletions elk/training/burns_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import torch
from torch import Tensor, nn


class BurnsNorm(nn.Module):
"""Burns et al. style normalization. Minimal changes from the original code."""

def __init__(self, scale: bool = True):
super().__init__()
self.scale: bool = scale

def forward(self, x: Tensor) -> Tensor:
"""Normalizes per prompt template
Args:
x: input of dimension (n, v, c, d) or (n, v, d)
Returns:
x_normalized: normalized output
"""
num_elements = x.shape[0]
x_normalized: Tensor = x - x.mean(dim=0) if num_elements > 1 else x

if not self.scale:
return x_normalized
else:
std = torch.linalg.norm(x_normalized, dim=0) / x_normalized.shape[0] ** 0.5
assert std.dim() == x.dim() - 1

# Compute the dimensions over which
# we want to compute the mean standard deviation
# exclude the first dimension v,
# which is the template dimension
dims = tuple(range(1, std.dim()))

avg_norm = std.mean(dim=dims, keepdim=True)

return x_normalized / avg_norm
38 changes: 25 additions & 13 deletions elk/training/ccs_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from ..parsing import parse_loss
from ..utils.typing import assert_type
from .burns_norm import BurnsNorm
from .common import FitterConfig
from .losses import LOSSES
from .platt_scaling import PlattMixin
Expand Down Expand Up @@ -41,6 +42,7 @@ class CcsConfig(FitterConfig):
function 1.0*consistency_squared + 0.5*prompt_var.
"""
loss_dict: dict[str, float] = field(default_factory=dict, init=False)
norm: Literal["leace", "burns"] = "leace"
num_layers: int = 1
"""The number of layers in the MLP."""
pre_ln: bool = False
Expand Down Expand Up @@ -105,6 +107,7 @@ def __init__(
device=device,
),
)

if cfg.pre_ln:
self.probe.insert(0, nn.LayerNorm(in_features, elementwise_affine=False))

Expand Down Expand Up @@ -163,7 +166,13 @@ def forward(self, x: Tensor) -> Tensor:
assert self.norm is not None, "Must call fit() before forward()"

raw_scores = self.probe(self.norm(x)).squeeze(-1)
return raw_scores.mul(self.scale).add(self.bias).squeeze(-1)
if self.config.norm == "leace":
return raw_scores.mul(self.scale).add(self.bias).squeeze(-1)

elif self.config.norm == "burns":
return raw_scores
else:
raise ValueError(f"Unknown normalization {self.config.norm}.")

def loss(self, logit0: Tensor, logit1: Tensor) -> Tensor:
"""Return the loss of the reporter on the contrast pair (x0, x1).
Expand Down Expand Up @@ -193,18 +202,21 @@ def fit(self, hiddens: Tensor) -> float:
n, v, d = x_neg.shape
prompt_ids = torch.eye(v, device=x_neg.device).expand(n, -1, -1)

fitter = LeaceFitter(d, 2 * v, dtype=x_neg.dtype, device=x_neg.device)
fitter.update(
x=x_neg,
# Independent indicator for each (template, pseudo-label) pair
z=torch.cat([torch.zeros_like(prompt_ids), prompt_ids], dim=-1),
)
fitter.update(
x=x_pos,
# Independent indicator for each (template, pseudo-label) pair
z=torch.cat([prompt_ids, torch.zeros_like(prompt_ids)], dim=-1),
)
self.norm = fitter.eraser
if self.config.norm == "burns":
self.norm = BurnsNorm()
else:
fitter = LeaceFitter(d, 2 * v, dtype=x_neg.dtype, device=x_neg.device)
fitter.update(
x=x_neg,
# Independent indicator for each (template, pseudo-label) pair
z=torch.cat([torch.zeros_like(prompt_ids), prompt_ids], dim=-1),
)
fitter.update(
x=x_pos,
# Independent indicator for each (template, pseudo-label) pair
z=torch.cat([prompt_ids, torch.zeros_like(prompt_ids)], dim=-1),
)
self.norm = fitter.eraser

x_neg, x_pos = self.norm(x_neg), self.norm(x_pos)

Expand Down
11 changes: 6 additions & 5 deletions elk/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,12 @@ def apply_to_layer(
reporter = CcsReporter(self.net, d, device=device, num_variants=v)
train_loss = reporter.fit(first_train_h)

(_, v, k, _) = first_train_h.shape
reporter.platt_scale(
to_one_hot(repeat(train_gt, "n -> (n v)", v=v), k).flatten(),
rearrange(first_train_h, "n v k d -> (n v k) d"),
)
if not self.net.norm == "burns":
(_, v, k, _) = first_train_h.shape
reporter.platt_scale(
to_one_hot(repeat(train_gt, "n -> (n v)", v=v), k).flatten(),
rearrange(first_train_h, "n v k d -> (n v k) d"),
)

elif isinstance(self.net, EigenFitterConfig):
fitter = EigenFitter(
Expand Down
40 changes: 40 additions & 0 deletions tests/test_burns_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import torch
from torch import Tensor

from elk.training.burns_norm import BurnsNorm


def correct_but_slow_normalization(x_all: Tensor, scale=True) -> Tensor:
res = []
xs = x_all.unbind(dim=1)

for x in xs:
num_elements = x.shape[0]
x_mean: Tensor = x - x.mean(dim=0) if num_elements > 1 else x
if scale is True:
std = torch.linalg.norm(x_mean, axis=0) / torch.sqrt(
torch.tensor(x_mean.shape[0], dtype=torch.float32)
)
avg_norm = std.mean()
x_mean = x_mean / avg_norm
res.append(x_mean)

return torch.stack(res, dim=1)


def test_BurnsNorm_3d_input():
x_all_3d = torch.randn((2, 13, 768))
expected_output_3d = correct_but_slow_normalization(x_all_3d)
bn = BurnsNorm()
output_3d = bn(x_all_3d)
diff = output_3d - expected_output_3d
assert (diff == torch.zeros_like(diff)).all()


def test_BurnsNorm_4d_input():
x_all_4d = torch.randn((2, 13, 2, 768))
expected_output_4d = correct_but_slow_normalization(x_all_4d)
bn = BurnsNorm()
output_4d = bn(x_all_4d)
diff = output_4d - expected_output_4d
assert (diff == torch.zeros_like(diff)).all()

0 comments on commit 7e60fa7

Please sign in to comment.