Skip to content

Commit

Permalink
Merge pull request #90 from OpenMOSS/update-source-dtype-fix
Browse files Browse the repository at this point in the history
fix(activation): preserve tokens type during dtype conversion
  • Loading branch information
Hzfinfdu authored Feb 14, 2025
2 parents 83acd31 + 6b84c49 commit 743ff47
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions src/lm_saes/activation/processors/cached_activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,9 @@ def process(self, data: None = None, **kwargs) -> Iterable[dict[str, Any]]:
device=self.device,
)
if self.dtype is not None:
activations = {k: v.to(self.dtype) for k, v in activations.items()}
for k, v in activations.items():
if k in self.hook_points:
activations[k] = v.to(self.dtype)
yield activations # Use pin_memory to load data on cpu, then transfer them to cuda in the main process, as advised in https://discuss.pytorch.org/t/dataloader-multiprocessing-with-dataset-returning-a-cuda-tensor/151022/2.
# I wrote this utils function as I notice it is used multiple times in this repo. Do we need to apply it elsewhere?

Expand All @@ -259,4 +261,4 @@ def __getitem__(self, chunk_idx):
return self.activation_loader.load_chunk_for_hooks(
chunk_idx,
self.hook_chunks,
)
)

0 comments on commit 743ff47

Please sign in to comment.