diff --git a/auto_circuit/experiments.py b/auto_circuit/experiments.py deleted file mode 100644 index 8e8c800..0000000 --- a/auto_circuit/experiments.py +++ /dev/null @@ -1,272 +0,0 @@ -#%% -from collections import defaultdict -from pathlib import Path -from typing import List - -import torch as t - -from auto_circuit.metrics.completeness_metrics.same_under_knockouts import ( - TaskCompletenessScores, - measure_same_under_knockouts, - same_under_knockouts_fig, -) -from auto_circuit.metrics.completeness_metrics.train_same_under_knockouts import ( - train_same_under_knockouts, -) -from auto_circuit.metrics.official_circuits.measure_roc import measure_roc -from auto_circuit.metrics.official_circuits.roc_plot import roc_plot -from auto_circuit.metrics.prune_metrics.measure_prune_metrics import ( - measure_prune_metrics, - measurement_figs, -) -from auto_circuit.metrics.prune_metrics.prune_metrics import ( - ANSWER_PROB_METRIC, - CLEAN_KL_DIV_METRIC, - CORRECT_ANSWER_GREATER_THAN_INCORRECT_PERCENT_METRIC, - CORRECT_ANSWER_PERCENT_METRIC, - LOGIT_DIFF_PERCENT_METRIC, - PruneMetric, -) -from auto_circuit.metrics.prune_scores_similarity import prune_score_similarities_plotly -from auto_circuit.prune_algos.prune_algos import ( - CIRCUIT_TREE_PROBING_PRUNE_ALGO, - GROUND_TRUTH_PRUNE_ALGO, - LOGIT_DIFF_GRAD_PRUNE_ALGO, - PRUNE_ALGO_DICT, - RANDOM_PRUNE_ALGO, - PruneAlgo, - run_prune_algos, -) -from auto_circuit.tasks import ( - DOCSTRING_TOKEN_CIRCUIT_TASK, - SPORTS_PLAYERS_TOKEN_CIRCUIT_TASK, - TASK_DICT, - Task, -) -from auto_circuit.types import ( - AblationType, - PatchType, - TaskMeasurements, - TaskPruneScores, -) -from auto_circuit.utils.misc import load_cache, repo_path_to_abs_path, save_cache -from auto_circuit.utils.tensor_ops import prune_scores_threshold -from auto_circuit.visualize import draw_seq_graph - -figs = [] - -# ------------------------------------ Prune Scores ------------------------------------ - -compute_prune_scores = False -save_prune_scores = False -load_prune_scores = True - -task_prune_scores: TaskPruneScores = defaultdict(dict) -cache_folder_name = ".prune_scores_cache" -if compute_prune_scores: - TASKS: List[Task] = [ - # Token Circuits - # SPORTS_PLAYERS_TOKEN_CIRCUIT_TASK, - # IOI_TOKEN_CIRCUIT_TASK, - DOCSTRING_TOKEN_CIRCUIT_TASK, - # Component Circuits - # SPORTS_PLAYERS_COMPONENT_CIRCUIT_TASK, - # IOI_COMPONENT_CIRCUIT_TASK, - # DOCSTRING_COMPONENT_CIRCUIT_TASK, - # GREATERTHAN_COMPONENT_CIRCUIT_TASK, - # Autoencoder Component Circuits - # IOI_GPT2_AUTOENCODER_COMPONENT_CIRCUIT_TASK, - # GREATERTHAN_GPT2_AUTOENCODER_COMPONENT_CIRCUIT_TASK - # ANIMAL_DIET_GPT2_AUTOENCODER_COMPONENT_CIRCUIT_TASK, - # CAPITAL_CITIES_PYTHIA_70M_AUTOENCODER_COMPONENT_CIRCUIT_TASK, - ] - PRUNE_ALGOS: List[PruneAlgo] = [ - GROUND_TRUTH_PRUNE_ALGO, - # ACT_MAG_PRUNE_ALGO, - RANDOM_PRUNE_ALGO, - # EDGE_ATTR_PATCH_PRUNE_ALGO, - # ACDC_PRUNE_ALGO, - # INTEGRATED_EDGE_GRADS_LOGIT_DIFF_PRUNE_ALGO, - # LOGPROB_GRAD_PRUNE_ALGO, - # LOGPROB_DIFF_GRAD_PRUNE_ALGO, - LOGIT_DIFF_GRAD_PRUNE_ALGO, # Fast implementation of Edge Attribution Patchng - # LOGIT_MSE_GRAD_PRUNE_ALGO, - # SUBNETWORK_EDGE_PROBING_PRUNE_ALGO, - # CIRCUIT_PROBING_PRUNE_ALGO, - # SUBNETWORK_TREE_PROBING_PRUNE_ALGO, - CIRCUIT_TREE_PROBING_PRUNE_ALGO, - # MSE_CIRCUIT_TREE_PROBING_PRUNE_ALGO, - ] - task_prune_scores = run_prune_algos(TASKS, PRUNE_ALGOS) -if load_prune_scores: - # 2000 epoch IOI Docstring tensor prune_scores post-kv-cache-fix - # batch_size=128, batch_count=2, default seed (for both) - filename = "task-prune-scores-16-02-2024_23-27-49.pkl" - - # 1000 epoch Sport Players tensor prune_scores post-kv-cache-fix - # batch_size=(10, 20), batch_count=(10, 5), default seed - # filename = "task-prune-scores-16-02-2024_22-22-43.pkl" - - loaded_cache = load_cache(cache_folder_name, filename) - task_prune_scores = {k: v | task_prune_scores[k] for k, v in loaded_cache.items()} -if save_prune_scores: - base_filename = "task-prune-scores" - save_cache(task_prune_scores, cache_folder_name, base_filename) - -for task, algo_prune_scores in task_prune_scores.items(): - for algo, prune_scores in algo_prune_scores.items(): - for module_name, scores in prune_scores.items(): - # Convert dtype to float32 - task_prune_scores[task][algo][module_name] = scores.float() - -# docstring_key = DOCSTRING_TOKEN_CIRCUIT_TASK.key -# task_prune_scores = {docstring_key: task_prune_scores[docstring_key]} - -# -------------------------------- Draw Circuit Graphs --------------------------------- - -if False: - for task_key, algo_prune_scores in task_prune_scores.items(): - # if not task_key.startswith("Docstring"): - # continue - task = TASK_DICT[task_key] - if ( - task.key != SPORTS_PLAYERS_TOKEN_CIRCUIT_TASK.key - or task.true_edge_count is None - ): - continue - for algo_key, ps in algo_prune_scores.items(): - algo = PRUNE_ALGO_DICT[algo_key] - # keys = [GROUND_TRUTH_PRUNE_ALGO.key, CIRCUIT_TREE_PROBING_PRUNE_ALGO.key] - keys = [GROUND_TRUTH_PRUNE_ALGO.key] - if algo_key not in keys: - continue - th = prune_scores_threshold(ps, task.true_edge_count) - circ_edges = dict([(d, (m.abs() >= th).float()) for d, m in ps.items()]) - print("circ_edge_count", sum([m.sum() for m in circ_edges.values()])) - circ = dict( - [(d, t.where(m.abs() >= th, m, t.zeros_like(m))) for d, m in ps.items()] - ) - print("task:", task.name, "algo:", algo.name) - draw_seq_graph( - model=task.model, - prune_scores=circ, - seq_labels=task.test_loader.seq_labels, - show_all_edges=False, - ) - -# ------------------------------ Prune Scores Similarity ------------------------------- - -if True: - prune_scores_similartity_fig = prune_score_similarities_plotly( - task_prune_scores, [], ground_truths=True - ) - figs.append(prune_scores_similartity_fig) - -# ------------------------------------ Completeness ------------------------------------ - -compute_task_completeness_scores = False -save_task_completeness_scores = False -load_task_completeness_scores = False -completeness_prune_scores: TaskPruneScores = {} - -faithfulness_target = "kl_div" -if compute_task_completeness_scores: - completeness_prune_scores: TaskPruneScores = train_same_under_knockouts( - task_prune_scores, - algo_keys=["Official Circuit", "Random"], - learning_rate=0.02, - epochs=300, - regularize_lambda=0, - faithfulness_target=faithfulness_target, - ) - -cache_folder_name = ".completeness_scores" -if save_task_completeness_scores: - base_filename = f"{faithfulness_target}-task-completeness-prune-scores" - save_cache(completeness_prune_scores, cache_folder_name, base_filename) - -if load_task_completeness_scores: - # for task-prune-scores-16-02-2024_23-27-49.pkl (IOI Docstring 2000 epochs) - # IOI Docstring 100 epoch KL completeness - filename = "task-completeness-prune-scores-20-02-2024_19-15-29.pkl" - # IOI Docstring 300 epoch KL completeness - filename = "task-completeness-prune-scores-23-02-2024_19-54-47.pkl" - - # for task-prune-scores-16-02-2024_22-22-43.pkl (Sports Players 1000 epochs) - # Sports Players 100 epoch completeness - # filename = "task-completeness-prune-scores-20-02-2024_22-55-47.pkl" - completeness_prune_scores = load_cache(cache_folder_name, filename) - -if completeness_prune_scores: - task_completeness_scores: TaskCompletenessScores = measure_same_under_knockouts( - circuit_ps=task_prune_scores, - knockout_ps=completeness_prune_scores, - ) - completeness_fig = same_under_knockouts_fig(task_completeness_scores) - figs.append(completeness_fig) - -# ---------------------------------------- ROC ----------------------------------------- - -if False: - roc_measurements: TaskMeasurements = measure_roc(task_prune_scores) - roc_fig = roc_plot(roc_measurements) - figs.append(roc_fig) - -# ----------------------------- Prune Metric Measurements ------------------------------ - -compute_prune_metric_measurements = True -save_prune_metric_measurements = False -load_prune_metric_measurements = False - -cache_folder_name = ".measurement_cache" -prune_metric_measurements = None -if compute_prune_metric_measurements: - ABLATION_TYPES: List[AblationType] = [ - AblationType.RESAMPLE, # Docstring - AblationType.TOKENWISE_MEAN_CORRUPT, # IOI and Sports Players - ] - PRUNE_METRICS: List[PruneMetric] = [ - CLEAN_KL_DIV_METRIC, - # CORRUPT_KL_DIV_METRIC, - ANSWER_PROB_METRIC, - # ANSWER_LOGIT_METRIC, - # WRONG_ANSWER_LOGIT_METRIC, - # LOGIT_DIFF_METRIC, - CORRECT_ANSWER_PERCENT_METRIC, # Docstring - LOGIT_DIFF_PERCENT_METRIC, # IOI - CORRECT_ANSWER_GREATER_THAN_INCORRECT_PERCENT_METRIC, # Sports Players - ] - prune_metric_measurements = measure_prune_metrics( - ablation_types=ABLATION_TYPES, - metrics=PRUNE_METRICS, - task_prune_scores=task_prune_scores, - patch_type=PatchType.TREE_PATCH, - reverse_clean_corrupt=False, - ) - if save_prune_metric_measurements: - assert prune_metric_measurements is not None - base_filename = "seq-circuit" - save_cache(prune_metric_measurements, cache_folder_name, base_filename) -if load_prune_metric_measurements: - # 2000 epoch IOI Docstring tensor prune_scores post-kv-cache-fix - filename = "seq-circuit-22-02-2024_20-14-25.pkl" - # filename = "seq-circuit-18-02-2024_17-20-57.pkl" - - # 1000 epoch Sport Players tensor prune_scores post-kv-cache-fix - # batch_size=(10, 20), batch_count=(10, 5), default seed - # filename="seq-circuit-18-02-2024_22-09-26.pkl" - - prune_metric_measurements = load_cache(cache_folder_name, filename) - -if prune_metric_measurements is not None: - figs += list(measurement_figs(prune_metric_measurements)) - -# -------------------------------------- Figures --------------------------------------- - -for i, fig in enumerate(figs): - fig.show() - folder: Path = repo_path_to_abs_path("figures-12") - # Save figure as pdf in figures folder - # fig.write_image(str(folder / f"new {i}.pdf")) - -#%% diff --git a/auto_circuit/language_rotations.py b/auto_circuit/language_rotations.py deleted file mode 100644 index 7c9132d..0000000 --- a/auto_circuit/language_rotations.py +++ /dev/null @@ -1,354 +0,0 @@ -#%% -from collections import defaultdict -from typing import Tuple - -import plotly.express as px -import torch as t -import transformer_lens as tl -from einops import einsum -from word2word import Word2word - -from auto_circuit.utils.custom_tqdm import tqdm -from auto_circuit.utils.misc import ( - get_most_similar_embeddings, - remove_hooks, - repo_path_to_abs_path, -) - -#%% - -model = tl.HookedTransformer.from_pretrained_no_processing("bloom-3b") -device = model.cfg.device -#%% - -en2fr = Word2word("en", "fr") -en2es = Word2word("en", "es") -n_toks = model.cfg.d_vocab_out -print("n_toks:", n_toks) -en_toks, fr_toks, es_toks = [], [], [] -en_strs, fr_strs, es_strs = [], [], [] -for tok in range(n_toks): - en_tok_str = model.to_string([tok]) - assert type(en_tok_str) == str - if len(en_tok_str) < 7: - continue - if en_tok_str[0] != " ": - continue - try: - fr_tok_str = " " + en2fr(en_tok_str[1:], n_best=1)[0] - # es_tok_str = " " + en2es(en_tok_str[1:], n_best=1)[0] - except Exception: - continue - # if en_tok_str.lower() == fr_tok_str.lower() - # or en_tok_str.lower() == es_tok_str.lower(): - if en_tok_str.lower() == fr_tok_str.lower(): - # if en_tok_str.lower() == es_tok_str.lower(): - continue - try: - fr_tok = model.to_single_token(fr_tok_str) - # es_tok = model.to_single_token(es_tok_str) - except Exception: - continue - en_toks.append(tok) - fr_toks.append(fr_tok) - # es_toks.append(es_tok) - en_strs.append(en_tok_str) - fr_strs.append(fr_tok_str) - # es_strs.append(es_tok_str) - -en_toks = t.tensor(en_toks, device=device) -print(en_toks.shape) -fr_toks = t.tensor(fr_toks, device=device) -es_toks = t.tensor(es_toks, device=device) -#%% -d_model = model.cfg.d_model -# en_embeds = t.nn.functional.layer_norm( -# model.embed.W_E[en_toks].detach().clone(), [d_model] -# ) -# fr_embeds = t.nn.functional.layer_norm( -# model.embed.W_E[fr_toks].detach().clone(), [d_model] -# ) -# es_embeds = t.nn.functional.layer_norm( -# model.embed.W_E[es_toks].detach().clone(), [d_model]) -en_embeds = model.embed.W_E[en_toks].detach().clone() -fr_embeds = model.embed.W_E[fr_toks].detach().clone() -# es_embeds = model.embed.W_E[es_toks].detach().clone() - -# dataset = t.utils.data.TensorDataset(en_embeds, fr_embeds, es_embeds) -dataset = t.utils.data.TensorDataset(en_embeds, fr_embeds) -# dataset = t.utils.data.TensorDataset(en_embeds, es_embeds) -train_set, test_set = t.utils.data.random_split(dataset, [0.99, 0.01]) -train_loader = t.utils.data.DataLoader(train_set, batch_size=512, shuffle=True) -test_loader = t.utils.data.DataLoader(test_set, batch_size=512, shuffle=True) - -#%% - -# translate = t.zeros([d_model], device=device, requires_grad=True) -# translate_2 = t.zeros([d_model], device=device, requires_grad=True) -learned_rotation = t.nn.Linear(d_model, d_model, bias=False, device=device) -linear_map = t.nn.utils.parametrizations.orthogonal(learned_rotation, "weight") -# optim = t.optim.Adam(list(learned_rotation.parameters()) + [translate], lr=0.0002) -# optim = t.optim.Adam(list(linear_map.parameters()) + [translate], lr=0.01) -optim = t.optim.Adam(list(learned_rotation.parameters()), lr=0.0002) -# optim = t.optim.Adam([translate], lr=0.0002) - - -def word_pred_from_embeds(embeds: t.Tensor, lerp: float = 1.0) -> t.Tensor: - # return learned_rotation(embeds + translate) - translate - return learned_rotation(embeds) - # return embeds + translate - - -def word_distance_metric(a: t.Tensor, b: t.Tensor) -> t.Tensor: - return -t.nn.functional.cosine_similarity(a, b) - # return (a - b) ** 2 - - -n_epochs = 50 -loss_history = [] -for epoch in (epoch_pbar := tqdm(range(n_epochs))): - for batch_idx, (en_embed, fr_embed) in enumerate(train_loader): - en_embed.to(device) - fr_embed.to(device) - - optim.zero_grad() - pred = word_pred_from_embeds(en_embed) - loss = word_distance_metric(pred, fr_embed).mean() - loss_history.append(loss.item()) - loss.backward() - optim.step() - epoch_pbar.set_description(f"Loss: {loss.item():.3f}") - -px.line(y=loss_history, title="Loss History").show() -# %% -cosine_sims = [] -for batch_idx, (en_embed, fr_embed) in enumerate(test_loader): - en_embed.to(device) - fr_embed.to(device) - pred = word_pred_from_embeds(en_embed) - cosine_sim = word_distance_metric(pred, fr_embed) - cosine_sims.append(cosine_sim) - -print("Test Accuracy:", t.cat(cosine_sims).mean().item()) - -correct_count = 0 -for batch_idx, (en_embed, fr_embed) in enumerate(test_loader): - en_embed.to(device) - fr_embed.to(device) - pred = word_pred_from_embeds(en_embed) - for i in range(30): - print() - print() - logits = einsum(en_embed[i], model.embed.W_E, "d_model, vocab d_model -> vocab") - en_str = model.to_single_str_token(logits.argmax().item()) # type: ignore - logits = einsum(fr_embed[i], model.embed.W_E, "d_model, vocab d_model -> vocab") - fr_str = model.to_single_str_token(logits.argmax().item()) # type: ignore - logits = einsum(pred[i], model.embed.W_E, "d_model, vocab d_model -> vocab") - pred_str = model.to_single_str_token(logits.argmax().item()) # type: ignore - if correct := (fr_str == pred_str): - correct_count += 1 - print("English:", en_str, "French:", fr_str) - print("English to French rotation", "✅" if correct else "❌") - get_most_similar_embeddings( - model, - pred[i], - top_k=4, - apply_embed=True, - ) -print() -print("Correct percentage:", correct_count / len(test_loader.dataset) * 100) -# %% -# -------------- GATHER FR EN EMBED DATA ---------------- -en_file = "/home/dev/europarl/europarl-v7.fr-en.en" -fr_file = "/home/dev/europarl/europarl-v7.fr-en.fr" -batch_size = 2 - -en_strs = [] -fr_strs = [] -# Read the first 5000 lines of the files (excluding the first line) -with open(en_file, "r") as f: - en_strs = [f.readline()[:-1] + " " + f.readline()[:-1] for _ in range(5001)][1:] -with open(fr_file, "r") as f: - fr_strs = [f.readline()[:-1] + " " + f.readline()[:-1] for _ in range(5001)][1:] - -model.tokenizer.padding_side = "right" # type: ignore -en_tknzd = model.tokenizer(en_strs, padding=True, return_tensors="pt") # type: ignore -fr_tknzd = model.tokenizer(fr_strs, padding=True, return_tensors="pt") # type: ignore -en_toks, en_attn_mask = en_tknzd["input_ids"], en_tknzd["attention_mask"] -fr_toks, fr_attn_mask = fr_tknzd["input_ids"], fr_tknzd["attention_mask"] - -dataset = t.utils.data.TensorDataset(en_toks, en_attn_mask, fr_toks, fr_attn_mask) -loader = t.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True) - -#%% -en_embeds, fr_embeds = defaultdict(list), defaultdict(list) -lyrs = [20, 25, 27, 29] -for en_batch, en_attn_mask, fr_batch, fr_attn_mask in tqdm(loader): - with t.inference_mode(): - _, en_cache = model.run_with_cache(en_batch, prepend_bos=True) - for lyr in lyrs: - resids = en_cache[f"blocks.{lyr}.hook_resid_pre"] - resids_flat = resids.flatten(start_dim=0, end_dim=1) - mask_flat = en_attn_mask.flatten(start_dim=0, end_dim=1) - en_embeds[lyr].append(resids_flat[mask_flat == 1].detach().clone().cpu()) - del en_cache - _, fr_cache = model.run_with_cache(fr_batch, prepend_bos=True) - for lyr in lyrs: - resids = fr_cache[f"blocks.{lyr}.hook_resid_pre"] - resids_flat = resids.flatten(start_dim=0, end_dim=1) - mask_flat = fr_attn_mask.flatten(start_dim=0, end_dim=1) - fr_embeds[lyr].append(resids_flat[mask_flat == 1].detach().clone().cpu()) - del fr_cache -# %% -en_resids = {lyr: t.cat(en_embeds[lyr]) for lyr in lyrs} -fr_resids = {lyr: t.cat(fr_embeds[lyr]) for lyr in lyrs} -# %% -cache_folder = repo_path_to_abs_path(".activation_cache") -filename_root = ( - f"europarl_v7_fr_en_double_prompt_all_toks-{model.cfg.model_name}-lyrs_{lyrs}" -) -# Save en_resids and fr_resids to cache with torch.save -t.save(en_resids, cache_folder / f"{filename_root}-en.pt") -t.save(fr_resids, cache_folder / f"{filename_root}-fr.pt") -# %% -# -------------- TRAIN FR EN EMBED ROTATION ---------------- -# train_en_resids = t.load("/home/dev/auto-circuit/.activation_cache/ -# europarl_v7_fr_en_double_prompt_final_tok-bloom-3b-lyrs_range(0, 30, 5)-train-en.pt") -# train_fr_resids = t.load("/home/dev/auto-circuit/.activation_cache/ -# europarl_v7_fr_en_double_prompt_final_tok-bloom-3b-lyrs_range(0, 30, 5)-train-fr.pt") -# test_en_resids = t.load("/home/dev/auto-circuit/.activation_cache/ -# europarl_v7_fr_en_double_prompt_final_tok-bloom-3b-lyrs_range(0, 30, 5)-test-en.pt") -# test_fr_resids = t.load("/home/dev/auto-circuit/.activation_cache/ -# europarl_v7_fr_en_double_prompt_final_tok-bloom-3b-lyrs_range(0, 30, 5)-test-fr.pt") -train_en_resids = t.load( - "/home/dev/auto-circuit/.activation_cache/europarl_v7_fr_en_double_prompt_all_toks" - + "-bloom-3b-lyrs_[20, 25, 27, 29]-en.pt" -) -train_fr_resids = t.load( - "/home/dev/auto-circuit/.activation_cache/europarl_v7_fr_en_double_prompt_all_toks" - + "-bloom-3b-lyrs_[20, 25, 27, 29]-fr.pt" -) -layer_idx = 29 -d_model = model.cfg.d_model -device = model.cfg.device -min_len = min(train_en_resids[layer_idx].shape[0], train_fr_resids[layer_idx].shape[0]) - -train_dataset = t.utils.data.TensorDataset( - # layer_norm(train_en_resids[layer_idx][:min_len], [d_model]), - # layer_norm(train_fr_resids[layer_idx][:min_len], [d_model]), - train_en_resids[layer_idx][:min_len], - train_fr_resids[layer_idx][:min_len], -) -train_loader = t.utils.data.DataLoader(train_dataset, batch_size=512, shuffle=True) -fr_to_en_mean_vec = ( - train_en_resids[layer_idx].mean(dim=0) - train_fr_resids[layer_idx].mean(dim=0) -).to(device) -del train_en_resids -del train_fr_resids -#%% - -# translate = t.zeros([d_model], device=device, requires_grad=True) -learned_rotation = t.nn.Linear(d_model, d_model, bias=False, device=device) -linear_map = t.nn.utils.parametrizations.orthogonal(learned_rotation, "weight") -# optim = t.optim.Adam(list(linear_map.parameters()) + [translate], lr=0.0002) -optim = t.optim.Adam(list(linear_map.parameters()), lr=0.0002) -# optim = t.optim.Adam(list(learned_rotation.parameters()) + [translate], lr=0.0002) -# optim = t.optim.Adam(list(learned_rotation.parameters()), lr=0.01) -# optim = t.optim.Adam([translate], lr=0.0002) - - -def pred_from_embeds(embeds: t.Tensor, lerp: float = 1.0) -> t.Tensor: - # return learned_rotation(embeds + translate) - translate - return learned_rotation(embeds) - # return embeds + translate - - -def distance_metric(a: t.Tensor, b: t.Tensor) -> t.Tensor: - return -t.nn.functional.cosine_similarity(a, b) - # return (a - b) ** 2 - - -n_epochs = 1 -loss_history = [] -for epoch in (epoch_pbar := tqdm(range(n_epochs))): - for batch_idx, (en_embed, fr_embed) in tqdm(enumerate(train_loader)): - en_embed = en_embed.to(device) - fr_embed = fr_embed.to(device) - - optim.zero_grad() - pred = pred_from_embeds(fr_embed) - loss = distance_metric(pred, en_embed).mean() - loss_history.append(loss.item()) - loss.backward() - optim.step() - epoch_pbar.set_description(f"Loss: {loss.item():.3f}") -px.line(y=loss_history, title="Loss History").show() -# %% -en_file = "/home/dev/europarl/europarl-v7.fr-en.en" -fr_file = "/home/dev/europarl/europarl-v7.fr-en.fr" - - -# define a pytorch forward hook function -def steering_hook( - module: t.nn.Module, input: Tuple[t.Tensor], output: t.Tensor -) -> t.Tensor: - prefix_toks, final_tok = input[0][:, :-1], input[0][:, -1] - # layernormed_final_tok = layer_norm(final_tok, [d_model]) - # rotated_final_tok = pred_from_embeds(layernormed_final_tok) - rotated_final_tok = pred_from_embeds(final_tok) - # rotated_final_tok = fr_to_en_mean_vec + layernormed_final_tok - # rotated_final_tok = fr_to_en_mean_vec + final_tok - # rotated_final_tok = t.zeros_like(rotated_final_tok) - out = t.cat([prefix_toks, rotated_final_tok.unsqueeze(1)], dim=1) - return out - - -test_en_strs = [] -test_fr_strs = [] -# Read the first 10000 lines of the files -with open(en_file, "r") as f: - for i in range(11001): - test_str = f.readline()[:-1] + " " + f.readline()[:-1] - if i > 10000: - test_en_strs.append(test_str) -with open(fr_file, "r") as f: - for i in range(11001): - test_str = f.readline()[:-1] + " " + f.readline()[:-1] - if i > 10000: - test_fr_strs.append(test_str) - -gen_length = 20 -for idx, (test_en_str, test_fr_str) in enumerate(zip(test_en_strs, test_fr_strs)): - print() - print("----------------------------------------------") - print("test_en_str:", test_en_str) - en_str_init_len = len(test_en_str) - logits = model(test_en_str, prepend_bos=True) - get_most_similar_embeddings(model, logits[0, -1], top_k=5) - for i in range(gen_length): - top_tok = model(test_en_str, prepend_bos=True)[:, -1].argmax(dim=-1) - top_tok_str = model.to_string(top_tok) - test_en_str += top_tok_str - print("result en str:", test_en_str[en_str_init_len:]) - print() - print("test_fr_str:", test_fr_str) - fr_str_init_len = len(test_fr_str) - with remove_hooks() as handles, t.inference_mode(): - handle = model.blocks[layer_idx].hook_resid_pre.register_forward_hook( - steering_hook - ) - handles.add(handle) - logits = model(test_fr_str, prepend_bos=True) - get_most_similar_embeddings(model, logits[0, -1], top_k=5) - for i in range(gen_length): - top_tok = model(test_fr_str, prepend_bos=True)[:, -1].argmax(dim=-1) - top_tok_str = model.to_string(top_tok) - test_fr_str += top_tok_str - print("result fr str:", test_fr_str[fr_str_init_len:]) - if idx > 5: - break - - -# %% -# FINDINGS -# See https://docs.google.com/document/d/1P_GDQb8L2rJBMtPJrm3gCmaOOvO2EtHtaIvP-HXMvWA diff --git a/auto_circuit/playground.py b/auto_circuit/playground.py deleted file mode 100644 index cebba10..0000000 --- a/auto_circuit/playground.py +++ /dev/null @@ -1,468 +0,0 @@ -#%% -import json -import os -from typing import Dict, Tuple - -import blobfile as bf -import plotly.express as px -import plotly.figure_factory as ff -import plotly.graph_objects as go -import torch as t -import torch.backends.mps -import transformer_lens as tl -from torch.nn.utils import parametrizations - -from auto_circuit.utils.misc import get_most_similar_embeddings - -# from auto_circuit.prune_functions.parameter_integrated_gradients import ( -# BaselineWeights, -# ) - - -def rotation_matrix( - x: t.Tensor, y: t.Tensor, lerp: float = 1.0 -) -> Tuple[t.Tensor, t.Tensor]: - # Based on: https://math.stackexchange.com/questions/598750/finding-the-rotation-matrix-in-n-dimensions - # Check that neither x or y is a zero vector - assert t.all(x != 0), "x is a zero vector" - assert t.all(y != 0), "y is a zero vector" - - assert x.device == y.device - device = x.device - - # Normalize x to get u - u = x / t.norm(x) - - # Project y onto the orthogonal complement of u and normalize to get v - v = y - t.dot(u, y) * u - v = v / t.norm(v) - - # Calculate cos(theta) and sin(theta) - cos_theta = t.dot(x, y) / (t.norm(x) * t.norm(y)) - sin_theta = t.sqrt(1 - cos_theta**2) - - # Interpolate between the identity matrix and the rotation matrix - if lerp != 1.0: - theta = t.atan2(sin_theta, cos_theta) - cos_theta = t.cos(theta * lerp) - sin_theta = t.sin(theta * lerp) - - # Rotation matrix in the plane spanned by u and v - R_theta = t.tensor([[cos_theta, -sin_theta], [sin_theta, cos_theta]], device=device) - - # Construct the full rotation matrix - uv = t.stack([u, v], dim=1) - identity = t.eye(len(x), device=device) - R = identity - t.outer(u, u) - t.outer(v, v) + uv @ R_theta @ uv.T - - return R, uv - - -# np.random.seed(0) -# torch.manual_seed(0) -# random.seed(0) -os.environ["TOKENIZERS_PARALLELISM"] = "False" - -device = t.device("cuda") if t.cuda.is_available() else t.device("cpu") -print("device", device) -model = tl.HookedTransformer.from_pretrained( - # "pythia-410m-deduped", - "gpt2", - fold_ln=True, - center_writing_weights=True, - center_unembed=False, - # "tiny-stories-2L-33M", - device=device - # "tiny-stories-33M", device=device -) -model.eval() -#%% -modelb16 = tl.HookedTransformer.from_pretrained( - # "pythia-410m-deduped", - "gpt2", - fold_ln=True, - center_writing_weights=True, - center_unembed=False, - # "tiny-stories-2L-33M", - device=device, - dtype="bfloat16" - # "tiny-stories-33M", device=device -) -model16 = tl.HookedTransformer.from_pretrained( - # "pythia-410m-deduped", - "gpt2", - fold_ln=True, - center_writing_weights=True, - center_unembed=False, - # "tiny-stories-2L-33M", - device=device, - dtype="float16" - # "tiny-stories-33M", device=device -) -model16.eval() -modelb16.eval() -#%% -test_prompt = "The sun rises in the" -tl.utils.test_prompt(test_prompt, "stone", model, top_k=5) -tl.utils.test_prompt(test_prompt, "stone", model16, top_k=5) -tl.utils.test_prompt(test_prompt, "stone", modelb16, top_k=5) -#%% - - -country_to_captial: Dict[str, str] = { - "country": "capital", - "France": "Paris", - "Hungary": "Budapest", - "China": "Beijing", - "Germany": "Berlin", - # "Italy": "Rome", - "Japan": "Tokyo", - # "Russia": "Moscow", - # 'Canada': 'Ottawa', - # 'Australia': 'Canberra', - # "Egypt": "Cairo", - # 'Turkey': 'Ankara', - # "Spain": "Madrid", - "Sweden": "Stockholm", - "Norway": "Oslo", - "Denmark": "Copenhagen", - "Finland": "Helsinki", - "Poland": "Warsaw", - "Indonesia": "Jakarta", - # "Thailand": "Bangkok", - "Cuba": "Havana", - "Chile": "Santiago", - "Greece": "Athens", - "Portugal": "Lisbon", - # "Austria": "Vienna", - # "Belgium": "Brussels", - "Philippines": "Manila", - "Peru": "Lima", - "Ireland": "Dublin", - "Israel": "Jerusalem", - # 'Switzerland': 'Bern', - "Netherlands": "Amsterdam", - "Singapore": "Singapore", # Interesting case - # "Pakistan": "Islamabad", - "Lebanon": "Beirut", -} -present_to_past: Dict[str, str] = { - "present": "past", - "is": "was", - "run": "ran", - "eat": "ate", - "drink": "drank", - "go": "went", - "see": "saw", - "hear": "heard", - "speak": "spoke", - "write": "wrote", - # "read": "read", - "do": "did", - "have": "had", - "give": "gave", - "take": "took", - "make": "made", - "know": "knew", - "think": "thought", - "find": "found", - "tell": "told", - "become": "became", - "leave": "left", - "feel": "felt", - # "put": "put", - "bring": "brought", - "begin": "began", - "keep": "kept", - "hold": "held", - "stand": "stood", - "play": "played", - "light": "lit", -} -male_to_female: Dict[str, str] = { - "male": "female", - "king": "queen", - "actor": "actress", - "brother": "sister", - "father": "mother", - "son": "daughter", - "nephew": "niece", - "uncle": "aunt", - "wizard": "witch", - "prince": "princess", - "husband": "wife", - "boy": "girl", - "man": "woman", - "hero": "heroine", - "lord": "lady", - "monk": "nun", - "groom": "bride", - "bull": "cow", - "god": "goddess", -} - -for k, v in country_to_captial.items(): - tl.utils.test_prompt(f"The capital of {k} is the city of", v, model, top_k=5) -#%% - -word_mapping = present_to_past -# for i, (k, v) in enumerate(word_mapping.items()): -# print("key:", model.to_str_tokens(" " + k, prepend_bos=False)) -# print("value:", model.to_str_tokens(" " + v, prepend_bos=False)) -key_toks = model.to_tokens( - [" " + s for s in word_mapping.keys()], prepend_bos=False -).squeeze() -val_toks = model.to_tokens( - [" " + s for s in word_mapping.values()], prepend_bos=False -).squeeze() - -# print("EMBEDDINGS") -# key_embeds = model.embed(key_toks).detach().clone() # [n_toks, embed_dim] -# val_embeds = model.embed(val_toks).detach().clone() # [n_toks, embed_dim] -# print("average key embedding norm:", key_embeds.norm(dim=-1).mean().item()) -# print("average val embedding norm:", val_embeds.norm(dim=-1).mean().item()) -# key_embeds = model.ln_final(model.embed(key_toks)).detach().clone() -# val_embeds = model.ln_final(model.embed(val_toks)).detach().clone() -# key_embeds = model.blocks[3].ln2(model.embed(key_toks)).detach().clone() -# val_embeds = model.blocks[3].ln2(model.embed(val_toks)).detach().clone() -# print("LAYERNORM NO BIAS NO SCALE") -key_embeds = ( - t.nn.functional.layer_norm(model.embed(key_toks), [model.cfg.d_model]) - .detach() - .clone() -) # [n_toks, embed_dim] -val_embeds = ( - t.nn.functional.layer_norm(model.embed(val_toks), [model.cfg.d_model]) - .detach() - .clone() -) # [n_toks, embed_dim] -# val_embeds = model.ln_final(model.embed(val_toks)).detach().clone() -print("average key embedding norm:", key_embeds.norm(dim=-1).mean().item()) -print("average val embedding norm:", val_embeds.norm(dim=-1).mean().item()) - -idxs = torch.randperm(len(word_mapping)) -train_len = key_embeds.shape[0] - 6 -concept_key_embed, train_key_embeds, test_key_embeds = ( - key_embeds[0], - key_embeds[idxs[1 : 1 + train_len]], - key_embeds[idxs[1 + train_len :]], -) -concept_val_embed, train_val_embeds, test_val_embeds = ( - val_embeds[0], - val_embeds[idxs[1 : 1 + train_len]], - val_embeds[idxs[1 + train_len :]], -) -# ln_final_bias = model.ln_final.b.detach().clone() - -# linear_map = t.zeros( -# [val_embeds.shape[1], key_embeds.shape[1]], device=device, requires_grad=True -# ) -translate = t.zeros([key_embeds.shape[1]], device=device, requires_grad=True) -scale = t.ones([key_embeds.shape[1]], device=device, requires_grad=True) -translate_2 = t.zeros([key_embeds.shape[1]], device=device, requires_grad=True) -rotation_vectors = t.rand([2, key_embeds.shape[1]], device=device, requires_grad=True) - -learned_rotation = t.nn.Linear( - key_embeds.shape[1], key_embeds.shape[1], bias=False, device=device -) - -linear_map = parametrizations.orthogonal(learned_rotation, "weight") - -# optim = t.optim.Adam(linear_map.parameters(), lr=0.01) -optim = t.optim.Adam(list(linear_map.parameters()) + [translate], lr=0.01) -# optim = t.optim.Adam([translate], lr=0.01) -# optim = t.optim.Adam([linear_map, translate], lr=0.01) -# optim = t.optim.Adam([rotation_vectors], lr=0.01) -# optim = t.optim.Adam([rotation_vectors, translate], lr=0.01) -# optim = t.optim.Adam([rotation_vectors, translate, translate_2], lr=0.01) -# optim = t.optim.Adam([linear_map, translate, translate_2], lr=0.01) -# optim = t.optim.Adam([rotation_vectors, translate, scale], lr=0.01) -# optim = t.optim.Adam([translate, scale], lr=0.01) -# optim = t.optim.Adam([scale], lr=0.01) - - -def pred_from_embeds(embeds: t.Tensor, lerp: float = 1.0) -> t.Tensor: - # linear_map, proj = rotation_matrix( - # rotation_vectors[0], rotation_vectors[1], lerp=lerp - # ) - # pred = learned_rotation(embeds) - pred = learned_rotation(embeds + translate) - translate - # pred = embeds @ linear_map - # pred = embeds + (translate * lerp) - # pred = embeds * scale - # pred = (embeds @ linear_map) + translate - # pred = ((embeds + translate_2) @ linear_map) + translate - # pred = ((embeds + translate) @ linear_map) - translate - # pred = ((embeds - ln_final_bias) @ linear_map) + ln_final_bias - # pred = ((embeds * scale) + translate) @ linear_map - # pred = ((embeds @ linear_map)* scale) + translate - # pred = (embeds * scale) + translate - return pred - - -def loss_fn(pred: t.Tensor, target: t.Tensor) -> t.Tensor: - # loss = (target - pred).pow(2).mean() - # loss = 1 - t.nn.functional.cosine_similarity(pred, target).mean() - loss = 1 - t.nn.functional.cosine_similarity(pred, target).mean() - return loss - - -losses = [] -for epoch in range(1000): - optim.zero_grad() - pred = pred_from_embeds(train_key_embeds) - loss = loss_fn(pred, train_val_embeds) - loss.backward() - optim.step() - losses.append(loss.item()) if epoch % 10 == 0 else None - -px.line(y=losses).show() - -# linear_map = rotation_matrix(rotation_vectors[0], rotation_vectors[1]) # type: ignore -test_pred = pred_from_embeds(test_key_embeds) -print("Test loss:", loss_fn(test_pred, test_val_embeds).item()) - -print("Train data example") -get_most_similar_embeddings( - model, train_key_embeds[0], top_k=5, apply_ln_final=False, apply_unembed=True -) -print() -train_pred_0 = pred_from_embeds(train_key_embeds[0]) -get_most_similar_embeddings( - model, train_pred_0, top_k=5, apply_ln_final=False, apply_unembed=True -) - -for i in range(5): - print("Test data example") - get_most_similar_embeddings( - model, test_key_embeds[i], top_k=5, apply_ln_final=False, apply_unembed=True - ) - print() - test_pred_i = pred_from_embeds(test_key_embeds[i]) - get_most_similar_embeddings( - model, test_pred_i, top_k=5, apply_ln_final=False, apply_unembed=True - ) - -#%% -get_most_similar_embeddings(model, train_key_embeds[0], apply_embed=True) -#%% -linear_map, proj = rotation_matrix(rotation_vectors[0], rotation_vectors[1], lerp=1.0) -projected_keys = ((key_embeds[1:]) @ proj).detach().clone().cpu() -projected_vals = ((val_embeds[1:]) @ proj).detach().clone().cpu() -projected_translate = (translate @ proj).detach().clone().cpu() - -projected_rotated_keys = ( - ((((translate + key_embeds[1:]) @ linear_map) - translate) @ proj) - .detach() - .clone() - .cpu() -) - -keys_x, keys_y = projected_keys[:, 0], projected_keys[:, 1] -vals_x, vals_y = projected_vals[:, 0], projected_vals[:, 1] -preds_x, preds_y = projected_rotated_keys[:, 0], projected_rotated_keys[:, 1] - -# Create a scatter plot -fig = go.Figure() - -# Adding scatter plot for keys -fig.add_trace( - go.Scatter( - x=keys_x, - y=keys_y, - mode="markers+text", - name="Keys", - text=list(word_mapping.keys())[1:], - marker=dict(size=10, color="blue"), - ) -) - -# Adding scatter plot for values -fig.add_trace( - go.Scatter( - x=vals_x, - y=vals_y, - mode="markers+text", - name="Vals", - text=list(word_mapping.values())[1:], - marker=dict(size=10, color="red"), - ) -) - -# Adding scatter plot for predictions -# fig.add_trace(go.Scatter(x=preds_x, y=preds_y, -# mode='markers+text', -# name='Predictions', -# text=list(word_mapping.values())[1:], -# marker=dict(size=10, color='orange'))) - - -# Adding lines connecting keys and values -for i in range(keys_x.shape[0]): - fig.add_trace( - go.Scatter( - x=[keys_x[i], vals_x[i]], - y=[keys_y[i], vals_y[i]], - mode="lines", - line=dict(color="grey", width=1), - showlegend=False, - ) - ) - -# # Adding lines connecting keys and predictions -# for i in range(keys_x.shape[0]): -# fig.add_trace(go.Scatter(x=[keys_x[i], preds_x[i]], -# y=[keys_y[i], preds_y[i]], -# mode='lines', -# line=dict(color='green', width=1), -# showlegend=False)) - -# fig.add_trace(go.Scatter(x=[projected_translate[0]], -# y=[projected_translate[1]], -# mode='markers+text', -# name='Translate', -# text=["Translate"], -# marker=dict(size=10, color='green'))) - -# Update layout for a better look -fig.update_layout( - title="2D Scatter plot of Keys and Vals with Connections", - xaxis_title="Dimension 1", - yaxis_title="Dimension 2", - legend_title="Legend", -) - -# Show plot -fig.show() - -# print("Interpolated test data example") -# get_most_similar_embeddings(model, test_key_embeds[0], top_k=3) -# for lerp in t.linspace(0, 1, 10): -# print(f"lerp: {lerp.item():.2f}") -# test_pred_i = pred_from_embeds(test_key_embeds[0], lerp=lerp.item()) -# get_most_similar_embeddings(model, test_pred_i, top_k=10) - -#%% -# Resid_delta_mlp, Layer 3 are the most interpretable -layer_index = 6 # in range(12) -autoencoder_input = ["mlp_post_act", "resid_delta_mlp"][0] -feature_idx = 25890 -base_url = "az://openaipublic/sparse-autoencoder/gpt2-small/" -weight_file = f"{autoencoder_input}/autoencoders/{layer_index}.pt" -feat_file = f"{autoencoder_input}/collated_activations/{layer_index}/{feature_idx}.json" -url_to_get = base_url + feat_file - -bf.stat(url_to_get) - -with bf.BlobFile(url_to_get, mode="rb") as f: - data = json.load(f) - -examples = data["most_positive_activation_records"] -# examples = data["random_sample"] -cell_vals = [d["activations"] for d in examples] -text = [d["tokens"] for d in examples] -fig = ff.create_annotated_heatmap(cell_vals, annotation_text=text, colorscale="Viridis") -prompt_len = min([len(acts) for acts in cell_vals]) -fig.layout.width = 75 * prompt_len # type: ignore -fig.layout.height = 32 * len(cell_vals) # type: ignore -fig.show() diff --git a/auto_circuit/pythia-2_8b-sports.py b/auto_circuit/pythia-2_8b-sports.py deleted file mode 100644 index 287d429..0000000 --- a/auto_circuit/pythia-2_8b-sports.py +++ /dev/null @@ -1,192 +0,0 @@ -#%% -from collections import defaultdict - -import plotly.express as px -import plotly.graph_objects as go -import torch as t -import transformer_lens as tl - -from auto_circuit.utils.custom_tqdm import tqdm -from auto_circuit.utils.misc import repo_path_to_abs_path - -MODEL_NAME = "pythia-2.8b-deduped" -device = "cuda" if t.cuda.is_available() else "cpu" -model = tl.HookedTransformer.from_pretrained(MODEL_NAME, device=device) - - -def read_players(filename: str): - filepath = repo_path_to_abs_path(f"datasets/sports-players/{filename}") - with open(filepath, "r") as file: - players = file.readlines() - return [player.strip() for player in players] - - -football_players = read_players("american-football-players.txt") -basketball_players = read_players("basketball-players.txt") -baseball_players = read_players("baseball-players.txt") - -bos = model.tokenizer.bos_token # type: ignore -template = ( - bos + "Fact: Tiger Woods plays the sport of golf\nFact: {} plays the sport of" -) - -football_prompts = [template.format(player) for player in football_players] -basketball_prompts = [template.format(player) for player in basketball_players] -baseball_prompts = [template.format(player) for player in baseball_players] - -football_valid_players, basketball_valid_players, baseball_valid_players = [], [], [] -ans_toks = [] - -for prompts, players, answer, valid_players in [ - (football_prompts, football_players, " football", football_valid_players), - (basketball_prompts, baseball_players, " basketball", basketball_valid_players), - (baseball_prompts, baseball_players, " baseball", baseball_valid_players), -]: - ans_tok = model.to_tokens(answer, padding_side="left", prepend_bos=False)[0][0] - ans_toks.append(ans_tok) - model.tokenizer.padding_side = "left" # type: ignore - out = model.tokenizer(prompts, padding=True, return_tensors="pt") # type: ignore - prompt_tokens = out.input_ids.to(device) - attn_mask = out.attention_mask.to(device) - print("prompt_tokens.shape", prompt_tokens.shape) - - with t.inference_mode(): - logits = model(prompt_tokens, attention_mask=attn_mask)[:, -1] - probs = t.softmax(logits, dim=-1) - correct_answer_idxs = t.where(probs[:, ans_tok] > 0.5)[0] - - correct_answer_names = [players[i.item()] for i in correct_answer_idxs] - print("correct_answer_names", correct_answer_names) - valid_players.extend(correct_answer_names) - -#%% - -min_name_count = min( - len(football_valid_players), - len(basketball_valid_players), - len(baseball_valid_players), -) -resids = defaultdict(list) -mean_acts = [] -for layer in range(19): - mean = [] - for sport, correct_players in [ - ("Football", football_valid_players), - ("Basketball", basketball_valid_players), - ("Baseball", baseball_valid_players), - ]: - player_prompts = [bos + name for name in correct_players] - # template = bos + "Fact: Tiger Woods plays the sport of golf\nFact: {}" - player_prompts = [template.format(name) for name in correct_players] - model.tokenizer.padding_side = "left" # type: ignore - out = model.tokenizer( - player_prompts, padding=True, return_tensors="pt" - ) # type: ignore - prompt_tokens = out.input_ids.to(device) - attn_mask = out.attention_mask.to(device) - - with t.inference_mode(): - logits = model( - prompt_tokens, - attention_mask=attn_mask, - stop_at_layer=layer + 1, # stop_at_layer is exclusive - ) - mean.append(logits[:min_name_count, -1].mean(dim=0)) - resids[sport].append(logits[:, -1]) - mean_acts.append(t.stack(mean).mean(dim=0)) - - -ans_toks = t.stack(ans_toks) if isinstance(ans_toks, list) else ans_toks -#%% - -ans_embeds = model.unembed.W_U[:, ans_toks] -probe = model.blocks[16].attn.W_V[20] @ model.blocks[16].attn.W_O[20] @ ans_embeds - -for ans_idx, (sport, sport_resids) in enumerate(resids.items()): - layers = list(range(2, 19)) - correct = [] - football_avgs, basketball_avgs, baseball_avgs = [], [], [] - for layer in layers: - embed = sport_resids[layer].detach().clone() - mean_act = mean_acts[layer].detach().clone() - with t.inference_mode(): - # x_normed = model.blocks[layer].ln2(x).detach().clone() - # x = x + model.blocks[layer].mlp(x_normed.unsqueeze(1)).squeeze(1) - probe_out = (embed - mean_act) @ probe - correct.append((probe_out.argmax(dim=-1) == ans_idx).float().mean().item()) - probs = t.nn.functional.softmax(probe_out, dim=-1) - mean_probs = probs.mean(dim=0) - football_avgs.append(mean_probs[0].item()) - basketball_avgs.append(mean_probs[1].item()) - baseball_avgs.append(mean_probs[2].item()) - - # Plot layers vs. average probability of each sport - fig = go.Figure() - fig.add_trace(go.Scatter(x=layers, y=football_avgs, name="Football")) - fig.add_trace(go.Scatter(x=layers, y=basketball_avgs, name="Basketball")) - fig.add_trace(go.Scatter(x=layers, y=baseball_avgs, name="Baseball")) - fig.add_trace(go.Scatter(x=layers, y=correct, name="Correct")) - fig.update_layout( - title="Average Probability of Each Sport", - xaxis_title="Layer", - yaxis_title="Probability", - ) - fig.show() - -#%% - -layer_2_resids = t.cat([r[2].detach().clone() for r in resids.values()]) -correct_dirs = t.cat( - [ - probe[:, ans_idx].detach().clone().repeat(len(r[2]), 1) - for ans_idx, r in enumerate(resids.values()) - ] -) -answer_idxs = t.cat( - [t.tensor([ans_idx] * len(r[2])) for ans_idx, r in enumerate(resids.values())] -) -d_model = model.cfg.d_model - -dataset = t.utils.data.TensorDataset(layer_2_resids, correct_dirs, answer_idxs) -train_set, test_set = t.utils.data.random_split(dataset, [0.9, 0.1]) -train_loader = t.utils.data.DataLoader(train_set, batch_size=32, shuffle=True) -test_loader = t.utils.data.DataLoader(test_set, batch_size=32, shuffle=True) - -translate = t.zeros([d_model], device=device, requires_grad=True) -translate_2 = t.zeros([d_model], device=device, requires_grad=True) -learned_rotation = t.nn.Linear(d_model, d_model, bias=False, device=device) -linear_map = t.nn.utils.parametrizations.orthogonal(learned_rotation, "weight") -optim = t.optim.Adam(list(linear_map.parameters()) + [translate, translate_2], lr=0.01) - - -def pred_from_embeds(embeds: t.Tensor, lerp: float = 1.0) -> t.Tensor: - return learned_rotation(embeds + translate) + translate_2 - - -n_epochs = 2000 -loss_history = [] -for epoch in (epoch_pbar := tqdm(range(n_epochs))): - for batch_idx, (resid, correct_dir, _) in enumerate(train_loader): - resid = resid.to(device) - correct_dir = correct_dir.to(device) - - optim.zero_grad() - pred = pred_from_embeds(resid) - loss = -t.nn.functional.cosine_similarity(pred, correct_dir).mean() - loss_history.append(loss.item()) - loss.backward() - optim.step() - epoch_pbar.set_description(f"Loss: {loss.item():.3f}") - -px.line(y=loss_history, title="Loss History").show() -#%% -test_corrects = [] -for batch_idx, (resid, _, answer_idx) in enumerate(test_loader): - resid = resid.to(device) - with t.inference_mode(): - # pred = pred_from_embeds(resid) - probe_out = resid @ probe - pred_ans = probe_out.argmax(dim=-1) - test_corrects.append((pred_ans == answer_idx.to(device)).float()) - -print("Test Accuracy:", t.cat(test_corrects).mean().item()) diff --git a/auto_circuit/tracr_experiments.py b/auto_circuit/tracr_experiments.py deleted file mode 100644 index c02a7bf..0000000 --- a/auto_circuit/tracr_experiments.py +++ /dev/null @@ -1,101 +0,0 @@ -#%% -from collections import defaultdict -from pathlib import Path -from typing import List - -from auto_circuit.metrics.official_circuits.measure_roc import measure_roc -from auto_circuit.metrics.official_circuits.roc_plot import task_roc_plot -from auto_circuit.prune_algos.prune_algos import ( - ACDC_PRUNE_ALGO, - GROUND_TRUTH_PRUNE_ALGO, - LOGIT_DIFF_GRAD_PRUNE_ALGO, - LOGIT_MSE_GRAD_PRUNE_ALGO, - MSE_ACDC_PRUNE_ALGO, - MSE_CIRCUIT_TREE_PROBING_PRUNE_ALGO, - MSE_SUBNETWORK_TREE_PROBING_PRUNE_ALGO, - PRUNE_ALGO_DICT, - RANDOM_PRUNE_ALGO, - SUBNETWORK_TREE_PROBING_PRUNE_ALGO, - PruneAlgo, - run_prune_algos, -) -from auto_circuit.tasks import ( - TASK_DICT, - TRACR_REVERSE_TOKEN_CIRCUIT_TASK, - TRACR_XPROPORTION_TOKEN_CIRCUIT_TASK, -) -from auto_circuit.types import ( - TaskMeasurements, - TaskPruneScores, -) -from auto_circuit.utils.misc import load_cache, repo_path_to_abs_path, save_cache -from auto_circuit.visualize import draw_seq_graph - -# ------------------------------------ Prune Scores ------------------------------------ -compute_prune_scores = True -save_prune_scores = False -load_prune_scores = False - -task_prune_scores: TaskPruneScores = defaultdict(dict) -cache_folder_name = ".prune_scores_cache" -if compute_prune_scores: - REVERSE_ALGOS: List[PruneAlgo] = [ - GROUND_TRUTH_PRUNE_ALGO, - RANDOM_PRUNE_ALGO, - ACDC_PRUNE_ALGO, - LOGIT_DIFF_GRAD_PRUNE_ALGO, # Fast implementation of Edge Attribution Patching - SUBNETWORK_TREE_PROBING_PRUNE_ALGO, - ] - reverse_ps = run_prune_algos([TRACR_REVERSE_TOKEN_CIRCUIT_TASK], REVERSE_ALGOS) - - XPROPORTION_ALGOS: List[PruneAlgo] = [ - GROUND_TRUTH_PRUNE_ALGO, - RANDOM_PRUNE_ALGO, - MSE_ACDC_PRUNE_ALGO, - LOGIT_MSE_GRAD_PRUNE_ALGO, # Fast implementation of EAP with MSE loss - # SUBNETWORK_EDGE_PROBING_PRUNE_ALGO, - MSE_SUBNETWORK_TREE_PROBING_PRUNE_ALGO, - MSE_CIRCUIT_TREE_PROBING_PRUNE_ALGO, - ] - xprop_ps = run_prune_algos( - [TRACR_XPROPORTION_TOKEN_CIRCUIT_TASK], XPROPORTION_ALGOS - ) - # task_prune_scores: TaskPruneScores = xprop_ps - task_prune_scores: TaskPruneScores = reverse_ps | xprop_ps -if load_prune_scores: - filename = "tracr-task-prune-scores-23-02-2024_00-13-23.pkl" - loaded_cache = load_cache(cache_folder_name, filename) - task_prune_scores = {k: v | task_prune_scores[k] for k, v in loaded_cache.items()} -if save_prune_scores: - base_filename = "tracr-task-prune-scores" - save_cache(task_prune_scores, cache_folder_name, base_filename) - - -# -------------------------------- Draw Circuit Graphs --------------------------------- -if False: - for task_key, algo_prune_scores in task_prune_scores.items(): - task = TASK_DICT[task_key] - for algo_key, prune_scores in algo_prune_scores.items(): - algo = PRUNE_ALGO_DICT[algo_key] - if not algo == LOGIT_MSE_GRAD_PRUNE_ALGO: - continue - print("task:", task.name, "algo:", algo.name) - draw_seq_graph( - model=task.model, - prune_scores=prune_scores, - seq_labels=task.test_loader.seq_labels, - show_all_edges=False, - ) - break - break - -# ---------------------------------------- ROC ----------------------------------------- - -roc_measurements: TaskMeasurements = measure_roc(task_prune_scores) -roc_fig = task_roc_plot(roc_measurements) -roc_fig.show() -# Save figure as pdf in figures folder -folder: Path = repo_path_to_abs_path("figures/figures-12") -roc_fig.write_image(str(folder / "tracr-roc.pdf")) -roc_fig.write_image(str(folder / "tracr-roc.svg")) -roc_fig.write_image(str(folder / "tracr-roc.png"), scale=4) diff --git a/auto_circuit/uniqueness_experiments.py b/auto_circuit/uniqueness_experiments.py deleted file mode 100644 index bf8fd4d..0000000 --- a/auto_circuit/uniqueness_experiments.py +++ /dev/null @@ -1,167 +0,0 @@ -#%% -from pathlib import Path -from typing import List - -import plotly.graph_objects as go -import torch as t - -from auto_circuit.metrics.official_circuits.measure_roc import measure_roc -from auto_circuit.metrics.official_circuits.roc_plot import task_roc_plot -from auto_circuit.metrics.prune_metrics.measure_prune_metrics import ( - measure_prune_metrics, - measurement_figs, -) -from auto_circuit.metrics.prune_metrics.prune_metrics import ( - ANSWER_PROB_METRIC, - LOGIT_DIFF_METRIC, -) -from auto_circuit.metrics.prune_scores_similarity import prune_score_similarities_plotly -from auto_circuit.prune_algos.prune_algos import ( - GROUND_TRUTH_PRUNE_ALGO, - OPPOSITE_TREE_PROBING_PRUNE_ALGO, - PRUNE_ALGO_DICT, - run_prune_algos, -) -from auto_circuit.tasks import ( - DOCSTRING_TOKEN_CIRCUIT_TASK, - IOI_TOKEN_CIRCUIT_TASK, - TASK_DICT, - Task, -) -from auto_circuit.types import ( - AblationType, - PatchType, - TaskMeasurements, - TaskPruneScores, -) -from auto_circuit.utils.misc import load_cache, repo_path_to_abs_path, save_cache -from auto_circuit.utils.tensor_ops import prune_scores_threshold -from auto_circuit.visualize import draw_seq_graph - -TASKS: List[Task] = [ - # Token Circuits - # SPORTS_PLAYERS_TOKEN_CIRCUIT_TASK, - IOI_TOKEN_CIRCUIT_TASK, - DOCSTRING_TOKEN_CIRCUIT_TASK, - # Component Circuits - # SPORTS_PLAYERS_COMPONENT_CIRCUIT_TASK, - # IOI_COMPONENT_CIRCUIT_TASK, - # DOCSTRING_COMPONENT_CIRCUIT_TASK, - # GREATERTHAN_COMPONENT_CIRCUIT_TASK, - # Autoencoder Component Circuits - # IOI_GPT2_AUTOENCODER_COMPONENT_CIRCUIT_TASK, - # GREATERTHAN_GPT2_AUTOENCODER_COMPONENT_CIRCUIT_TASK - # ANIMAL_DIET_GPT2_AUTOENCODER_COMPONENT_CIRCUIT_TASK, - # CAPITAL_CITIES_PYTHIA_70M_AUTOENCODER_COMPONENT_CIRCUIT_TASK, -] -figs: List[go.Figure] = [] - -# --------------------------------- Load Prune Scores ---------------------------------- - -# 2000 epoch IOI Docstring tensor prune_scores post-kv-cache-fix -# batch_size=128, batch_count=2, default seed (for both) -filename = "task-prune-scores-16-02-2024_23-27-49.pkl" - -# 1000 epoch Sport Players tensor prune_scores post-kv-cache-fix -# batch_size=(10, 20), batch_count=(10, 5), default seed -# filename = "task-prune-scores-16-02-2024_22-22-43.pkl" -cache_folder_name = ".prune_scores_cache" -task_prune_scores = load_cache(cache_folder_name, filename) - -# -------------------------------- Draw Circuit Graphs --------------------------------- - -if True: - for task_key, algo_prune_scores in task_prune_scores.items(): - # if not task_key.startswith("Docstring"): - # continue - task = TASK_DICT[task_key] - if task.key != IOI_TOKEN_CIRCUIT_TASK.key or task.true_edge_count is None: - continue - for algo_key, ps in algo_prune_scores.items(): - algo = PRUNE_ALGO_DICT[algo_key] - keys = [GROUND_TRUTH_PRUNE_ALGO.key] - if algo_key not in keys: - continue - th = prune_scores_threshold(ps, task.true_edge_count) - circ_edges = dict([(d, (m.abs() >= th).float()) for d, m in ps.items()]) - print("circ_edge_count", sum([m.sum() for m in circ_edges.values()])) - circ = dict( - [(d, t.where(m.abs() >= th, m, t.zeros_like(m))) for d, m in ps.items()] - ) - print("task:", task.name, "algo:", algo.name) - draw_seq_graph( - model=task.model, - prune_scores=circ, - seq_labels=task.test_loader.seq_labels, - show_all_edges=False, - ) - -# ------------------------------ Prune Scores Similarity ------------------------------- - -if True: - prune_scores_similartity_fig = prune_score_similarities_plotly( - task_prune_scores, [], ground_truths=True - ) - figs.append(prune_scores_similartity_fig) - -# ----------------------------- Opposite Task Prune Scores ----------------------------- - -compute_opposite_task_prune_scores = False -save_opposite_task_prune_scores = False -load_opposite_task_prune_scores = False -opposite_task_prune_scores: TaskPruneScores = {} -opposite_prune_scores_cache_folder_name = ".opposite_prune_scores_cache" -if compute_opposite_task_prune_scores: - opposite_task_prune_scores = run_prune_algos( - TASKS, [OPPOSITE_TREE_PROBING_PRUNE_ALGO] - ) -if save_opposite_task_prune_scores: - base_filename = "opposite-task-prune-scores" - save_cache( - opposite_task_prune_scores, - opposite_prune_scores_cache_folder_name, - base_filename, - ) -if load_opposite_task_prune_scores: - filename = "opposite-task-prune-scores-07-02-2024_17-34-33.pkl" - opposite_task_prune_scores = load_cache( - opposite_prune_scores_cache_folder_name, filename - ) -if opposite_task_prune_scores: - opposite_prune_metric_measurements = measure_prune_metrics( - [AblationType.RESAMPLE], - [ANSWER_PROB_METRIC, LOGIT_DIFF_METRIC], - opposite_task_prune_scores, - PatchType.TREE_PATCH, - ) - figs += list(measurement_figs(opposite_prune_metric_measurements, auc_plots=False)) - -# ---------------------------------------- ROC ----------------------------------------- - -compute_roc_measurements = False -save_roc_measurements = False -load_roc_measurements = False -roc_measurements: TaskMeasurements = {} -roc_cache_folder_name = ".roc_measurements" -if compute_roc_measurements: - roc_measurements: TaskMeasurements = measure_roc(task_prune_scores) -if save_roc_measurements: - base_filename = "roc-measurements" - save_cache(roc_measurements, cache_folder_name, base_filename) -if load_roc_measurements: - filename = "lala.pkl" - roc_measurements = load_cache(roc_cache_folder_name, filename) -if roc_measurements: - roc_fig = task_roc_plot(roc_measurements) - figs.append(roc_fig) - - -# -------------------------------------- Figures --------------------------------------- - -for i, fig in enumerate(figs): - fig.show() - folder: Path = repo_path_to_abs_path("figures-12") - # Save figure as pdf in figures folder - # fig.write_image(str(folder / f"new {i}.pdf")) - -#%%