-
Notifications
You must be signed in to change notification settings - Fork 0
/
generate_plot.py
78 lines (62 loc) · 2.74 KB
/
generate_plot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import argparse
import utilities as utils
import matplotlib.pyplot as plt
parser = argparse.ArgumentParser()
parser.add_argument("-f", "--file", type=str, default="", help="Path to pickle file of losses")
parser.add_argument("-s", "--save_file", type=str, default="", help="file to save as. Will save under results")
parser.add_argument("-uc", "--update_check", type=int, default=1000, help="how many iterations per update")
parser.add_argument("-t", "--title", type=str, default="", help="Specify title of plot")
args = parser.parse_args()
def generate_plot(file, save_file="", title=""):
"""
Function for visualizing training and validation loss over training cycle
File must be given for the function to run. If no save_file given, then file will be saved
with same name as File.
If save_file is given, the plot will be saved inside results with the given name
:param file: if ran stand alone it is from args.file, path to specified file
:type file: str
:param save_file: if ran stand alone, it is from args.save_file, name to save the file as
:type save_file: str
:param title: title of the plot to generate
:type title: str
:return: None
"""
assert isinstance(file, str)
assert isinstance(save_file, str)
if file == "":
print("Please specify file path")
return -1
assert (file[-2:] == '.p') # Must specify that the file is a pickled file with the .p!
losses = utils.load_files(file)
epochs_train = len(losses['train'])
epochs_val = len(losses['valid'])
if epochs_train == epochs_val:
plt.plot(list(range(epochs_train)), losses['train'], label='Train Loss')
plt.plot(list(range(epochs_val)), losses['valid'], label='Valid Loss')
plt.xlabel("Every %d Training Iterations" % args.update_check)
plt.ylabel("Loss")
if title == "":
plt.title("Loss over Training Cycle")
else:
plt.title(title)
plt.legend(loc='upper right')
else:
plt.subplot(2, 1, 1)
plt.plot(list(range(epochs_train)), losses['train'], label='Train Loss')
plt.xlabel("Training Iterations")
plt.ylabel("Loss")
plt.title("Train Loss over Training Cycle")
plt.subplot(2, 1, 2)
plt.plot(list(range(epochs_val)), losses['valid'], label='Valid Loss')
plt.xlabel("Every %d Training Iterations" % args.update_check)
plt.ylabel("Loss")
plt.title("Valid Loss over Training Cycle")
plt.legend(loc='upper right')
if save_file == "":
save_file = file[:-2]
plt.savefig(save_file+'.png')
else:
plt.savefig('results/' + save_file + '.png')
plt.show()
if __name__ == "__main__":
generate_plot(args.file, args.save_file, args.title)