forked from YangtaoWANG95/TokenCut
-
Notifications
You must be signed in to change notification settings - Fork 0
/
visualizations.py
executable file
·72 lines (67 loc) · 2.19 KB
/
visualizations.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
"""
Vis utilities. Code adapted from LOST: https://github.com/valeoai/LOST
"""
import cv2
import torch
import skimage.io
import numpy as np
import torch.nn as nn
from PIL import Image
import scipy
import matplotlib.pyplot as plt
def visualize_img(image, vis_folder, im_name):
pltname = f"{vis_folder}/{im_name}"
Image.fromarray(image).save(pltname)
print(f"Original image saved at {pltname}.")
def visualize_predictions(img, pred, vis_folder, im_name, save=True):
"""
Visualization of the predicted box and the corresponding seed patch.
"""
image = np.copy(img)
# Plot the box
cv2.rectangle(
image,
(int(pred[0]), int(pred[1])),
(int(pred[2]), int(pred[3])),
(255, 0, 0), 3,
)
if save:
pltname = f"{vis_folder}/{im_name}_TokenCut_pred.jpg"
Image.fromarray(image).save(pltname)
print(f"Predictions saved at {pltname}.")
return image
def visualize_predictions_gt(img, pred, gt, vis_folder, im_name, dim, scales, save=True):
"""
Visualization of the predicted box and the corresponding seed patch.
"""
image = np.copy(img)
# Plot the box
cv2.rectangle(
image,
(int(pred[0]), int(pred[1])),
(int(pred[2]), int(pred[3])),
(255, 0, 0), 3,
)
# Plot the ground truth box
if len(gt>1):
for i in range(len(gt)):
cv2.rectangle(
image,
(int(gt[i][0]), int(gt[i][1])),
(int(gt[i][2]), int(gt[i][3])),
(0, 0, 255), 3,
)
if save:
pltname = f"{vis_folder}/{im_name}_TokenCut_BBOX.jpg"
Image.fromarray(image).save(pltname)
#print(f"Predictions saved at {pltname}.")
return image
def visualize_eigvec(eigvec, vis_folder, im_name, dim, scales, save=True):
"""
Visualization of the second smallest eigvector
"""
eigvec = scipy.ndimage.zoom(eigvec, scales, order=0, mode='nearest')
if save:
pltname = f"{vis_folder}/{im_name}_TokenCut_attn.jpg"
plt.imsave(fname=pltname, arr=eigvec, cmap='cividis')
print(f"Eigen attention saved at {pltname}.")