-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutil.py
103 lines (82 loc) · 2.87 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import numpy as np
import torch
from torch._C import dtype
from typing import Dict
DTYPE_BIT_SIZE: Dict[dtype, int] = {
torch.float32: 32,
torch.float: 32,
torch.float64: 64,
torch.double: 64,
torch.float16: 16,
torch.half: 16,
torch.bfloat16: 16,
torch.complex32: 32,
torch.complex64: 64,
torch.complex128: 128,
torch.cdouble: 128,
torch.uint8: 8,
torch.int8: 8,
torch.int16: 16,
torch.short: 16,
torch.int32: 32,
torch.int: 32,
torch.int64: 64,
torch.long: 64,
torch.bool: 1
}
def to_coordinates_and_features(img):
"""Converts an image to a set of coordinates and features.
Args:
img (torch.Tensor): Shape (channels, height, width).
"""
# Coordinates are indices of all non zero locations of a tensor of ones of
# same shape as spatial dimensions of image
coordinates = torch.ones(img.shape[1:]).nonzero(as_tuple=False).float()
# Normalize coordinates to lie in [-.5, .5]
coordinates = coordinates / (img.shape[1] - 1) - 0.5
# Convert to range [-1, 1]
coordinates *= 2
# Convert image to a tensor of features of shape (num_points, channels)
features = img.reshape(img.shape[0], -1).T
print("features:", features)
return coordinates, features
def model_size_in_bits(model):
"""Calculate total number of bits to store `model` parameters and buffers."""
return sum(sum(t.nelement() * DTYPE_BIT_SIZE[t.dtype] for t in tensors)
for tensors in (model.parameters(), model.buffers()))
def bpp(image, model):
"""Computes size in bits per pixel of model.
Args:
image (torch.Tensor): Image to be fitted by model.
model (torch.nn.Module): Model used to fit image.
"""
num_pixels = np.prod(image.shape) / \
3 # Dividing by 3 because of RGB channels
return model_size_in_bits(model=model) / num_pixels
def psnr(img1, img2):
"""Calculates PSNR between two images.
Args:
img1 (torch.Tensor):
img2 (torch.Tensor):
"""
return 20. * np.log10(1.) - 10. * (img1 - img2).detach().pow(2).mean().log10().to('cpu').item()
def clamp_image(img):
"""Clamp image values to like in [0, 1] and convert to unsigned int.
Args:
img (torch.Tensor):
"""
# Values may lie outside [0, 1], so clamp input
img_ = torch.clamp(img, 0., 1.)
# Pixel values lie in {0, ..., 255}, so round float tensor
return torch.round(img_ * 255) / 255.
def get_clamped_psnr(img, img_recon):
"""Get PSNR between true image and reconstructed image. As reconstructed
image comes from output of neural net, ensure that values like in [0, 1] and
are unsigned ints.
Args:
img (torch.Tensor): Ground truth image.
img_recon (torch.Tensor): Image reconstructed by model.
"""
return psnr(img, clamp_image(img_recon))
def mean(list_):
return np.mean(list_)