Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
…into main
  • Loading branch information
Hzfinfdu committed Jun 13, 2024
2 parents 2d54a91 + e62876a commit 40e795d
Show file tree
Hide file tree
Showing 28 changed files with 697 additions and 797 deletions.
52 changes: 52 additions & 0 deletions .github/workflows/checks.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
name: Checks

on:
push:
branches:
- main
- dev
paths:
- "**" # Include all files by default
- "!.devcontainer/**"
- "!.vscode/**"
- "!.git*"
- "!*.md"
- "!.github/**"
- ".github/workflows/checks.yml" # Still include current workflow
pull_request:
branches:
- main
- dev
paths:
- "**"
- "!.devcontainer/**"
- "!.vscode/**"
- "!.git*"
- "!*.md"
- "!.github/**"
- ".github/workflows/checks.yml"
# Allow this workflow to be called from other workflows
workflow_call:
inputs:
# Requires at least one input to be valid, but in practice we don't need any
dummy:
type: string
required: false

permissions:
actions: write
contents: write

jobs:
code-checks:
name: Code Checks
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Setup PDM
uses: pdm-project/setup-pdm@v4
# You are now able to use PDM in your workflow
- name: Install dependencies
run: pdm install
- name: Type check
run: pdm run mypy .
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
This repo aims to provide a general codebase for conducting dictionary-learning-based mechanistic interpretability research on Language Models (LMs). It powers a configurable pipeline for training and evaluating GPT-2 dictionaries, and provides a set of tools (mainly a React-based webpage) for analyzing and visualizing the learned dictionaries.

