Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add head only version #76

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
ims/
out.txt
out2.txt
out1.txt
Expand Down
57 changes: 37 additions & 20 deletions acdc/TLACDCCorrespondence.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def remove_edge(
child.parents.remove(parent)

@classmethod
def setup_from_model(cls, model, use_pos_embed=False):
def setup_from_model(cls, model, use_pos_embed=False, use_split_qkv=True):
correspondence = cls()

downstream_residual_nodes: List[TLACDCInterpNode] = []
Expand Down Expand Up @@ -167,34 +167,51 @@ def setup_from_model(cls, model, use_pos_embed=False):
safe=False,
)

for letter in "qkv":
hook_letter_name = f"blocks.{layer_idx}.attn.hook_{letter}"
hook_letter_slice = TorchIndex([None, None, head_idx])
hook_letter_node = TLACDCInterpNode(name=hook_letter_name, index=hook_letter_slice, incoming_edge_type=EdgeType.DIRECT_COMPUTATION)
correspondence.add_node(hook_letter_node)
if use_split_qkv:
for letter in "qkv":
hook_letter_name = f"blocks.{layer_idx}.attn.hook_{letter}"
hook_letter_slice = TorchIndex([None, None, head_idx])
hook_letter_node = TLACDCInterpNode(name=hook_letter_name, index=hook_letter_slice, incoming_edge_type=EdgeType.DIRECT_COMPUTATION)
correspondence.add_node(hook_letter_node)

hook_letter_input_name = f"blocks.{layer_idx}.hook_{letter}_input"
hook_letter_input_slice = TorchIndex([None, None, head_idx])
hook_letter_input_node = TLACDCInterpNode(
name=hook_letter_input_name, index=hook_letter_input_slice, incoming_edge_type=EdgeType.ADDITION
)
correspondence.add_node(hook_letter_input_node)
hook_letter_input_name = f"blocks.{layer_idx}.hook_{letter}_input"
hook_letter_input_slice = TorchIndex([None, None, head_idx])
hook_letter_input_node = TLACDCInterpNode(
name=hook_letter_input_name, index=hook_letter_input_slice, incoming_edge_type=EdgeType.ADDITION
)
correspondence.add_node(hook_letter_input_node)

correspondence.add_edge(
parent_node = hook_letter_node,
child_node = cur_head,
edge = Edge(edge_type=EdgeType.PLACEHOLDER),
safe = False,
)

correspondence.add_edge(
parent_node=hook_letter_input_node,
child_node=hook_letter_node,
edge=Edge(edge_type=EdgeType.DIRECT_COMPUTATION),
safe=False,
)

new_downstream_residual_nodes.append(hook_letter_input_node)

else:
hook_head_name = f"blocks.{layer_idx}.hook_attn_in"
hook_letter_slice = TorchIndex([None, None, head_idx])
hook_head_node = TLACDCInterpNode(name=hook_head_name, index=hook_letter_slice, incoming_edge_type=EdgeType.ADDITION)
correspondence.add_node(hook_head_node)

correspondence.add_edge(
parent_node = hook_letter_node,
parent_node = hook_head_node,
child_node = cur_head,
edge = Edge(edge_type=EdgeType.PLACEHOLDER),
safe = False,
)

correspondence.add_edge(
parent_node=hook_letter_input_node,
child_node=hook_letter_node,
edge=Edge(edge_type=EdgeType.DIRECT_COMPUTATION),
safe=False,
)
new_downstream_residual_nodes.append(hook_head_node)

new_downstream_residual_nodes.append(hook_letter_input_node)
downstream_residual_nodes.extend(new_downstream_residual_nodes)

if use_pos_embed:
Expand Down
26 changes: 19 additions & 7 deletions acdc/TLACDCExperiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def __init__(
names_mode: Literal["normal", "reverse", "shuffle"] = "normal",
wandb_config: Optional[Namespace] = None,
early_exit: bool = False,
use_split_qkv: bool = True,
):
"""Initialize the ACDC experiment"""

