diff --git a/elk/training/reporter.py b/elk/training/reporter.py index 6ce3e951..a321387a 100644 --- a/elk/training/reporter.py +++ b/elk/training/reporter.py @@ -95,19 +95,17 @@ def __init__( def reset_parameters(self): """Reset the parameters of the probe. - Mathematically equivalent to the unusual initialization scheme used in the - original paper. They sample a random Gaussian vector of dim in_features + 1, - normalize to the unit sphere, then add an extra all-ones dimension to the - input and compute the inner product. Here, we use nn.Linear with an explicit - bias term, but use the same initialization. - If init is "spherical", use the spherical initialization scheme. If init is "default", use the default PyTorch initialization scheme for nn.Linear (Kaiming uniform). If init is "zero", initialize all parameters to zero. """ - if self.init == "spherical": + # Mathematically equivalent to the unusual initialization scheme used in + # the original paper. They sample a Gaussian vector of dim in_features + 1, + # normalize to the unit sphere, then add an extra all-ones dimension to the + # input and compute the inner product. Here, we use nn.Linear with an + # explicit bias term, but use the same initialization. assert len(self.probe) == 1, "Only linear probes can use spherical init" probe = cast(nn.Linear, self.probe[0]) # Pylance gets the type wrong here