-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
115 lines (92 loc) · 2.87 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
"""What does any pytorch based training script look like:
- class ModelConfig
- class Model : nn.Module
- class DataConfig
- class Data : torch.utils.data.Dataset
- one for training: dstrain
- one for eval. : dstest
- class TrainerConfig
- batch_size
- lr
- n_step
...
- class Trainer
- train(): method to train
dl_train = DataLoader(ds_train)
dl_test = DataLoader(ds_test)
for input, global_step in zip(dl_train, range(n_training_steps)):
loss = model(input)
loss.backward()
optim.step()
if global_step % test_every_steps == 0:
test_loss = mean([
loss = model(test_input)
for test_input in dl_test
])
if test_loss < previous_lowest_loss:
save_model()
no_improvement_evals = 0
else:
no_improvement_evals += 1
if no_improvement_evals == patience:
break training
"""
import os
import re
import textract
import numpy as np
from glob import glob
import torch
from torch import nn
from torch.nn import functional as F
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
model = AutoModel.from_pretrained("distilbert-base-uncased")
# load all the text
def read_pdf(filename) -> str:
if os.name == "nt":
# windows bypass for reading PDF files
import PyPDF2
with open(filename, 'rb') as pdf_file:
read_pdf = PyPDF2.PdfFileReader(pdf_file)
number_of_pages = read_pdf.getNumPages()
page = read_pdf.getPage(0)
page_content = page.extractText()
text = page_content
elif os.name == "posix":
#linux bypass for reading PDF files
text = textract.process(filename, method='pdfminer')
text = text.decode("utf-8")
return text
files = glob("./sample/*.pdf")
all_text = []
for f in files:
text = read_pdf(f)
text = re.sub("\s+", " ", text)
all_text.append(text)
# get logits
with torch.no_grad():
out = tokenizer(all_text, return_tensors = "pt", padding = "longest")
output = model(
**{k:v[:, :model.config.max_position_embeddings] for k,v in out.items()}
)
logits = torch.sum(output.last_hidden_state, dim = 1)
# define classifier head
class ClassifierHead(nn.Module):
def __init__(self, i, c):
super().__init__()
self.w = nn.Parameter(data = torch.normal(mean = 0, std = 0.02, size = [i,c]), requires_grad=True)
self.b = nn.Parameter(data = torch.zeros([c,]))
def forward(self, x):
return [email protected] + self.b
c = ClassifierHead(model.config.hidden_dim, 3)
# train the model
optim = torch.optim.AdamOptimizer(c.parameters())
t = torch.Tensor([0, 1, 2, 0, 1, 5, 6]).long() # define target method
for i in range(10):
out = c(logits)
loss = F.cross_entropy(out, target = t)
loss.backward()
optim.step()
ws = np.vstack([c.w.detach().numpy(), c.b.view(1, c.b.shape[-1]).detach().numpy()])
np.save("./params.npy", ws)