Expand All @@ -91,6 +92,7 @@ def __init__(

model.reset_hooks()

self.use_split_qkv = use_split_qkv
self.remove_redundant = remove_redundant
self.indices_mode = indices_mode
self.names_mode = names_mode
Expand All @@ -110,7 +112,7 @@ def __init__(
warnings.warn("Never skipping edges, for now")
skip_edges = "no"

self.corr = TLACDCCorrespondence.setup_from_model(self.model, use_pos_embed=use_pos_embed)
self.corr = TLACDCCorrespondence.setup_from_model(self.model, use_pos_embed=use_pos_embed, use_split_qkv=self.use_split_qkv)

if early_exit:
return
Expand Down Expand Up @@ -186,7 +188,16 @@ def verify_model_setup(self):
if not self.model.cfg.attn_only and "use_hook_mlp_in" in self.model.cfg.to_dict():
assert self.model.cfg.use_hook_mlp_in, "Need to be able to see hook MLP inputs"
assert self.model.cfg.use_attn_result, "Need to be able to see split by head outputs"
assert self.model.cfg.use_split_qkv_input, "Need to be able to see split by head QKV inputs"

if self.use_split_qkv:
assert self.model.cfg.use_split_qkv_input, "Need to be able to see split by head QKV inputs"
else:
try:
assert self.model.cfg.use_attn_in
except AttributeError:
raise Exception("You need to be using the attention in version of the TransformerLens library, available here: https://github.com/ArthurConmy/TransformerLens/tree/arthur-add-attn-in . Alternatively, hopefully this is merged into Neel's main branch by the time you read this!")
except Exception as e:
raise e

def update_cur_metric(self, recalc_metric=True, recalc_edges=True, initial=False):
if recalc_metric:
Expand Down Expand Up @@ -644,6 +655,7 @@ def step(self, early_stop=False, testing=False):
print("Removing redundant node", self.current_node)
self.remove_redundant_node(self.current_node)

# TODO add back
if is_this_node_used and self.current_node.incoming_edge_type.value != EdgeType.PLACEHOLDER.value:
fname = f"ims/img_new_{self.step_idx}.png"
show(
Expand Down Expand Up @@ -776,10 +788,6 @@ def count_no_edges(self, verbose=False) -> int:
print("No edge", cnt)
return cnt

def reload_hooks(self):
old_corr = self.corr
self.corr = TLACDCCorrespondence.setup_from_model(self.model)

def save_subgraph(self, fpath: Optional[str]=None, return_it=False) -> None:
"""Saves the subgraph as a Dictionary of all the edges, so it can be reloaded (or return that)"""

Expand All @@ -806,7 +814,11 @@ def load_subgraph(self, subgraph: Subgraph):
receiver_name, receiver_torch_index, sender_name, sender_torch_index = tupl
receiver_index, sender_index = receiver_torch_index.hashable_tuple, sender_torch_index.hashable_tuple
set_of_edges.add((receiver_name, receiver_index, sender_name, sender_index))
assert set(subgraph.keys()) == set_of_edges, f"Ensure that the dictionary includes exactly the correct keys... e.g missing {list( set(set_of_edges) - set(subgraph.keys()) )[:1]} and has excess stuff { list(set(subgraph.keys()) - set_of_edges)[:1] }"

assert len(set(subgraph.keys()) - set_of_edges) == 0 or set(subgraph.keys()) == set_of_edges, f"Ensure that the dictionary includes exactly the correct keys... e.g missing {list( set(set_of_edges) - set(subgraph.keys()) )[:1]} and has excess stuff { list(set(subgraph.keys()) - set_of_edges)[:1] }"
if set(subgraph.keys()) != set_of_edges:
for edge in set_of_edges - set(subgraph.keys()):
subgraph[edge] = False

print("Editing all edges...")
for (receiver_name, receiver_index, sender_name, sender_index), is_present in subgraph.items():
Expand Down
13 changes: 8 additions & 5 deletions acdc/acdc_graphics.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ def get_node_name(node: TLACDCInterpNode, show_full_index=True):
relevant_letter = letter
name += "a" + node.name.split(".")[1] + "." + str(node.index.hashable_tuple[2]) + "_" + relevant_letter

# Handle attention hook_result
elif "hook_result" in node.name or any([qkv_substring in node.name for qkv_substring in qkv_substrings]):
# Handle attention hook_result or attn_in thing
elif "hook_result" in node.name or any([qkv_substring in node.name for qkv_substring in qkv_substrings]) or node.name.endswith("attn_in"):
name = "a" + node.name.split(".")[1] + "." + str(node.index.hashable_tuple[2])

# Handle MLPs
Expand All @@ -76,7 +76,8 @@ def get_node_name(node: TLACDCInterpNode, show_full_index=True):
name += "resid_post"

else:
raise ValueError(f"Unrecognized node name {node.name}")
# TODO add a warning here? Names may be cursed
name = node.name

if show_full_index:
name += f"_{str(node.index.graphviz_index())}"
Expand All @@ -98,6 +99,7 @@ def show(
show_full_index: bool = True,
remove_self_loops: bool = True,
remove_qkv: bool = False,
show_effect_size_none: bool = False,
) -> pgv.AGraph:
"""
Colorscheme: a color for each node name, or a string corresponding to a cmapy color scheme
Expand Down Expand Up @@ -130,7 +132,7 @@ def show(
# Important this go after the qkv removal
continue

if edge.present and edge.effect_size is not None and edge.edge_type != EdgeType.PLACEHOLDER:
if edge.present and (show_effect_size_none or edge.effect_size is not None) and edge.edge_type != EdgeType.PLACEHOLDER:
for node_name in [parent_name, child_name]:
g.add_node(
node_name,
Expand All @@ -140,10 +142,11 @@ def show(
fontname="Helvetica"
)

cur_effect_size = edge.effect_size if edge.effect_size is not None else 0
g.add_edge(
parent_name,
child_name,
penwidth=str(max(minimum_penwidth, edge.effect_size)),
penwidth=str(max(minimum_penwidth, cur_effect_size)),
color=colors[parent_name],
)

Expand Down
4 changes: 4 additions & 0 deletions acdc/docstring/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ def get_docstring_model(device="cuda"):
)
tl_model.set_use_attn_result(True)
tl_model.set_use_split_qkv_input(True)
try:
tl_model.set_use_attn_in(True)
except AttributeError as e:
raise Exception("You need to be using the `use_attn_in` version of the TransformerLens library, available here: https://github.com/ArthurConmy/TransformerLens/tree/arthur-add-attn-in . Alternatively, hopefully this is merged into Neel's main branch by the time you read this!")
if "use_hook_mlp_in" in tl_model.cfg.to_dict(): # not strictly necessary, but good practice to keep compatibility with new *optional* transformerlens feature
tl_model.set_use_hook_mlp_in(True)
tl_model.to(device)
Expand Down
6 changes: 5 additions & 1 deletion acdc/induction/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,11 @@ def get_model(device):

# standard ACDC options
tl_model.set_use_attn_result(True)
tl_model.set_use_split_qkv_input(True)
try:
tl_model.set_use_attn_in(True)
except AttributeError as e:
raise Exception("You need to be using the `use_attn_in` version of the TransformerLens library, available here: https://github.com/ArthurConmy/TransformerLens/tree/arthur-add-attn-in . Alternatively, hopefully this is merged into Neel's main branch by the time you read this!")
tl_model.set_use_split_qkv_input(True)
if "use_hook_mlp_in" in tl_model.cfg.to_dict(): # not strictly necessary, but good practice to keep compatibility with new *optional* transformerlens feature
tl_model.set_use_hook_mlp_in(True)
return tl_model
Expand Down
18 changes: 12 additions & 6 deletions acdc/ioi/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,27 @@
import wandb
from transformer_lens.HookedTransformer import HookedTransformer

def get_gpt2_small(device="cuda") -> HookedTransformer:
def get_gpt2_small(device="cuda", split_qkv: bool = True) -> HookedTransformer:
tl_model = HookedTransformer.from_pretrained("gpt2")
tl_model = tl_model.to(device)
tl_model.set_use_attn_result(True)
tl_model.set_use_split_qkv_input(True)
if split_qkv:
tl_model.set_use_split_qkv_input(True)
else:
try:
tl_model.set_use_attn_in(True)
except AttributeError as e:
raise Exception("You need to be using the `use_attn_in` version of the TransformerLens library, available here: https://github.com/ArthurConmy/TransformerLens/tree/arthur-add-attn-in . Alternatively, hopefully this is merged into Neel's main branch by the time you read this!")
if "use_hook_mlp_in" in tl_model.cfg.to_dict():
tl_model.set_use_hook_mlp_in(True)
return tl_model

def get_ioi_gpt2_small(device="cuda"):
def get_ioi_gpt2_small(device="cuda", split_qkv: bool = True):
"""For backwards compat"""
return get_gpt2_small(device=device)
return get_gpt2_small(device=device, split_qkv=split_qkv)

def get_all_ioi_things(num_examples, device, metric_name, kl_return_one_element=True):
tl_model = get_gpt2_small(device=device)
def get_all_ioi_things(num_examples, device, metric_name, kl_return_one_element=True, split_qkv: bool = True):
tl_model = get_gpt2_small(device=device, split_qkv=split_qkv)
ioi_dataset = IOIDataset(
prompt_type="ABBA",
N=num_examples*2,
Expand Down
25 changes: 18 additions & 7 deletions acdc/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# <p>Janky code to do different setup when run in a Colab notebook vs VSCode (adapted from e.g <a href="https://github.com/neelnanda-io/TransformerLens/blob/5c89b7583e73ce96db5e46ef86a14b15f303dde6/demos/Activation_Patching_in_TL_Demo.ipynb">this notebook</a>)</p>

#%%

try:
import google.colab

Expand Down Expand Up @@ -163,24 +164,30 @@
parser.add_argument('--seed', type=int, default=1234)
parser.add_argument("--max-num-epochs",type=int, default=100_000)
parser.add_argument('--single-step', action='store_true', help='Use single step, mostly for testing')
parser.add_argument("--dont-split-qkv", action="store_true", help="Dont splits qkv")
parser.add_argument("--abs-value-threshold", action='store_true', help='Use the absolute value of the result to check threshold')

if ipython is not None:
# we are in a notebook
# you can put the command you would like to run as the ... in r"""..."""
args = parser.parse_args(
[line.strip() for line in r"""--task=induction\
--zero-ablation\
--threshold=0.71\
--threshold=0.0175\
--metric=kl_div\
--indices-mode=reverse\
--first-cache-cpu=False\
--second-cache-cpu=False\
--max-num-epochs=100000""".split("\\\n")]
)
--max-num-epochs=100000\
--dont-split-qkv\
--using-wandb""".split("\\\n")]
) # also 0.39811 # also on the main machine you just added two lines here.

else:
# read from command line
args = parser.parse_args()

print(args)

# process args

if args.torch_num_threads > 0:
Expand Down Expand Up @@ -216,6 +223,7 @@
DEVICE = args.device
RESET_NETWORK = args.reset_network
SINGLE_STEP = True if args.single_step else False
SPLIT_QKV = False if args.dont_split_qkv else True

#%% [markdown]
# <h2>Setup Task</h2>
Expand All @@ -227,7 +235,7 @@
if TASK == "ioi":
num_examples = 100
things = get_all_ioi_things(
num_examples=num_examples, device=DEVICE, metric_name=args.metric
num_examples=num_examples, device=DEVICE, metric_name=args.metric, split_qkv=SPLIT_QKV,
)
elif TASK == "tracr-reverse":
num_examples = 6
Expand Down Expand Up @@ -337,22 +345,23 @@
add_receiver_hooks=False,
remove_redundant=False,
show_full_index=use_pos_embed,
use_split_qkv=SPLIT_QKV,
)