The design of the pipeline (including the configuration and some training detail) is highly inspired by the [mats_sae_training
](https://github.com/jbloomAus/mats_sae_training) project. We thank the authors for their great work.
](https://github.com/jbloomAus/mats_sae_training) project and heavily relies on the [TransformerLens](https://github.com/TransformerLensOrg/TransformerLens) library. We thank the authors for their great work.

## Getting Started with Mechanistic Interpretability and Dictionary Learning

Expand All @@ -16,7 +16,7 @@ If you are new to the concept of mechanistic interpretability and dictionary lea
- [Towards Monosemanticity: Decomposing Language Models With Dictionary Learning](https://transformer-circuits.pub/2023/monosemantic-features/index.html)
- [Sparse Autoencoders Find Highly Interpretable Features in Language Models](https://arxiv.org/abs/2309.08600)

Furthermore, to dive deeper into the inner activations of LMs, it's recommended to get familiar with the [TransformerLens](https://github.com/neelnanda-io/TransformerLens/tree/main) library.
Furthermore, to dive deeper into the inner activations of LMs, it's recommended to get familiar with the [TransformerLens](https://github.com/TransformerLensOrg/TransformerLens) library.

## Installation

Expand Down
35 changes: 35 additions & 0 deletions TransformerLens/tests/acceptance/test_offloading.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from transformer_lens import HookedTransformer
import torch

MODEL = "solu-2l"

def time_diff(func):
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()

func()

end_event.record()
torch.cuda.synchronize()
return start_event.elapsed_time(end_event)

@torch.no_grad()
def test_offload_params_after():
model = HookedTransformer.from_pretrained(MODEL, device="cuda")
allocated_before = torch.cuda.memory_allocated(0)
model.offload_params_after("blocks.0.hook_resid_post", torch.tensor([[0]], device="cuda"))
allocated_after = torch.cuda.memory_allocated(0)
assert allocated_after < allocated_before * 0.55

@torch.no_grad()
def test_run_with_cache_until():
model = HookedTransformer.from_pretrained(MODEL, device="cuda")
def forward():
model.run_with_cache("Hello, world!", names_filter=["blocks.0.hook_resid_post"])
forward_time = time_diff(forward)
def forward_until():
model.run_with_cache_until("Hello, world!", names_filter=["blocks.0.hook_resid_post"])
forward_fake_time = time_diff(forward_until)
assert forward_fake_time < forward_time * 0.7

52 changes: 52 additions & 0 deletions TransformerLens/tests/integration/test_offloading.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from transformer_lens.hook_points import HookedRootModule, HookPoint
import torch
import torch.nn as nn

class Block(nn.Module):
def __init__(self):
super().__init__()
self.subblock1 = nn.Linear(10, 10)
self.subblock2 = nn.Linear(10, 10)
self.activation = nn.ReLU()
self.hook_mid = HookPoint()

def forward(self, x):
return self.subblock2(self.hook_mid(self.activation(self.subblock1(x))))

class TestModule(HookedRootModule):
__test__ = False
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.blocks = nn.ModuleList([Block() for _ in range(3)])
self.embed = nn.Linear(1, 10)
self.unembed = nn.Linear(10, 1)
self.setup()

def forward(self, x):
x = self.embed(x)
for block in self.blocks:
x = block(x)
return self.unembed(x)

def test_run_with_cache_until():
model = TestModule()
_, cache_before = model.run_with_cache(torch.tensor([1.]), names_filter=["blocks.0.hook_mid", "blocks.1.hook_mid"])
out, cache_after = model.run_with_cache_until(torch.tensor([1.]), names_filter=["blocks.0.hook_mid", "blocks.1.hook_mid"])

assert torch.allclose(cache_before["blocks.0.hook_mid"], cache_after["blocks.0.hook_mid"])
assert torch.allclose(cache_before["blocks.1.hook_mid"], cache_after["blocks.1.hook_mid"])
assert torch.allclose(cache_before["blocks.1.hook_mid"], out)

def test_offload_params_after():
model = TestModule()
_, cache_before = model.run_with_cache(torch.tensor([1.]))

model.offload_params_after("blocks.1.hook_mid", torch.tensor([1.]))
assert model.blocks[0].subblock1.weight is not None
assert model.blocks[1].subblock1.weight is not None
assert model.blocks[2].subblock1.weight is None
assert model.unembed.weight is None

_, cache_after = model.run_with_cache_until(torch.tensor([1.]), names_filter=["blocks.0.hook_mid", "blocks.1.hook_mid"])
assert torch.allclose(cache_before["blocks.0.hook_mid"], cache_after["blocks.0.hook_mid"])
assert torch.allclose(cache_before["blocks.1.hook_mid"], cache_after["blocks.1.hook_mid"])
126 changes: 123 additions & 3 deletions TransformerLens/transformer_lens/hook_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import torch.nn as nn
import torch.utils.hooks as hooks

from transformer_lens.utils import Slice, SliceInput
from transformer_lens.utils import Slice, SliceInput, set_nested_attr


@dataclass
Expand All @@ -51,8 +51,7 @@ class LensHandle:
class _HookFunctionProtocol(Protocol):
"""Protocol for hook functions."""

def __call__(self, tensor: torch.Tensor, *, hook: "HookPoint") -> Union[Any, None]:
...
def __call__(self, tensor: torch.Tensor, *, hook: "HookPoint") -> Union[Any, None]: ...


HookFunction = _HookFunctionProtocol # Callable[..., _HookFunctionProtocol]
Expand Down Expand Up @@ -776,5 +775,126 @@ def run_with_ref_cache(

return model_out, cache_dict

def offload_params_after(self, last_hook: str, *model_args, **model_kwargs):
"""
Set parameters that are not used after a certain hook to None.
This does not guarantee that all parameters are offloaded, but it should offload most of them.
Specifically, the direct parameters of the ancestor modules of the last hook are not offloaded,
since there are no way to know whether they are used before or after the last hook.
Args:
last_hook (str): The name of the last hook.
*model_args: Positional arguments for the model.
**model_kwargs: Keyword arguments for the model.
"""
pass_module_list: List[nn.Module] = []
fake_mode = FakeTensorMode(allow_non_fake_inputs=True)
hook_handles: List[hooks.RemovableHandle] = []

def pass_hook(module: nn.Module, module_input: Any, module_output: Any):
pass_module_list.append(module)

def convert_hook(tensor: torch.Tensor, hook: HookPoint):
pass_param_set = set()
hook_ancestors = [module for module in self.modules() if module == self or hook.name.startswith(module.name)]
for module in pass_module_list + hook_ancestors:
for param_name, parameters in module.named_parameters():
if "." not in param_name:
name = f"{module.name}.{param_name}" if module != self else param_name
pass_param_set.add(name)

fake_param_set = set([name for name, _ in self.named_parameters()]).difference(pass_param_set)

for name in fake_param_set:
set_nested_attr(self, name, None)
raise StopIteration


for _, module in self.named_modules():
hook_handles.append(module.register_forward_hook(pass_hook))

with fake_mode:
with self.hooks(fwd_hooks=[(last_hook, convert_hook)]):
try:
self(*model_args, **model_kwargs)
except StopIteration:
pass

for handle in hook_handles:
handle.remove()

def run_with_cache_until(
self,
*model_args: Any,
names_filter: NamesFilter = None,
until: str = None,
device: DeviceType = None,
remove_batch_dim: bool = False,
reset_hooks_end: bool = True,
clear_contexts: bool = False,
pos_slice: Optional[Union[Slice, SliceInput]] = None,
**model_kwargs: Any,
):
"""
Runs the model and returns the model output and a Cache object.
Args:
*model_args: Positional arguments for the model.
names_filter (NamesFilter, optional): A filter for which activations to cache. Accepts None, str,
list of str, or a function that takes a string and returns a bool. Defaults to None, which
means cache everything.
until (str, optional): The name of the hook to stop caching at. Defaults to None, which means
stop caching at the last hook.
device (str or torch.Device, optional): The device to cache activations on. Defaults to the
model device. WARNING: Setting a different device than the one used by the model leads to
significant performance degradation.
remove_batch_dim (bool, optional): If True, removes the batch dimension when caching. Only
makes sense with batch_size=1 inputs. Defaults to False.
reset_hooks_end (bool, optional): If True, removes all hooks added by this function at the
end of the run. Defaults to True.
clear_contexts (bool, optional): If True, clears hook contexts whenever hooks are reset.
Defaults to False.
pos_slice:
The slice to apply to the cache output. Defaults to None, do nothing.
**model_kwargs: Keyword arguments for the model.
Returns:
tuple: A tuple containing the model output and a Cache object.
"""

pos_slice = Slice.unwrap(pos_slice)

cache_dict, fwd, _ = self.get_caching_hooks(
names_filter,
False,
device,
remove_batch_dim=remove_batch_dim,
pos_slice=pos_slice,
)

if until is None:
until = fwd[-1][0]

class ModuleStop(Exception):
def __init__(self, tensor: torch.Tensor):
self.tensor = tensor

def stop_hook(tensor: torch.Tensor, hook: HookPoint):
if hook.name == until:
raise ModuleStop(tensor)

with self.hooks(
fwd_hooks=fwd + [(until, stop_hook)],
reset_hooks_end=reset_hooks_end,
clear_contexts=clear_contexts,
):
try:
model_out = self(*model_args, **model_kwargs)
except ModuleStop as e:
model_out = e.tensor

return model_out, cache_dict


# %%
2 changes: 1 addition & 1 deletion examples/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
hook_points = ["blocks.3.hook_mlp_out"],

# SAEConfig
**SAEConfig.get_hyperparameters("L3M", "results", "final.pt"), # Load the hyperparameters from the trained model.
**SAEConfig.from_pretrained("result/L3M").to_dict(), # Load the hyperparameters from the trained model.

# LanguageModelSAEAnalysisConfig
total_analyzing_tokens = 20_000_000,
Expand Down
2 changes: 1 addition & 1 deletion examples/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
store_batch_size = 32, # The batch size for loading the corpus.

# ActivationStoreConfig
hook_points = ["blocks.3.hook_mlp_out"], # The hook point to extract the activations, i.e. the layer output of which is used for training/evaluating the dictionary.
hook_points = ["blocks.3.hook_mlp_out"], # Hook points to store activations from, i.e. the layer output of which is used for training/evaluating the dictionary. Will run until the last hook point in the list, so make sure to order them correctly.
use_cached_activations = False, # Whether to use cached activations. Caching activation is now not recommended, as it may consume extremely large disk space. (May be tens of TBs for corpus like `openwebtext`)
n_tokens_in_buffer = 500_000, # The number of tokens to store in the activation buffer. The buffer is used to shuffle the activations before training the dictionary.

Expand Down
Loading

0 comments on commit 40e795d

Please sign in to comment.