forked from ultraeric/training
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathUtils.py
105 lines (84 loc) · 2.92 KB
/
Utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
"""Utility classes for training."""
import os
import operator
import time
from Parameters import ARGS
from libs.utils2 import Timer, d2s
from libs.vis2 import mi
import matplotlib.pyplot as plt
import numpy as np
import torch
class MomentCounter:
"""Notify after N Data Moments Passed"""
def __init__(self, n):
self.start = 0
self.n = n
def step(self, data_index):
if data_index.ctr - self.start >= self.n:
self.start = data_index.ctr
return True
return False
class LossLog:
"""Keep Track of Loss, can be used within epoch or for per epoch."""
def __init__(self):
self.log = []
self.ctr = 0
self.total_loss = 0
def add(self, ctr, loss):
self.log.append((ctr, loss))
self.total_loss += loss
self.ctr += 1
def average(self):
return self.total_loss / (self.ctr * 1.)
def export_csv(self, filename):
np.savetxt(
filename,
np.array(self.log),
header='Counter,Loss',
delimiter=",",
comments='')
class RateCounter:
"""Calculate rate of process in Hz"""
def __init__(self):
self.rate_ctr = 0
self.rate_timer_interval = 10.0
self.rate_timer = Timer(self.rate_timer_interval)
def step(self):
self.rate_ctr += 1
if self.rate_timer.check():
print('rate = ' + str(ARGS.batch_size * self.rate_ctr /
self.rate_timer_interval) + 'Hz')
self.rate_timer.reset()
self.rate_ctr = 0
def save_net(weights_file_name, net):
torch.save(
net.state_dict(),
os.path.join(
ARGS.save_path,
weights_file_name +
'.weights'))
# Next, save for inference (creates ['net'] and moves net to GPU #0)
weights = {'net': net.state_dict().copy()}
for key in weights['net']:
weights['net'][key] = weights['net'][key].cuda(device=0)
torch.save(weights,
os.path.join(ARGS.save_path, weights_file_name + '.infer'))
def display_sort_data_moment_loss(data_moment_loss_record, data):
sorted_data_moment_loss_record = sorted(data_moment_loss_record.items(),
key=operator.itemgetter(1))
low_loss_range = range(20)
high_loss_range = range(-1, -20, -1)
for i in low_loss_range + high_loss_range:
l = sorted_data_moment_loss_record[i]
run_code, seg_num, offset = sorted_data_moment_loss_record[i][0][0]
t = sorted_data_moment_loss_record[i][0][1]
o = sorted_data_moment_loss_record[i][0][2]
sorted_data = data.get_data(run_code, seg_num, offset)
plt.figure(22)
plt.clf()
plt.ylim(0, 1)
plt.plot(t, 'r.')
plt.plot(o, 'g.')
plt.plot([0, 20], [0.5, 0.5], 'k')
mi(sorted_data['right'][0, :, :], 23, img_title=d2s(l[1]))
plt.pause(1)