-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain_step4.py
57 lines (37 loc) · 1.8 KB
/
train_step4.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from __future__ import division
from data.dataset_factory import DatasetFactory
from models.model_factory import ModelsFactory
from options.train_options import TrainOptions
class Train:
def __init__(self):
self._opt = TrainOptions().parse()
self._dataset_train = DatasetFactory.get_by_name("RegDataset", self._opt)
self._dataset_train_size = len(self._dataset_train)
print('#train images = %d' % self._dataset_train_size)
self._model = ModelsFactory.get_by_name("RegModel", self._opt, is_train=True)
self._train()
def _train(self):
self._steps_per_epoch = int (self._dataset_train_size / self._opt.batch_size)
for i_epoch in range(self._opt.load_reg_epoch + 1, self._opt.total_epoch + 1):
# train epoch
self._train_epoch(i_epoch)
# save model
if i_epoch % 20 == 0:
print('saving the model at the end of epoch %d' % i_epoch)
self._model.save(i_epoch)
def _train_epoch(self, i_epoch):
for step in range(1, self._steps_per_epoch+1):
input, labels = self._dataset_train.get_batch()
# train model
self._model.set_input(input, labels)
self._model.optimize_parameters()
# display terminal
self._display_terminal_train(i_epoch, step)
def _display_terminal_train(self, i_epoch, i_train_batch):
errors = self._model.get_current_errors()
message = '(epoch: %d, it: %d/%d) ' % (i_epoch, i_train_batch, self._steps_per_epoch)
for k, v in errors.items():
message += '%s:%.3f ' % (k, v)
print(message)
if __name__ == "__main__":
Train()