-
Notifications
You must be signed in to change notification settings - Fork 18
/
engine_state.py
93 lines (71 loc) · 2.35 KB
/
engine_state.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import queue
from heapq import heappush
class EngineStateListener:
def batch_completed(self):
pass
def epoch_started(self):
pass
def epoch_completed(self):
pass
def training_started(self):
pass
def training_completed(self):
pass
def test_completed(self):
pass
def run_completed(self):
pass
class EngineState(EngineStateListener):
__main_engine_state = None # type: Optional[EngineState]
@classmethod
def current_engine_state(cls):
# type: () -> EngineState
return cls.__main_engine_state
def __init__(self, start_epoch, max_epoch):
self.start_epoch = start_epoch
self.max_epoch = max_epoch
self.epoch = start_epoch
self.batch = 0
self.global_step = 0
self.estimated_num_batches = 0
self.lr = 0
self.is_training = False
self.listeners = []
self.last_listeners = []
EngineState.__main_engine_state = self
def add_listener(self, listener, last=False):
# FIXME ugly
if last:
self.last_listeners.append(listener)
else:
self.listeners.append(listener)
def batch_completed(self):
for listener in self.listeners + self.last_listeners:
listener.batch_completed()
self.batch += 1
self.global_step += 1
def epoch_started(self):
for listener in self.listeners + self.last_listeners:
listener.epoch_started()
self.batch = 0
def epoch_completed(self):
for listener in self.listeners + self.last_listeners:
listener.epoch_completed()
if self.epoch != self.max_epoch - 1:
self.epoch += 1
def training_started(self):
for listener in self.listeners + self.last_listeners:
listener.training_started()
self.is_training = True
def training_completed(self):
for listener in self.listeners + self.last_listeners:
listener.training_completed()
self.is_training = False
def test_completed(self):
for listener in self.listeners + self.last_listeners:
listener.test_completed()
def run_completed(self):
for listener in self.listeners + self.last_listeners:
listener.run_completed()
def update_lr(self, lr):
self.lr = lr