Skip to content

Commit

Permalink
feat(sae): change input format of forward method
Browse files Browse the repository at this point in the history
  • Loading branch information
frankstein authored and dest1n1s committed Jan 15, 2025
1 parent d9d4611 commit 9427334
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 12 deletions.
13 changes: 3 additions & 10 deletions src/lm_saes/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def log_metric(metric: str, value: float) -> None:
reconstructed = (
log_info.pop("reconstructed")[useful_token_mask]
if "reconstructed" in log_info
else sae.forward({sae.cfg.hook_point_in: activation_in})
else sae.forward(activation_in)
)

# 3. Compute sparsity metrics
Expand Down Expand Up @@ -158,18 +158,11 @@ def _evaluate_tokens(
)
reconstructed_activations: Tensor | None = None
if isinstance(sae.cfg, SAEConfig):
reconstructed_activations = sae.forward(cache).to(
cache[sae.cfg.hook_point_out].dtype
) # shape: (seq_len, d_model)
reconstructed_activations = sae.forward(cache[sae.cfg.hook_point_in])

elif isinstance(sae.cfg, MixCoderConfig):
assert isinstance(sae, MixCoder)
reconstructed_activations = sae.forward(
{
sae.cfg.hook_point_in: cache[sae.cfg.hook_point_in],
"tokens": input_ids,
}
).to(cache[sae.cfg.hook_point_out].dtype)
reconstructed_activations = sae.forward(cache[sae.cfg.hook_point_in], tokens=input_ids)

assert reconstructed_activations is not None

Expand Down
8 changes: 6 additions & 2 deletions src/lm_saes/sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,12 +384,16 @@ def decode(

def forward(
self,
batch: dict[str, torch.Tensor],
x: Union[
Float[torch.Tensor, "batch d_model"],
Float[torch.Tensor, "batch seq_len d_model"],
],
**kwargs,
) -> Union[
Float[torch.Tensor, "batch d_model"],
Float[torch.Tensor, "batch seq_len d_model"],
]:
feature_acts = self.encode(batch[self.cfg.hook_point_in])
feature_acts = self.encode(x)
reconstructed = self.decode(feature_acts)
return reconstructed

Expand Down

0 comments on commit 9427334

Please sign in to comment.