-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathutil.py
127 lines (102 loc) · 3.19 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import os
import math
from datetime import datetime
import numpy as np
from PIL import Image
import cv2
####################
# miscellaneous
####################
def get_timestamp():
return datetime.now().strftime('%y%m%d-%H%M%S')
def mkdir(path):
if not os.path.exists(path):
os.makedirs(path)
def mkdirs(paths):
if isinstance(paths, str):
mkdir(paths)
else:
for path in paths:
mkdir(path)
def mkdir_and_rename(path):
if os.path.exists(path):
new_name = path + '_archived_' + get_timestamp()
print('[Warning] Path [%s] already exists. Rename it to [%s]' % (path, new_name))
os.rename(path, new_name)
os.makedirs(path)
####################
# image convert
####################
def Tensor2np(tensor_list, rgb_range):
def _Tensor2numpy(tensor, rgb_range):
array = np.transpose(quantize(tensor, rgb_range).numpy(), (1, 2, 0)).astype(np.uint8)
return array
return [_Tensor2numpy(tensor, rgb_range) for tensor in tensor_list]
def rgb2ycbcr(img, only_y=True):
'''same as matlab rgb2ycbcr
only_y: only return Y channel
Input:
uint8, [0, 255]
float, [0, 1]
'''
in_img_type = img.dtype
img.astype(np.float32)
if in_img_type != np.uint8:
img *= 255.
# convert
if only_y:
rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
else:
rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
[24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128]
if in_img_type == np.uint8:
rlt = rlt.round()
else:
rlt /= 255.
return rlt.astype(in_img_type)
def ycbcr2rgb(img):
'''same as matlab ycbcr2rgb
Input:
uint8, [0, 255]
float, [0, 1]
'''
in_img_type = img.dtype
img.astype(np.float32)
if in_img_type != np.uint8:
img *= 255.
# convert
rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
[0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836]
if in_img_type == np.uint8:
rlt = rlt.round()
else:
rlt /= 255.
return rlt.astype(in_img_type)
def save_img_np(img_np, img_path, mode='RGB'):
if img_np.ndim == 2:
mode = 'L'
img_pil = Image.fromarray(img_np, mode=mode)
img_pil.save(img_path)
def quantize(img, rgb_range):
pixel_range = 255. / rgb_range
# return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range)
return img.mul(pixel_range).clamp(0, 255).round()
####################
# metric
####################
def calc_psnr(sr, hr, scale, rgb_range, benchmark=False):
diff = (sr - hr).data.div(rgb_range)
if benchmark:
shave = scale
if diff.size(1) > 1:
convert = diff.new(1, 3, 1, 1)
convert[0, 0, 0, 0] = 65.738
convert[0, 1, 0, 0] = 129.057
convert[0, 2, 0, 0] = 25.064
diff.mul_(convert).div_(256)
diff = diff.sum(dim=1, keepdim=True)
else:
shave = scale + 6
valid = diff[:, :, shave:-shave, shave:-shave]
mse = valid.pow(2).mean()
return -10 * math.log10(mse)