From 6e07819099c44f0ca00c197186e32d27fc1115c9 Mon Sep 17 00:00:00 2001 From: Alex Mallen Date: Tue, 14 Nov 2023 19:53:03 +0000 Subject: [PATCH] default meanonly norm --- ccs/training/ccs_reporter.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ccs/training/ccs_reporter.py b/ccs/training/ccs_reporter.py index ce20fee..11ed842 100644 --- a/ccs/training/ccs_reporter.py +++ b/ccs/training/ccs_reporter.py @@ -43,7 +43,7 @@ class CcsConfig(FitterConfig): function 1.0*consistency_squared + 0.5*prompt_var. """ loss_dict: dict[str, float] = field(default_factory=dict, init=False) - norm: Literal["leace", "burns"] = "leace" + norm: Literal["leace", "burns", "meanonly"] = "meanonly" num_layers: int = 1 """The number of layers in the MLP.""" pre_ln: bool = False @@ -209,6 +209,8 @@ def fit(self, hiddens: Tensor) -> float: if self.config.norm == "burns": self.norm = BurnsNorm() + elif self.config.norm == "meanonly": + self.norm = BurnsNorm(scale=False) else: fitter = LeaceFitter(d, 2 * v, dtype=x_neg.dtype, device=x_neg.device) fitter.update(