-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbigram.py
108 lines (73 loc) · 2.96 KB
/
bigram.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import torch
import torch.nn as nn
from torch.nn import functional as F
block_size = 8
batch_size = 32
eval_iter = 100
max_iters = 10000
# link of data https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open('input.txt', 'r', encoding='utf-8') as f:
text = f.read()
chars = sorted(set(text))
vocab_size = len(chars)
stoi = {char:i for i, char in enumerate(chars)}
itos = {i:char for i, char in enumerate(chars)}
encoder = lambda s: [stoi[char] for char in s]
decoder = lambda l: [itos[i] for i in l]
data = torch.tensor(encoder(text))
data_size = len(data)
train_data = data[:int(data_size*0.9)]
val_data = data[int(data_size*0.1):]
# the target will be one index up from the x and will range from len 1 to 8
def get_batch(split):
data = train_data if split == 'train' else val_data
ix = torch.randint(len(data)- block_size, (batch_size, ))
x = [data[i:i+block_size] for i in ix]
y = [data[i+1:i+block_size+1] for i in ix]
return torch.stack(x), torch.stack(y)
@torch.no_grad()
def estimate_loss():
train_val_loss = {}
model.eval()
for split in ['train', 'val']:
losses = torch.zeros(eval_iter)
for i in range(eval_iter):
X, Y = get_batch(split)
logits, loss = model(X, Y)
losses[i] = loss
train_val_loss[split] = losses.mean()
return train_val_loss
#bigram model
class BigramLanguageModel(nn.Module):
def __init__(self, vocab_size):
super().__init__()
self.token_embedding = nn.Embedding(vocab_size, vocab_size) #esentially a vector representing every 1-65 with vector of len 65.(sort of one hot encoding)
def forward(self, input, targets= None):
logits = self.token_embedding(input)
if targets is None:
loss = None
else:
logits = logits.permute(0, 2, 1)
loss = F.cross_entropy(logits, targets)
return logits, loss
def generate(self, input, max_tokens):
for _ in range(max_tokens):
logits, loss = self.forward(input)
# use only last output
logits = logits[:, -1, :] #B,C
prob = F.softmax(logits, dim = -1)
idx_next = torch.multinomial(prob, 1) #B,1
input = torch.cat((input, idx_next), dim=1)
return input
model = BigramLanguageModel(vocab_size)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
for steps in range(max_iters):
xb, yb = get_batch('train')
logits, loss = model(xb, yb)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if steps%100 == 0:
losses = estimate_loss()
print(f"step: {steps} train loss : {losses['train']} val loss: {losses["val"]}")
print(''.join(decoder(model.generate(torch.zeros((1,1), dtype = torch.long), max_tokens = 500)[0].tolist())))