-
Notifications
You must be signed in to change notification settings - Fork 0
/
log_utils.py
100 lines (91 loc) · 4.2 KB
/
log_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import os
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
from matplotlib.colors import Normalize
import matplotlib.cm as cm
import yaml
class LogWriter(object):
def __init__(self, log_dir_name):
train_log_path, val_log_path, eval_log_path = os.path.join(log_dir_name, "train"), \
os.path.join(log_dir_name, "val"), \
os.path.join(log_dir_name, "eval")
self.writer = {"train": SummaryWriter(train_log_path),
"val": SummaryWriter(val_log_path),
# "eval": SummaryWriter(eval_log_path)
}
self.log_dir_name = log_dir_name
def loss_per_epoch(self, loss_type, loss_arr, phase, epoch):
loss = np.mean(loss_arr)
self.writer[phase].add_scalar(loss_type + '/per_epoch', loss, epoch)
print(f'{phase} / epoch {epoch:03d} / {loss_type} = {loss:.5f}')
def plot_per_epoch(self, data, prediction, ncc_loss,
mind_loss, dice_loss,
# mse_loss,
grad_loss,
loss, phase, epoch):
source_image = data["image"][0][0].detach().cpu().numpy()
source_label = data["label"][0][0].detach().cpu().numpy()
target_image = data["image"][1][0].detach().cpu().numpy()
target_label = data["label"][1][0].detach().cpu().numpy()
pred_image = prediction["image"][0][0].detach().cpu().numpy()
pred_label = prediction["label"][0][0].detach().cpu().numpy()
pred_flow = prediction["flow"][0].detach().cpu().numpy().transpose(1, 2, 3, 0)
shape = source_image.shape
plt.set_cmap('gray')
fig = plt.figure(figsize=(12, 4), dpi=180, facecolor='w', edgecolor='k')
fig.suptitle(f'ncc_loss:{ncc_loss:.2f},'
f'mind_loss:{mind_loss:.2f},'
f'dice_loss:{dice_loss:.2f},'
f'grad_loss:{grad_loss:.2f}')
num_plots = 7
ax = fig.add_subplot(1, num_plots, 1)
ax.set_title("moving image")
ax.imshow(source_image[shape[0] // 2, :, :])
ax = fig.add_subplot(1, num_plots, 2)
ax.set_title("moving label")
ax.imshow(source_label[shape[0] // 2, :, :])
ax = fig.add_subplot(1, num_plots, 3)
ax.set_title("fixed image")
ax.imshow(target_image[shape[0] // 2, :, :])
ax = fig.add_subplot(1, num_plots, 4)
ax.set_title("fixed label")
ax.imshow(target_label[shape[0] // 2, :, :])
ax = fig.add_subplot(1, num_plots, 5)
ax.set_title("warpped label")
ax.imshow(pred_label[shape[0] // 2, :, :])
ax = fig.add_subplot(1, num_plots, 6)
ax.set_title("warpped image")
ax.imshow(pred_image[shape[0] // 2, :, :])
ax = fig.add_subplot(1, num_plots, 7)
ax.set_title("flow")
flow_shape = pred_flow.shape
self.plot_flow(ax, pred_flow[flow_shape[0] // 2, ::2, ::2, 1:])
self.writer[phase].add_figure(f'{phase} / epoch {epoch:03d}', fig)
def plot_flow(self, ax, pred_flow, img_indexing=True, quiver_width=None, scale=1):
if img_indexing:
pred_flow = np.flipud(pred_flow)
ax.set_title("pred flow")
u, v = pred_flow[..., 0], pred_flow[..., 1]
colors = np.arctan2(u, v)
colors[np.isnan(colors)] = 0
norm = Normalize()
norm.autoscale(colors)
colormap = cm.winter
ax.quiver(u, v,
color=colormap(norm(colors).flatten()),
angles='xy',
units='xy',
width=quiver_width,
scale=scale
)
ax.axis('equal')
ax.axis('off')
def time_per_epoch(self, epoch_step_time, phase, epoch):
step_time = np.mean(epoch_step_time)
epoch_time = np.sum(epoch_step_time)
print(f'{phase} / epoch {epoch:03d} / {step_time:0.4f} sec/step / {epoch_time:0.4f} sec/epoch')
def log_configuration(self, **kwargs):
with open(os.path.join(self.log_dir_name, "configs.yaml"), 'w') as yamlfile:
yaml.dump(kwargs, yamlfile)