-
Notifications
You must be signed in to change notification settings - Fork 129
/
utils.py
124 lines (95 loc) · 4.05 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
import os
def batch_transform(batch, transform):
"""Applies a transform to a batch of samples.
Keyword arguments:
- batch (): a batch os samples
- transform (callable): A function/transform to apply to ``batch``
"""
# Convert the single channel label to RGB in tensor form
# 1. torch.unbind removes the 0-dimension of "labels" and returns a tuple of
# all slices along that dimension
# 2. the transform is applied to each slice
transf_slices = [transform(tensor) for tensor in torch.unbind(batch)]
return torch.stack(transf_slices)
def imshow_batch(images, labels):
"""Displays two grids of images. The top grid displays ``images``
and the bottom grid ``labels``
Keyword arguments:
- images (``Tensor``): a 4D mini-batch tensor of shape
(B, C, H, W)
- labels (``Tensor``): a 4D mini-batch tensor of shape
(B, C, H, W)
"""
# Make a grid with the images and labels and convert it to numpy
images = torchvision.utils.make_grid(images).numpy()
labels = torchvision.utils.make_grid(labels).numpy()
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 7))
ax1.imshow(np.transpose(images, (1, 2, 0)))
ax2.imshow(np.transpose(labels, (1, 2, 0)))
plt.show()
def save_checkpoint(model, optimizer, epoch, miou, args):
"""Saves the model in a specified directory with a specified name.save
Keyword arguments:
- model (``nn.Module``): The model to save.
- optimizer (``torch.optim``): The optimizer state to save.
- epoch (``int``): The current epoch for the model.
- miou (``float``): The mean IoU obtained by the model.
- args (``ArgumentParser``): An instance of ArgumentParser which contains
the arguments used to train ``model``. The arguments are written to a text
file in ``args.save_dir`` named "``args.name``_args.txt".
"""
name = args.name
save_dir = args.save_dir
assert os.path.isdir(
save_dir), "The directory \"{0}\" doesn't exist.".format(save_dir)
# Save model
model_path = os.path.join(save_dir, name)
checkpoint = {
'epoch': epoch,
'miou': miou,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict()
}
torch.save(checkpoint, model_path)
# Save arguments
summary_filename = os.path.join(save_dir, name + '_summary.txt')
with open(summary_filename, 'w') as summary_file:
sorted_args = sorted(vars(args))
summary_file.write("ARGUMENTS\n")
for arg in sorted_args:
arg_str = "{0}: {1}\n".format(arg, getattr(args, arg))
summary_file.write(arg_str)
summary_file.write("\nBEST VALIDATION\n")
summary_file.write("Epoch: {0}\n". format(epoch))
summary_file.write("Mean IoU: {0}\n". format(miou))
def load_checkpoint(model, optimizer, folder_dir, filename):
"""Saves the model in a specified directory with a specified name.save
Keyword arguments:
- model (``nn.Module``): The stored model state is copied to this model
instance.
- optimizer (``torch.optim``): The stored optimizer state is copied to this
optimizer instance.
- folder_dir (``string``): The path to the folder where the saved model
state is located.
- filename (``string``): The model filename.
Returns:
The epoch, mean IoU, ``model``, and ``optimizer`` loaded from the
checkpoint.
"""
assert os.path.isdir(
folder_dir), "The directory \"{0}\" doesn't exist.".format(folder_dir)
# Create folder to save model and information
model_path = os.path.join(folder_dir, filename)
assert os.path.isfile(
model_path), "The model file \"{0}\" doesn't exist.".format(filename)
# Load the stored model parameters to the model instance
checkpoint = torch.load(model_path)
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']
miou = checkpoint['miou']
return model, optimizer, epoch, miou