forked from HAHA-DL/MLDG
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
80 lines (55 loc) · 1.81 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import numpy as np
import torch
import torch.optim as optim
from sklearn.metrics import accuracy_score
def unfold_label(labels, classes):
new_labels = []
assert len(np.unique(labels)) == classes
# minimum value of labels
mini = np.min(labels)
for index in range(len(labels)):
dump = np.full(shape=[classes], fill_value=0).astype(np.int8)
_class = int(labels[index]) - mini
dump[_class] = 1
new_labels.append(dump)
return np.array(new_labels)
def shuffle_data(samples, labels):
num = len(labels)
shuffle_index = np.random.permutation(np.arange(num))
shuffled_samples = samples[shuffle_index]
shuffled_labels = labels[shuffle_index]
return shuffled_samples, shuffled_labels
def shuffle_list(li):
np.random.shuffle(li)
return li
def shuffle_list_with_ind(li):
shuffle_index = np.random.permutation(np.arange(len(li)))
li = li[shuffle_index]
return li, shuffle_index
def num_flat_features(x):
size = x.size()[1:] # all dimensions except the batch dimension
num_features = 1
for s in size:
num_features *= s
return num_features
def crossentropyloss():
loss_fn = torch.nn.CrossEntropyLoss()
return loss_fn
def mseloss():
loss_fn = torch.nn.MSELoss()
return loss_fn
def sgd(parameters, lr, weight_decay=0.00005, momentum=0.9):
opt = optim.SGD(params=parameters, lr=lr, momentum=momentum, weight_decay=weight_decay)
return opt
def fix_seed():
# torch.manual_seed(1108)
# np.random.seed(1108)
pass
def write_log(log, log_path):
f = open(log_path, mode='a')
f.write(str(log))
f.write('\n')
f.close()
def compute_accuracy(predictions, labels):
accuracy = accuracy_score(y_true=np.argmax(labels, axis=-1), y_pred=np.argmax(predictions, axis=-1))
return accuracy