# %% [markdown]
# <h2>Run steps of ACDC: iterate over a NODE in the model's computational graph</h2>
# <p>WARNING! This will take a few minutes to run, but there should be rolling nice pictures too : )</p>

#%%

for i in range(args.max_num_epochs):
exp.step(testing=False)

# TODO add back
show(
exp.corr,
f"ims/img_new_{i+1}.png",
show_full_index=use_pos_embed,
)

if IN_COLAB or ipython is not None:
# so long as we're not running this as a script, show the image!
display(Image(f"ims/img_new_{i+1}.png"))
Expand Down Expand Up @@ -387,3 +396,5 @@
exp.save_subgraph(
return_it=True,
)

#%% HELLO
17 changes: 8 additions & 9 deletions notebooks/roc_plot_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,7 @@ def get_col(df, col): # dumb util
parser.add_argument("--ignore-missing-score", action="store_true", help="Ignore runs that are missing score")

if IPython.get_ipython() is not None:
args = parser.parse_args("--task=tracr-reverse --metric=l2 --alg=acdc".split())
if "arthur" not in __file__:
__file__ = "/Users/adria/Documents/2023/ACDC/Automatic-Circuit-Discovery/notebooks/roc_plot_generator.py"
args = parser.parse_args("--task ioi --mode edges --metric logit_diff --alg acdc".split())
else:
args = parser.parse_args()

