-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
122 lines (85 loc) · 2.83 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import platform
import pwd
import subprocess
import time
import numpy as np
import glob
import os
import cPickle as pickle
maxfloat = np.finfo(np.float32).max
def auto_make_dir(path):
if not os.path.exists(path):
os.makedirs(path)
print 'Created dir', path
def find_model_metadata(metadata_dir, config_name,best=False):
if best:
metadata_paths = glob.glob(metadata_dir + '/%s-*-best.pkl' % config_name)
else:
metadata_paths = glob.glob(metadata_dir + '/%s-*[0-9].pkl' % config_name)
if not metadata_paths:
raise ValueError('No metadata files for config %s' % config_name)
elif len(metadata_paths) > 1:
raise ValueError('Multiple metadata files for config %s' % config_name)
print 'Loaded model from', metadata_paths[0]
return metadata_paths[0]
def get_train_valid_split(train_data_path):
filename = 'valid_split.pkl'
# if not os.path.isfile(filename):
# print 'Making validation split'
# create_validation_split.save_train_validation_ids(filename, train_data_path)
return load_pkl(filename)
def check_data_paths(data_path):
if not os.path.isdir(data_path):
raise ValueError('path is not a directory '+data_path)
def get_dir_path(dir_name, root_dir, no_name=False):
if no_name:
username = ''
else:
username = pwd.getpwuid(os.getuid())[0]
dir_path = root_dir + '/' + dir_name + '/%s' % username
if not os.path.isdir(dir_path):
os.makedirs(dir_path)
return dir_path
def hms(seconds):
seconds = np.floor(seconds)
minutes, seconds = divmod(seconds, 60)
hours, minutes = divmod(minutes, 60)
return "%02d:%02d:%02d" % (hours, minutes, seconds)
def timestamp():
return time.strftime("%Y%m%d-%H%M%S", time.localtime())
def hostname():
return platform.node()
def generate_expid(arch_name):
return "%s-%s" % (arch_name, timestamp())
def get_git_revision_hash():
try:
return subprocess.check_output(['git', 'rev-parse', 'HEAD']).strip()
except:
return 0
def save_pkl(obj, path):
with open(path, 'wb') as f:
pickle.dump(obj, f)
def load_pkl(path):
with open(path, 'rb') as f:
obj = pickle.load(f)
return obj
def save_np(obj, path):
np.save(file=path, arr=obj, fix_imports=True)
def savez_compressed_np(obj, path):
np.savez_compressed(file=path, arr=obj)
def load_np(path):
return np.load(path)
def copy(from_folder, to_folder):
command = "cp -r %s %s/." % (from_folder, to_folder)
print command
os.system(command)
def current_learning_rate(schedule, idx):
s = schedule.keys()
s.sort()
current_lr = schedule[0]
for i in s:
if idx >= i:
current_lr = schedule[i]
return current_lr
def get_script_name(file_path):
return os.path.basename(file_path).replace('.py', '')