-
Notifications
You must be signed in to change notification settings - Fork 0
/
spyr.py
287 lines (240 loc) · 10.3 KB
/
spyr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
import scipy
import os
import math
import numpy as np
from numpy import pi
pi=np.pi
import matplotlib.pyplot as plt
import numpy.linalg
from scipy import ndimage
from scipy import misc
from sklearn.metrics import mean_squared_error
import scipy.stats as st
def log0(x, base=None, zero_val=None):
x = np.asarray(x)
y = np.full(x.shape, 0.0 if zero_val is None else zero_val, dtype=np.float)
ii = (x > 0)
y[ii] = np.log(x[ii]) / (1 if base is None else np.log(base))
return y
def log_raised_cos(r, ctrfreq, bandwidth):
rarg = (pi / bandwidth) * log0(pi / ctrfreq * r, 2, bandwidth)
y = np.sqrt(0.5 * (np.cos(rarg) + 1))
y[np.where(rarg >= pi)] =0
y[np.where(rarg <= -pi)] = 0
return y
def log_raised_coshi(r, ctrfreq, bandwidth):
ctrfreq = ctrfreq * math.pow(2, bandwidth)
rarg = (pi / bandwidth) * log0(pi / ctrfreq * r, 2, -pi)
y = np.sqrt(0.5*(np.cos(rarg)+1))
y[np.where(rarg >= 0)] = 1
y[np.where(rarg <= -pi)] = 0
return y
def log_raised_coslo(r, ctrfreq, bandwidth):
ctrfreq= ctrfreq / math.pow(2, bandwidth)
rarg = (pi / bandwidth) * log0(pi / ctrfreq * r, 2, 0)
y = np.sqrt(0.5 * (np.cos(rarg) + 1))
y[np.where(rarg >= pi)] = 0
y[np.where(rarg <= 0)] = 1
return y
def freqspace(dim):
"""Equivalent of Matlab freqspace, frequency spacing for frequency response"""
f1 = []
if dim % 2 == 0:
for i in range(-dim, dim-1, 2):
ft = float(i) / float(dim)
f1.append(ft)
else:
for i in range(-dim+1, dim, 2):
ft = float(i) / float(dim)
f1.append(ft)
return f1
def freqspace2(dim):
"""Equivalent of Matlab freqspace, frequency spacing for frequency response"""
(minval, maxval) = (-dim, dim-1) if (dim % 2 == 0) else (1-dim, dim)
return np.asarray(range(minval,maxval,2)) / dim
def make_steer_frs(dims, numlevels, numorientations, bandwidth):
"""Makes the frequency responses of the filters for a multiscale image transform.
Arguments:
dims -- image shape
numlevels -- number of levels/scales
numorientations -- number of orientation subbands at each scale
bandwidth -- spatial frequency bandwidth in octaves
Returns array that contains all of the frequency responses.
array[0] contains the high frequency response,
array[1] contains the band frequency responses in the form (numlevel, numorientations, dims),
and array[2] contains the low frequency response
"""
result = []
bands=[]
p = numorientations-1
const = math.sqrt(float(math.pow(2,(2*p))*math.pow(math.factorial(p),2)) / float(math.factorial(2*p)*(p+1)))
f1 = freqspace(dims[0])
f2 = freqspace(dims[1])
wx, wy = np.meshgrid(f1, f2)
size = wx.shape
r = np.sqrt(wx**2 + wy**2)
theta = np.arctan2(wy, wx)
bands = np.full((numlevels, numorientations, dims[0], dims[1]), const*1j)
for level in range(numlevels):
for orientation in range(numorientations):
theta_offset = orientation * np.pi / numorientations
ctrfreq = pi / math.pow(2, (level+1)*bandwidth)
band = np.cos(theta - theta_offset)**p * log_raised_cos(r, ctrfreq, bandwidth)
bands[level,orientation,:,:] *= band
hi = log_raised_coshi(r, pi / math.pow(2, bandwidth), bandwidth)
lo = log_raised_coslo(r, pi / math.pow(2, bandwidth * numlevels), bandwidth)
result.append(hi)
result.append(bands)
result.append(lo)
return result
def est_maxlevel(dims,bandwidth):
"""Estimate max level for the steerable pyramid"""
lev = math.floor((math.log(min(dims))/math.log(2)-2)/bandwidth)
lev=int(lev)
return lev
def build_steer_bands(im, freq_resps, numlevels, numorientations):
""" Builds subbands multiscale of a multiscale of a multiscale image transform.
Arguments:
im -- a grayscale image
freq_resps -- filter frequency responses returned by make_steer_frs
numlevels -- number of levels/scales
numorientations -- number of orientation subbands at each scale
Returns array that contains all of the subbands.
array[0] contains the high band,
array[1] contains the bands in the form (numlevel, numorientations, dims),
and array[2] contains the low band
"""
dims = im.shape
bands = []
pyr = []
fourier = np.fft.fftshift(np.fft.fft2(im))
freq_resp_hi = freq_resps[0]
hi = np.fft.ifft2(np.fft.fftshift(np.multiply(fourier, freq_resp_hi))).real
freq_resp_lo = freq_resps[2]
lo = np.fft.ifft2(np.fft.fftshift(np.multiply(fourier, freq_resp_lo))).real
freq_resp_bands = freq_resps[1]
for i in range(numlevels):
for j in range(numorientations):
freq_respband = freq_resp_bands[i][j]
ifourier_band = np.fft.ifft2(np.fft.fftshift(np.multiply(fourier, freq_respband))).real
bands.append(ifourier_band)
bands = np.reshape(bands, [numlevels, numorientations, dims[0], dims[1]])
pyr.append(hi)
pyr.append(bands)
pyr.append(lo)
return pyr
def recon_steer_bands(pyr, freq_resps, numlevels, numorientations):
"""Reconstructs an image from the subband transform.
Arguments:
pyr -- the image transform
freq_resps -- filter frequency responses returned by make_steer_bands, make_quad_frs_imag or make_quad_frs_real
numlevels -- number of levels/scales
numorientations -- number of orientation subbands at each scale
Returns the reconstructed image
"""
result_bands = np.zeros(pyr[0].shape)
freq_hi = np.fft.fftshift(np.fft.fft2(pyr[0]))
result_hi = np.fft.ifft2(np.fft.fftshift(np.multiply(freq_hi, np.conjugate(freq_resps[0])))).real
freq_lo = np.fft.fftshift(np.fft.fft2(pyr[2]))
result_lo = np.fft.ifft2(np.fft.fftshift(np.multiply(freq_lo, np.conjugate(freq_resps[2])))).real
freq_resp_band = freq_resps[1]
pyr_band = pyr[1]
for i in range(numlevels):
for j in range(numorientations):
freq_band = np.fft.fftshift(np.fft.fft2(pyr_band[i][j]))
result_band = np.fft.ifft2(np.fft.fftshift(np.multiply(freq_band, np.conjugate(freq_resp_band[i][j])))).real
result_bands = result_bands + result_band
result = result_bands + result_hi + result_lo
return result
def make_quad_frs_imag(dims,numlevels,numorientations,bandwidth):
"""Makes imaginary frequency responses for the quadrature pairs of "make_steer_frs".
Arguments:
dims -- image shape
numlevels -- number of levels/scales
numorientations -- number of orientation subbands at each scale
bandwidth -- spatial frequency bandwidth in octaves
Returns array that contains the imaginary part of the quadrature pair
array[0] contains the high frequency response,
array[1] contains the band frequency responses in the form (numlevel, numorientations, dims),
and array[2] contains the low frequency response
"""
freq_resps_imag = make_steer_frs(dims,numlevels,numorientations,bandwidth)
freq_resps_imag[0] = np.zeros(dims)
freq_resps_imag[2] = np.zeros(dims)
return freq_resps_imag
def make_quad_frs_real(dims, numlevels, numorientations, bandwidth):
"""Makes real frequency responses for the quadrature pairs of "make_steer_frs".
Arguments:
dims -- image shape
numlevels -- number of levels/scales
numorientations -- number of orientation subbands at each scale
bandwidth -- spatial frequency bandwidth in octaves
Returns array that contains the real part of the quadrature pair
array[0] contains the high frequency response,
array[1] contains the band frequency responses in the form (numlevel, numorientations, dims),
and array[2] contains the low frequency response
"""
freq_resps_real = make_steer_frs(dims,numlevels,numorientations,bandwidth)
freq_resps_real[1] = abs(freq_resps_real[1])
return freq_resps_real
def build_quad_bands(im, freq_resps_imag, freq_resps_real, numlevels, numorientations):
""" Builds quadrature pair multiscale subbands.
Arguments:
im -- grayscale image
freq_resps_imag, freq_resps_real -- filter frequency responses returned by make_quad_frs_imag, make_quad_frs_real
numlevels -- number of levels/scales
numorientations -- number of orientation subbands at each scale
Returns array that contains the quadrature pair multiscale subbands
array[0] contains the high band,
array[1] contains the bands in the form (numlevel, numorientations, dims),
and array[2] contains the low band
"""
pyr_imag = build_steer_bands(im, freq_resps_imag, numlevels, numorientations)
pyr_real = build_steer_bands(im, freq_resps_real, numlevels, numorientations)
pyr = pyr_real + np.multiply(1j, pyr_imag)
return pyr
def view_abs_spyr_images(to_plot, numlevels, numorientations):
to_plot0 = abs(to_plot[0])
to_plot2 = abs(to_plot[2])
to_plot1 = abs(to_plot[1])
plt.figure()
plt.gray()
plt.imshow(to_plot0)
for level in range(numlevels):
for orientation in range(numorientations):
plt.figure()
plt.gray()
plt.imshow(to_plot1[level][orientation])
plt.figure()
plt.gray()
plt.imshow(to_plot2)
def view_real_imag_spyr_images(to_plot, numlevels, numorientations):
to_plot_band = to_plot[1]
plt.figure()
plt.gray()
plt.imshow(to_plot[0].real)
for level in range(numlevels):
for orientation in range(numorientations):
plt.figure()
plt.gray()
f, (ax1, ax2) = plt.subplots(1, 2, sharey=True)
ax1.imshow(to_plot_band[level][orientation].real)
ax2.imshow(to_plot_band[level][orientation].imag)
plt.figure()
plt.gray()
plt.imshow(to_plot[2].real)
def view_spyr_images(to_plot, numlevels, numorientations):
to_plot0=to_plot[0]
to_plot2=to_plot[2]
to_plot1=to_plot[1]
plt.figure()
plt.gray()
plt.imshow(to_plot0)
for level in range(numlevels):
for orientation in range(numorientations):
plt.figure()
plt.gray()
plt.imshow(to_plot1[level][orientation])
plt.figure()
plt.gray()
plt.imshow(to_plot2)