Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

option for using raw scores rather than absolute value scores #10

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions auto_circuit/prune.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def run_circuits(
patch_type: PatchType = PatchType.EDGE_PATCH,
ablation_type: AblationType = AblationType.RESAMPLE,
reverse_clean_corrupt: bool = False,
use_abs: bool = True,
render_graph: bool = False,
render_score_threshold: bool = False,
render_file_path: Optional[str] = None,
Expand All @@ -45,6 +46,7 @@ def run_circuits(
patch_type: Whether to patch the circuit or the complement.
ablation_type: The type of ablation to use.
reverse_clean_corrupt: Reverse clean and corrupt (for input and patches).
use_abs: Whether to use the absolute value of the scores.
render_graph: Whether to render the graph using `draw_seq_graph`.
render_score_threshold: Edge score threshold, if `render_graph` is `True`.
render_file_path: Path to save the rendered graph, if `render_graph` is `True`.
Expand All @@ -56,7 +58,7 @@ def run_circuits(
tensors.
"""
circ_outs: CircuitOutputs = defaultdict(dict)
desc_ps: t.Tensor = desc_prune_scores(prune_scores)
desc_ps: t.Tensor = desc_prune_scores(prune_scores, use_abs=use_abs)

patch_src_outs: Optional[t.Tensor] = None
if ablation_type.mean_over_dataset:
Expand All @@ -83,19 +85,20 @@ def run_circuits(
with patch_mode(model, patch_src_outs):
for edge_count in (edge_pbar := tqdm(test_edge_counts)):
edge_pbar.set_description_str(f"Running Circuit: {edge_count} Edges")
threshold = prune_scores_threshold(desc_ps, edge_count)
threshold = prune_scores_threshold(desc_ps, edge_count, use_abs)
# When prune_scores are tied we can't prune exactly edge_count edges
patch_edge_count = 0
for mod_name, patch_mask in prune_scores.items():
dest = module_by_name(model, mod_name)
assert isinstance(dest, PatchWrapper)
assert dest.is_dest and dest.patch_mask is not None
patch_mask = patch_mask.abs() if use_abs else patch_mask
if patch_type == PatchType.EDGE_PATCH:
dest.patch_mask.data = (patch_mask.abs() >= threshold).float()
dest.patch_mask.data = (patch_mask >= threshold).float()
patch_edge_count += dest.patch_mask.int().sum().item()
else:
assert patch_type == PatchType.TREE_PATCH
dest.patch_mask.data = (patch_mask.abs() < threshold).float()
dest.patch_mask.data = (patch_mask < threshold).float()
patch_edge_count += (1 - dest.patch_mask.int()).sum().item()
with t.inference_mode():
model_output = model(batch_input)[model.out_slice]
Expand All @@ -107,6 +110,7 @@ def run_circuits(
show_all_seq_pos=False,
seq_labels=dataloader.seq_labels,
file_path=render_file_path,
use_abs=use_abs,
)
del patch_src_outs
return circ_outs
13 changes: 9 additions & 4 deletions auto_circuit/utils/tensor_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,22 +233,26 @@ def flat_prune_scores(prune_scores: PruneScores) -> t.Tensor:
return t.cat([ps.flatten() for _, ps in prune_scores.items()])


def desc_prune_scores(prune_scores: PruneScores) -> t.Tensor:
def desc_prune_scores(prune_scores: PruneScores, use_abs: bool = True) -> t.Tensor:
"""
Flatten the prune scores into a single, 1-dimensional tensor and sort them in
descending order.

Args:
prune_scores: The prune scores to flatten and sort.
use_abs: Whether to sort the absolute values of the prune scores.

Returns:
The flattened and sorted prune scores.
"""
return flat_prune_scores(prune_scores).abs().sort(descending=True).values
flat_ps = flat_prune_scores(prune_scores)
if use_abs:
flat_ps = flat_ps.abs()
return flat_ps.sort(descending=True).values


def prune_scores_threshold(
prune_scores: PruneScores | t.Tensor, edge_count: int
prune_scores: PruneScores | t.Tensor, edge_count: int, use_abs: bool = True
) -> t.Tensor:
"""
Return the minimum absolute value of the top `edge_count` prune scores.
Expand All @@ -257,6 +261,7 @@ def prune_scores_threshold(
Args:
prune_scores: The prune scores to threshold.
edge_count: The number of edges that should be above the threshold.
use_abs: Whether to use the absolute values of the prune scores.

Returns:
The threshold value.
Expand All @@ -268,4 +273,4 @@ def prune_scores_threshold(
assert prune_scores.ndim == 1
return prune_scores[edge_count - 1]
else:
return desc_prune_scores(prune_scores)[edge_count - 1]
return desc_prune_scores(prune_scores, use_abs=use_abs)[edge_count - 1]
14 changes: 11 additions & 3 deletions auto_circuit/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def net_viz(
seq_idx: Optional[int] = None,
score_threshold: float = 1e-2,
layer_spacing: bool = False,
use_abs: bool = True,
orientation: Literal["h", "v"] = "h",
) -> Tuple[go.Sankey, int]:
"""
Expand Down Expand Up @@ -71,6 +72,8 @@ def net_viz(
layer_spacing: If `True`, all nodes are spaced according to the layer they in.
Otherwise, the Plotly automatic spacing is used and nodes in later layers
may appear to the left of nodes in earlier layers.
use_abs: If `True`, the absolute value of the edge scores will be used to threshold
the edges. If `False`, the raw edge scores will be used.
orientation: The orientation of the sankey diagram. Can be either `"h"` for
horizontal or `"v"` for vertical.

Expand Down Expand Up @@ -113,7 +116,7 @@ def net_viz(
edge_score = prune_scores[e.dest.module_name][e.patch_idx].item()
lbl = None

if abs(edge_score) < score_threshold:
if (abs(edge_score) if use_abs else edge_score) < score_threshold:
continue

color_idx = len(sources) % len(COLOR_PALETTE)
Expand Down Expand Up @@ -202,6 +205,7 @@ def draw_seq_graph(
show_all_seq_pos: bool = False,
seq_labels: Optional[List[str]] = None,
layer_spacing: bool = False,
use_abs: bool = True,
orientation: Literal["h", "v"] = "h",
display_ipython: bool = True,
file_path: Optional[str] = None,
Expand All @@ -222,6 +226,8 @@ def draw_seq_graph(
prune_scores: The edge scores to use for the visualization. If `None`, the
current activations and mask values of the model will be visualized instead.
score_threshold: The minimum _absolute_ edge score to show in the diagram.
use_abs: If `True`, the absolute value of the edge scores will be used to threshold
the edges. If `False`, the raw edge scores will be used.
show_all_seq_pos: If `True`, the diagram will show all token positions, even if
they have no non-zero edge values. If `False`, only token positions with
non-zero edge values will be shown.
Expand All @@ -244,12 +250,13 @@ def draw_seq_graph(
edge_scores = model.current_patch_masks_as_prune_scores().values()
else:
edge_scores = prune_scores.values()
ps = [t.clamp(v.abs() - score_threshold, min=0).sum().item() for v in edge_scores]
ps = [t.clamp((v.abs() if use_abs else v) - score_threshold, min=0).sum().item() for v in edge_scores]
total_ps = max(sum(ps), 1e-2)
if seq_len > 1:
sankey_heights: Dict[Optional[int], float] = {}
for patch_mask in edge_scores:
ps_seq_tots = t.clamp(patch_mask.abs() - score_threshold, min=0.0)
patch_mask = patch_mask.abs() if use_abs else patch_mask
ps_seq_tots = t.clamp(patch_mask - score_threshold, min=0.0)
ps_seq_tots = ps_seq_tots.sum(dim=list(range(1, patch_mask.ndim)))
for seq_idx, ps_seq_tot in enumerate(ps_seq_tots):
if ps_seq_tot > 0 or show_all_seq_pos:
Expand Down Expand Up @@ -288,6 +295,7 @@ def draw_seq_graph(
score_threshold=score_threshold,
layer_spacing=layer_spacing,
orientation=orientation,
use_abs=use_abs,
)
sankeys.append(viz)

Expand Down