Expand Down Expand Up @@ -208,7 +206,7 @@ def get_col(df, col): # dumb util
OUT_FILE = OUT_DIR / f"{args.alg}-{args.task}-{args.metric}-{args.zero_ablation}-{args.reset_network}.json"

if OUT_FILE.exists():
print("File already exists, skipping")
print(f"File {str(OUT_FILE)} already exists, skipping")
sys.exit(0)
else:
OUT_FILE = None
Expand Down Expand Up @@ -316,11 +314,12 @@ def get_col(df, col): # dumb util
except KeyError:
pass
ACDC_PRE_RUN_FILTER = {
"$or": [
{"group": "reset-networks-neurips", **ACDC_PRE_RUN_FILTER},
{"group": "acdc-gt-ioi-redo", **ACDC_PRE_RUN_FILTER},
{"group": "acdc-spreadsheet2", **ACDC_PRE_RUN_FILTER},
]
"id": "0wfrojop",
# "$or": [
# {"group": "reset-networks-neurips", **ACDC_PRE_RUN_FILTER},
# {"group": "acdc-gt-ioi-redo", **ACDC_PRE_RUN_FILTER},
# {"group": "acdc-spreadsheet2", **ACDC_PRE_RUN_FILTER},
# ]
}

get_true_edges = partial(get_ioi_true_edges, model=things.tl_model)
Expand Down
Loading