-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathattn_intdecoding.py
81 lines (69 loc) · 2.48 KB
/
attn_intdecoding.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
# system imports
import time
import json
import itertools
import random
# external imports
from transformers import GPTNeoXForCausalLM, AutoModelForCausalLM, AutoTokenizer, OlmoForCausalLM
import torch
import numpy as np
from tqdm import tqdm
from datasets import load_dataset
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.ticker import PercentFormatter
# local imports
from src.pythia_intermediate_decoder import PythiaIntermediateDecoder
from src.olmo2_intermediate_decoder import Olmo2IntermediateDecoder
# enivornment setup
torch.manual_seed(42)
torch.cuda.manual_seed(42)
torch.mps.manual_seed(42)
# -------------------------Start of Script------------------------- #
# attempt to auto recognize the device!
device = "cpu"
if torch.cuda.is_available():
device = "cuda"
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
device = "mps"
print(f"using device {device}")
# model_id = "allenai/OLMo-7B-0724-hf"
# model_id = "EleutherAI/pythia-160m-deduped"
model_id = "allenai/OLMo-2-1124-7B-Instruct"
print(f"Loading {model_id}...")
if "pythia" in model_id:
int_decoder = PythiaIntermediateDecoder(model_id=model_id)
elif "OLMo" in model_id:
int_decoder = Olmo2IntermediateDecoder(model_id=model_id)
else:
raise TypeError("Could not recognise model type and associated intermediate decoder")
dataset = load_dataset("lighteval/MATH")["train"]
print(dataset)
n = 5
first_n = dataset[:n]
num_layers = len(int_decoder.model.base_model.layers)
all_probabilities = []
avg_probabilities = {i: [] for i in range(num_layers)}
plotted_counter = 0
for i in tqdm(range(n), dynamic_ncols=True):
question = first_n["problem"][i]
chat = [
{"role": "user", "content": question}
]
question = int_decoder.tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
answer = first_n["solution"][i]
prompt = f"{question}"
int_decoder.reset_all()
block_activations = int_decoder.decode_all_layers(prompt,
topk=5,
printing=False,
print_attn_mech=False,
print_intermediate_res=False,
print_mlp=False,
print_block=True
)
block_numbers = []
probabilities = []
for block_activation in block_activations:
block_num = int(block_activation[0].split()[1])
token_probs = dict(block_activation[1])