Skip to content

Commit

Permalink
Merge branch 'main' into experiments-notodai-no-probe-per-prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
lauritowal committed Aug 11, 2023
2 parents 12277ec + 7e60fa7 commit 6284e74
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 14 deletions.
1 change: 1 addition & 0 deletions comparison-sweeps
Submodule comparison-sweeps added at f4ed88
17 changes: 3 additions & 14 deletions elk/training/burns_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def __init__(self, scale: bool = True):
self.scale: bool = scale

def forward(self, x: Tensor) -> Tensor:
"""Normalizes per template
"""Normalizes per prompt template
Args:
x: input of dimension (n, v, c, d) or (n, v, d)
Returns:
Expand All @@ -22,9 +22,7 @@ def forward(self, x: Tensor) -> Tensor:
if not self.scale:
return x_normalized
else:
std = torch.linalg.norm(x_normalized, dim=0) / torch.sqrt(
torch.tensor(x_normalized.shape[0], dtype=torch.float32)
)
std = torch.linalg.norm(x_normalized, dim=0) / x_normalized.shape[0] ** 0.5
assert std.dim() == x.dim() - 1

# Compute the dimensions over which
Expand All @@ -33,15 +31,6 @@ def forward(self, x: Tensor) -> Tensor:
# which is the template dimension
dims = tuple(range(1, std.dim()))

avg_norm = std.mean(dim=dims)

# Add a singleton dimension at the beginning to allow broadcasting.
# This compensates for the dimension we lost when computing the norm.
avg_norm = avg_norm.unsqueeze(0)

# Add singleton dimensions at the end to allow broadcasting.
# This compensates for the dimensions we lost when computing the mean.
for _ in range(1, x.dim() - 1):
avg_norm = avg_norm.unsqueeze(-1)
avg_norm = std.mean(dim=dims, keepdim=True)

return x_normalized / avg_norm

0 comments on commit 6284e74

Please sign in to comment.