-
Notifications
You must be signed in to change notification settings - Fork 12
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' of https://github.com/OpenMOSS/Language-Model-SAEs …
…into main
- Loading branch information
Showing
28 changed files
with
697 additions
and
797 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
name: Checks | ||
|
||
on: | ||
push: | ||
branches: | ||
- main | ||
- dev | ||
paths: | ||
- "**" # Include all files by default | ||
- "!.devcontainer/**" | ||
- "!.vscode/**" | ||
- "!.git*" | ||
- "!*.md" | ||
- "!.github/**" | ||
- ".github/workflows/checks.yml" # Still include current workflow | ||
pull_request: | ||
branches: | ||
- main | ||
- dev | ||
paths: | ||
- "**" | ||
- "!.devcontainer/**" | ||
- "!.vscode/**" | ||
- "!.git*" | ||
- "!*.md" | ||
- "!.github/**" | ||
- ".github/workflows/checks.yml" | ||
# Allow this workflow to be called from other workflows | ||
workflow_call: | ||
inputs: | ||
# Requires at least one input to be valid, but in practice we don't need any | ||
dummy: | ||
type: string | ||
required: false | ||
|
||
permissions: | ||
actions: write | ||
contents: write | ||
|
||
jobs: | ||
code-checks: | ||
name: Code Checks | ||
runs-on: ubuntu-latest | ||
steps: | ||
- uses: actions/checkout@v4 | ||
- name: Setup PDM | ||
uses: pdm-project/setup-pdm@v4 | ||
# You are now able to use PDM in your workflow | ||
- name: Install dependencies | ||
run: pdm install | ||
- name: Type check | ||
run: pdm run mypy . |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
from transformer_lens import HookedTransformer | ||
import torch | ||
|
||
MODEL = "solu-2l" | ||
|
||
def time_diff(func): | ||
start_event = torch.cuda.Event(enable_timing=True) | ||
end_event = torch.cuda.Event(enable_timing=True) | ||
start_event.record() | ||
|
||
func() | ||
|
||
end_event.record() | ||
torch.cuda.synchronize() | ||
return start_event.elapsed_time(end_event) | ||
|
||
@torch.no_grad() | ||
def test_offload_params_after(): | ||
model = HookedTransformer.from_pretrained(MODEL, device="cuda") | ||
allocated_before = torch.cuda.memory_allocated(0) | ||
model.offload_params_after("blocks.0.hook_resid_post", torch.tensor([[0]], device="cuda")) | ||
allocated_after = torch.cuda.memory_allocated(0) | ||
assert allocated_after < allocated_before * 0.55 | ||
|
||
@torch.no_grad() | ||
def test_run_with_cache_until(): | ||
model = HookedTransformer.from_pretrained(MODEL, device="cuda") | ||
def forward(): | ||
model.run_with_cache("Hello, world!", names_filter=["blocks.0.hook_resid_post"]) | ||
forward_time = time_diff(forward) | ||
def forward_until(): | ||
model.run_with_cache_until("Hello, world!", names_filter=["blocks.0.hook_resid_post"]) | ||
forward_fake_time = time_diff(forward_until) | ||
assert forward_fake_time < forward_time * 0.7 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
from transformer_lens.hook_points import HookedRootModule, HookPoint | ||
import torch | ||
import torch.nn as nn | ||
|
||
class Block(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.subblock1 = nn.Linear(10, 10) | ||
self.subblock2 = nn.Linear(10, 10) | ||
self.activation = nn.ReLU() | ||
self.hook_mid = HookPoint() | ||
|
||
def forward(self, x): | ||
return self.subblock2(self.hook_mid(self.activation(self.subblock1(x)))) | ||
|
||
class TestModule(HookedRootModule): | ||
__test__ = False | ||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.blocks = nn.ModuleList([Block() for _ in range(3)]) | ||
self.embed = nn.Linear(1, 10) | ||
self.unembed = nn.Linear(10, 1) | ||
self.setup() | ||
|
||
def forward(self, x): | ||
x = self.embed(x) | ||
for block in self.blocks: | ||
x = block(x) | ||
return self.unembed(x) | ||
|
||
def test_run_with_cache_until(): | ||
model = TestModule() | ||
_, cache_before = model.run_with_cache(torch.tensor([1.]), names_filter=["blocks.0.hook_mid", "blocks.1.hook_mid"]) | ||
out, cache_after = model.run_with_cache_until(torch.tensor([1.]), names_filter=["blocks.0.hook_mid", "blocks.1.hook_mid"]) | ||
|
||
assert torch.allclose(cache_before["blocks.0.hook_mid"], cache_after["blocks.0.hook_mid"]) | ||
assert torch.allclose(cache_before["blocks.1.hook_mid"], cache_after["blocks.1.hook_mid"]) | ||
assert torch.allclose(cache_before["blocks.1.hook_mid"], out) | ||
|
||
def test_offload_params_after(): | ||
model = TestModule() | ||
_, cache_before = model.run_with_cache(torch.tensor([1.])) | ||
|
||
model.offload_params_after("blocks.1.hook_mid", torch.tensor([1.])) | ||
assert model.blocks[0].subblock1.weight is not None | ||
assert model.blocks[1].subblock1.weight is not None | ||
assert model.blocks[2].subblock1.weight is None | ||
assert model.unembed.weight is None | ||
|
||
_, cache_after = model.run_with_cache_until(torch.tensor([1.]), names_filter=["blocks.0.hook_mid", "blocks.1.hook_mid"]) | ||
assert torch.allclose(cache_before["blocks.0.hook_mid"], cache_after["blocks.0.hook_mid"]) | ||
assert torch.allclose(cache_before["blocks.1.hook_mid"], cache_after["blocks.1.hook_mid"]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.