From c72351a63567c74d0493257b4c43289ac86f63b9 Mon Sep 17 00:00:00 2001 From: oliveradk Date: Wed, 21 Aug 2024 13:20:59 -0700 Subject: [PATCH 1/2] option for using raw scores rather than absolute value scores --- auto_circuit/prune.py | 10 +++++++--- auto_circuit/utils/tensor_ops.py | 13 +++++++++---- auto_circuit/visualize.py | 14 +++++++++++--- 3 files changed, 27 insertions(+), 10 deletions(-) diff --git a/auto_circuit/prune.py b/auto_circuit/prune.py index da1370d..5d16d80 100644 --- a/auto_circuit/prune.py +++ b/auto_circuit/prune.py @@ -30,6 +30,7 @@ def run_circuits( patch_type: PatchType = PatchType.EDGE_PATCH, ablation_type: AblationType = AblationType.RESAMPLE, reverse_clean_corrupt: bool = False, + use_abs: bool = True, render_graph: bool = False, render_score_threshold: bool = False, render_file_path: Optional[str] = None, @@ -45,6 +46,7 @@ def run_circuits( patch_type: Whether to patch the circuit or the complement. ablation_type: The type of ablation to use. reverse_clean_corrupt: Reverse clean and corrupt (for input and patches). + use_abs: Whether to use the absolute value of the scores. render_graph: Whether to render the graph using `draw_seq_graph`. render_score_threshold: Edge score threshold, if `render_graph` is `True`. render_file_path: Path to save the rendered graph, if `render_graph` is `True`. @@ -83,19 +85,20 @@ def run_circuits( with patch_mode(model, patch_src_outs): for edge_count in (edge_pbar := tqdm(test_edge_counts)): edge_pbar.set_description_str(f"Running Circuit: {edge_count} Edges") - threshold = prune_scores_threshold(desc_ps, edge_count) + threshold = prune_scores_threshold(desc_ps, edge_count, use_abs) # When prune_scores are tied we can't prune exactly edge_count edges patch_edge_count = 0 for mod_name, patch_mask in prune_scores.items(): dest = module_by_name(model, mod_name) assert isinstance(dest, PatchWrapper) assert dest.is_dest and dest.patch_mask is not None + patch_mask = patch_mask.abs() if use_abs else patch_mask if patch_type == PatchType.EDGE_PATCH: - dest.patch_mask.data = (patch_mask.abs() >= threshold).float() + dest.patch_mask.data = (patch_mask >= threshold).float() patch_edge_count += dest.patch_mask.int().sum().item() else: assert patch_type == PatchType.TREE_PATCH - dest.patch_mask.data = (patch_mask.abs() < threshold).float() + dest.patch_mask.data = (patch_mask < threshold).float() patch_edge_count += (1 - dest.patch_mask.int()).sum().item() with t.inference_mode(): model_output = model(batch_input)[model.out_slice] @@ -107,6 +110,7 @@ def run_circuits( show_all_seq_pos=False, seq_labels=dataloader.seq_labels, file_path=render_file_path, + use_abs=use_abs, ) del patch_src_outs return circ_outs diff --git a/auto_circuit/utils/tensor_ops.py b/auto_circuit/utils/tensor_ops.py index 597c6f8..61da173 100644 --- a/auto_circuit/utils/tensor_ops.py +++ b/auto_circuit/utils/tensor_ops.py @@ -233,22 +233,26 @@ def flat_prune_scores(prune_scores: PruneScores) -> t.Tensor: return t.cat([ps.flatten() for _, ps in prune_scores.items()]) -def desc_prune_scores(prune_scores: PruneScores) -> t.Tensor: +def desc_prune_scores(prune_scores: PruneScores, use_abs: bool = True) -> t.Tensor: """ Flatten the prune scores into a single, 1-dimensional tensor and sort them in descending order. Args: prune_scores: The prune scores to flatten and sort. + use_abs: Whether to sort the absolute values of the prune scores. Returns: The flattened and sorted prune scores. """ - return flat_prune_scores(prune_scores).abs().sort(descending=True).values + flat_ps = flat_prune_scores(prune_scores) + if use_abs: + flat_ps = flat_ps.abs() + return flat_ps.sort(descending=True).values def prune_scores_threshold( - prune_scores: PruneScores | t.Tensor, edge_count: int + prune_scores: PruneScores | t.Tensor, edge_count: int, use_abs: bool = True ) -> t.Tensor: """ Return the minimum absolute value of the top `edge_count` prune scores. @@ -257,6 +261,7 @@ def prune_scores_threshold( Args: prune_scores: The prune scores to threshold. edge_count: The number of edges that should be above the threshold. + use_abs: Whether to use the absolute values of the prune scores. Returns: The threshold value. @@ -268,4 +273,4 @@ def prune_scores_threshold( assert prune_scores.ndim == 1 return prune_scores[edge_count - 1] else: - return desc_prune_scores(prune_scores)[edge_count - 1] + return desc_prune_scores(prune_scores, use_abs=use_abs)[edge_count - 1] diff --git a/auto_circuit/visualize.py b/auto_circuit/visualize.py index 62f8ece..c807f1b 100644 --- a/auto_circuit/visualize.py +++ b/auto_circuit/visualize.py @@ -43,6 +43,7 @@ def net_viz( seq_idx: Optional[int] = None, score_threshold: float = 1e-2, layer_spacing: bool = False, + use_abs: bool = True, orientation: Literal["h", "v"] = "h", ) -> Tuple[go.Sankey, int]: """ @@ -71,6 +72,8 @@ def net_viz( layer_spacing: If `True`, all nodes are spaced according to the layer they in. Otherwise, the Plotly automatic spacing is used and nodes in later layers may appear to the left of nodes in earlier layers. + use_abs: If `True`, the absolute value of the edge scores will be used to threshold + the edges. If `False`, the raw edge scores will be used. orientation: The orientation of the sankey diagram. Can be either `"h"` for horizontal or `"v"` for vertical. @@ -113,7 +116,7 @@ def net_viz( edge_score = prune_scores[e.dest.module_name][e.patch_idx].item() lbl = None - if abs(edge_score) < score_threshold: + if (abs(edge_score) if use_abs else edge_score) < score_threshold: continue color_idx = len(sources) % len(COLOR_PALETTE) @@ -202,6 +205,7 @@ def draw_seq_graph( show_all_seq_pos: bool = False, seq_labels: Optional[List[str]] = None, layer_spacing: bool = False, + use_abs: bool = True, orientation: Literal["h", "v"] = "h", display_ipython: bool = True, file_path: Optional[str] = None, @@ -222,6 +226,8 @@ def draw_seq_graph( prune_scores: The edge scores to use for the visualization. If `None`, the current activations and mask values of the model will be visualized instead. score_threshold: The minimum _absolute_ edge score to show in the diagram. + use_abs: If `True`, the absolute value of the edge scores will be used to threshold + the edges. If `False`, the raw edge scores will be used. show_all_seq_pos: If `True`, the diagram will show all token positions, even if they have no non-zero edge values. If `False`, only token positions with non-zero edge values will be shown. @@ -244,12 +250,13 @@ def draw_seq_graph( edge_scores = model.current_patch_masks_as_prune_scores().values() else: edge_scores = prune_scores.values() - ps = [t.clamp(v.abs() - score_threshold, min=0).sum().item() for v in edge_scores] + ps = [t.clamp((v.abs() if use_abs else v) - score_threshold, min=0).sum().item() for v in edge_scores] total_ps = max(sum(ps), 1e-2) if seq_len > 1: sankey_heights: Dict[Optional[int], float] = {} for patch_mask in edge_scores: - ps_seq_tots = t.clamp(patch_mask.abs() - score_threshold, min=0.0) + patch_mask = patch_mask.abs() if use_abs else patch_mask + ps_seq_tots = t.clamp(patch_mask - score_threshold, min=0.0) ps_seq_tots = ps_seq_tots.sum(dim=list(range(1, patch_mask.ndim))) for seq_idx, ps_seq_tot in enumerate(ps_seq_tots): if ps_seq_tot > 0 or show_all_seq_pos: @@ -288,6 +295,7 @@ def draw_seq_graph( score_threshold=score_threshold, layer_spacing=layer_spacing, orientation=orientation, + use_abs=use_abs, ) sankeys.append(viz) From d0400929bf82dabcb6ec59f862d9811ac034425c Mon Sep 17 00:00:00 2001 From: oliveradk Date: Wed, 21 Aug 2024 15:10:19 -0700 Subject: [PATCH 2/2] added missing use_abs to desc_prune_scores --- auto_circuit/prune.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/auto_circuit/prune.py b/auto_circuit/prune.py index 5d16d80..f1f2c48 100644 --- a/auto_circuit/prune.py +++ b/auto_circuit/prune.py @@ -58,7 +58,7 @@ def run_circuits( tensors. """ circ_outs: CircuitOutputs = defaultdict(dict) - desc_ps: t.Tensor = desc_prune_scores(prune_scores) + desc_ps: t.Tensor = desc_prune_scores(prune_scores, use_abs=use_abs) patch_src_outs: Optional[t.Tensor] = None if ablation_type.mean_over_dataset: