-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlogit_lens_2digit_intdecoding.py
109 lines (93 loc) · 3.78 KB
/
logit_lens_2digit_intdecoding.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
# system imports
import time
import json
import itertools
import random
# external imports
from transformers import GPTNeoXForCausalLM, AutoModelForCausalLM, AutoTokenizer, OlmoForCausalLM
import torch
import numpy as np
from tqdm import tqdm
from datasets import load_dataset
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.ticker import PercentFormatter
# local imports
from src.pythia_intermediate_decoder import PythiaIntermediateDecoder
from src.olmo_intermediate_decoder import OlmoIntermediateDecoder
# enivornment setup
torch.manual_seed(42)
torch.cuda.manual_seed(42)
torch.mps.manual_seed(42)
# -------------------------Start of Script------------------------- #
# attempt to auto recognize the device!
device = "cpu"
if torch.cuda.is_available():
device = "cuda"
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
device = "mps"
print(f"using device {device}")
# model_id = "allenai/OLMo-7B-0724-hf"
model_id = "EleutherAI/pythia-160m-deduped"
print(f"Loading {model_id}...")
if "pythia" in model_id:
int_decoder = PythiaIntermediateDecoder(model_id=model_id)
elif "OLMo" in model_id:
int_decoder = OlmoIntermediateDecoder(model_id=model_id)
else:
raise TypeError("Could not recognise model type and associated intermediate decoder")
with open("datasets/2digit_sum_dataset.json") as f:
two_digit_dataset = json.load(f)
random.shuffle(two_digit_dataset)
n = len(two_digit_dataset)
n = 100
first_n = two_digit_dataset[:n]
num_layers = len(int_decoder.model.base_model.layers)
colors = itertools.cycle(sns.color_palette("tab10"))
sns.set(style="whitegrid")
plt.figure(figsize=(10, 6))
questions = [elt[0] for elt in first_n]
answers = [elt[1] for elt in first_n]
all_probabilities = []
avg_probabilities = {i: [] for i in range(num_layers)}
plotted_counter = 0
for i in tqdm(range(len(questions)), dynamic_ncols=True):
question = questions[i]
answer = answers[i]
prompt = f"Question: What is {question}? Answer: {question}="
int_decoder.reset_all()
block_activations = int_decoder.decode_all_layers(prompt,
topk=5,
printing=False,
print_attn_mech=False,
print_intermediate_res=False,
print_mlp=False,
print_block=True
)
block_numbers = []
probabilities = []
for block_activation in block_activations:
block_num = int(block_activation[0].split()[1])
token_probs = dict(block_activation[1])
probability_correct = token_probs.get(answer, 0)
block_numbers.append(block_num)
probabilities.append(probability_correct)
avg_probabilities[block_num].append(probability_correct)
if plotted_counter < 10:
sns.lineplot(x=block_numbers, y=probabilities, marker='o', label=f"Probability of token '{answer}' for question {question}", color=next(colors), alpha=0.3)
plotted_counter += 1
avg_prob_values = [np.mean(avg_probabilities[block]) for block in block_numbers]
sns.lineplot(x=block_numbers, y=avg_prob_values, marker='o', color="black", label=f"Average Probability Across {n} prompts", linewidth=3)
plt.xlabel("Block Number", fontsize=17)
plt.ylabel("Probability (%)", fontsize=17)
plt.title(f"Logit Lens Probability of Correct Token Across Decoder Blocks of {model_id}", fontsize=18)
plt.ylim(0, 100)
x_min = num_layers // 5
plt.xlim(x_min, num_layers)
plt.xticks(block_numbers[x_min:], fontsize=14)
plt.yticks(fontsize=14)
plt.gca().yaxis.set_major_formatter(PercentFormatter())
plt.grid(visible=True, which='both', axis='both', linestyle='--', linewidth=0.7)
plt.legend(fontsize=13, loc='upper left',)
plt.savefig(f"figures/2digit_accuracy_intdecoding/{model_id}.pdf")
plt.show()