From 942733482968837846878ff1bc8b3fb919d25796 Mon Sep 17 00:00:00 2001 From: frankstein Date: Wed, 15 Jan 2025 21:08:24 +0800 Subject: [PATCH] feat(sae): change input format of forward method --- src/lm_saes/evaluator.py | 13 +++---------- src/lm_saes/sae.py | 8 ++++++-- 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/src/lm_saes/evaluator.py b/src/lm_saes/evaluator.py index f010ee1..c58ce48 100644 --- a/src/lm_saes/evaluator.py +++ b/src/lm_saes/evaluator.py @@ -59,7 +59,7 @@ def log_metric(metric: str, value: float) -> None: reconstructed = ( log_info.pop("reconstructed")[useful_token_mask] if "reconstructed" in log_info - else sae.forward({sae.cfg.hook_point_in: activation_in}) + else sae.forward(activation_in) ) # 3. Compute sparsity metrics @@ -158,18 +158,11 @@ def _evaluate_tokens( ) reconstructed_activations: Tensor | None = None if isinstance(sae.cfg, SAEConfig): - reconstructed_activations = sae.forward(cache).to( - cache[sae.cfg.hook_point_out].dtype - ) # shape: (seq_len, d_model) + reconstructed_activations = sae.forward(cache[sae.cfg.hook_point_in]) elif isinstance(sae.cfg, MixCoderConfig): assert isinstance(sae, MixCoder) - reconstructed_activations = sae.forward( - { - sae.cfg.hook_point_in: cache[sae.cfg.hook_point_in], - "tokens": input_ids, - } - ).to(cache[sae.cfg.hook_point_out].dtype) + reconstructed_activations = sae.forward(cache[sae.cfg.hook_point_in], tokens=input_ids) assert reconstructed_activations is not None diff --git a/src/lm_saes/sae.py b/src/lm_saes/sae.py index 059e5a5..359e0f2 100644 --- a/src/lm_saes/sae.py +++ b/src/lm_saes/sae.py @@ -384,12 +384,16 @@ def decode( def forward( self, - batch: dict[str, torch.Tensor], + x: Union[ + Float[torch.Tensor, "batch d_model"], + Float[torch.Tensor, "batch seq_len d_model"], + ], + **kwargs, ) -> Union[ Float[torch.Tensor, "batch d_model"], Float[torch.Tensor, "batch seq_len d_model"], ]: - feature_acts = self.encode(batch[self.cfg.hook_point_in]) + feature_acts = self.encode(x) reconstructed = self.decode(feature_acts) return reconstructed