Skip to content

Commit

Permalink
default meanonly norm
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexTMallen committed Nov 14, 2023
1 parent e751405 commit 6e07819
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion ccs/training/ccs_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class CcsConfig(FitterConfig):
function 1.0*consistency_squared + 0.5*prompt_var.
"""
loss_dict: dict[str, float] = field(default_factory=dict, init=False)
norm: Literal["leace", "burns"] = "leace"
norm: Literal["leace", "burns", "meanonly"] = "meanonly"
num_layers: int = 1
"""The number of layers in the MLP."""
pre_ln: bool = False
Expand Down Expand Up @@ -209,6 +209,8 @@ def fit(self, hiddens: Tensor) -> float:

if self.config.norm == "burns":
self.norm = BurnsNorm()
elif self.config.norm == "meanonly":
self.norm = BurnsNorm(scale=False)
else:
fitter = LeaceFitter(d, 2 * v, dtype=x_neg.dtype, device=x_neg.device)
fitter.update(
Expand Down

0 comments on commit 6e07819

Please sign in to comment.