-
Notifications
You must be signed in to change notification settings - Fork 4
/
train_equity.py
106 lines (89 loc) · 2.91 KB
/
train_equity.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import os
import toml
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as Data
import torch.backends.cudnn as cudnn
from tensorboardX import SummaryWriter
import data
from agents.model import *
from utils import *
def main(config_path):
cudnn.benchmark = True
config = toml.load(config_path)
writer = SummaryWriter()
# model
model_crt = DF(18,6)
# resume from a checkpoint
if config["model"]["load"]:
checkpoint = torch.load(config["model"]["load_path"])
model_crt.load_state_dict(checkpoint['model'])
print("successfully load model")
model_crt.cuda()
# data
dataset_train = data.Equity_DATASET(config["general"]["data_path"])
dataloader_train = Data.DataLoader(dataset_train, batch_size=30000, shuffle=True, pin_memory=True,
num_workers=6 , drop_last=False)
# criterion
criterion = nn.KLDivLoss(reduction="batchmean")
# optim
params = [
{"params": model_crt.parameters(), "lr": config["hyperparameters"]["lr"]}
]
optimizer = optim.Adam(params, betas=(config["hyperparameters"]["betas"], 0.999), weight_decay=config["hyperparameters"]["decay"])
lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.7)
if not os.path.exists(config["model"]["save_path"]):
os.mkdir(config["model"]["save_path"])
train_length = len(dataloader_train)
for epoch in range(1, 1000):
for (i, data_in) in enumerate(dataloader_train):
train_package = [
data_in,
model_crt,
criterion,
optimizer,
writer,
config,
epoch,
i,
train_length
]
train(train_package)
if epoch % 10 == 0:
save_checkpoint({
"model": model_crt.state_dict(),
}, config["model"]["save_path"] + "checkpoint_{}.pt".format(epoch))
print("save model successfully")
if epoch % 20 == 0:
lr_scheduler.step(epoch-config["general"]["start_epoch"])
def train(package):
[data,
model,
criterion,
optimizer,
writer,
config,
epoch,
i,
train_length] = package
model.train()
holes , pubs, history, labels = data
holes = holes.cuda()
pubs = pubs.cuda()
history = history.cuda()
labels = labels.cuda()
predict = model(holes, pubs, history)
predict = torch.log(predict)
loss = criterion(predict, labels)
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), 1.)
optimizer.step()
optimizer.zero_grad()
if i % 10 == 0:
print("iteration: {} epoch: {} loss: {:.6f}".format(i, epoch, loss))
writer.add_scalars("train loss", {"loss": loss}, i + epoch * train_length)
if __name__ == "__main__":
main("./train.toml")