diff --git a/comparison-sweeps b/comparison-sweeps new file mode 160000 index 00000000..f4ed884b --- /dev/null +++ b/comparison-sweeps @@ -0,0 +1 @@ +Subproject commit f4ed884b59c99012c80b972d2a02c660b39c90cb diff --git a/elk/training/burns_norm.py b/elk/training/burns_norm.py index f9925f2f..f116937e 100644 --- a/elk/training/burns_norm.py +++ b/elk/training/burns_norm.py @@ -10,7 +10,7 @@ def __init__(self, scale: bool = True): self.scale: bool = scale def forward(self, x: Tensor) -> Tensor: - """Normalizes per template + """Normalizes per prompt template Args: x: input of dimension (n, v, c, d) or (n, v, d) Returns: @@ -22,9 +22,7 @@ def forward(self, x: Tensor) -> Tensor: if not self.scale: return x_normalized else: - std = torch.linalg.norm(x_normalized, dim=0) / torch.sqrt( - torch.tensor(x_normalized.shape[0], dtype=torch.float32) - ) + std = torch.linalg.norm(x_normalized, dim=0) / x_normalized.shape[0] ** 0.5 assert std.dim() == x.dim() - 1 # Compute the dimensions over which @@ -33,15 +31,6 @@ def forward(self, x: Tensor) -> Tensor: # which is the template dimension dims = tuple(range(1, std.dim())) - avg_norm = std.mean(dim=dims) - - # Add a singleton dimension at the beginning to allow broadcasting. - # This compensates for the dimension we lost when computing the norm. - avg_norm = avg_norm.unsqueeze(0) - - # Add singleton dimensions at the end to allow broadcasting. - # This compensates for the dimensions we lost when computing the mean. - for _ in range(1, x.dim() - 1): - avg_norm = avg_norm.unsqueeze(-1) + avg_norm = std.mean(dim=dims, keepdim=True) return x_normalized / avg_norm