-
Notifications
You must be signed in to change notification settings - Fork 0
/
fit_gabor.py
93 lines (83 loc) · 3.63 KB
/
fit_gabor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
"""
Fit Gabor wavelet to trained dictionry. Required Matlab Python engine, Python 3.6 or 3.7, Curve Fitting Toolbox,
and Image Processing Toolbox for Matlab.
@Filename fit_gabor.py
@Author Kion
@Created 5/31/20
"""
import argparse
import glob
import io
import os
import matlab.engine
import numpy as np
parser = argparse.ArgumentParser(description='Fit Gabor wavelet to trained dictionaries')
parser.add_argument('-s', '--sweep', action='store_true', help="Sweep all files in ./results")
parser.add_argument('-i', '--input', help="Specify input file")
parser.add_argument('-r', '--runs', default=100, type=int, help="Number of runs for each wavelet fit")
args = parser.parse_args()
if not args.sweep and args.input is None:
raise ValueError("If not sweeping (--sweep) must specify input (--input <FILE>).")
elif not args.sweep:
file_list = [args.input]
else:
file_list = glob.glob("./results/*.npz")
if __name__ == "__main__":
# Load Matlab engine
print("Starting Matlab engine...")
eng = matlab.engine.start_matlab()
# Add utils to path
eng.addpath(r'./utils/', nargout=0)
# Options used by Matlab script
options = {'shape': 'elliptical', 'runs': args.runs, 'parallel': False, 'visualize': False}
for file_path in file_list:
save_path = file_path.replace('traindata', 'wavelet_fit')
if os.path.exists(save_path):
print("Fit already found, skipping {}".format(file_path))
continue
print("Loading {}...".format(file_path))
# Load .npz file
data_file = np.load(file_path)
# Extract dictionary list from loaded data file
dictionary_list = data_file['phi']
# Pick the dictionary from the last epoch
dictionary = dictionary_list[-1]
# Find dictionary shape and reshape
patch_size = int(np.sqrt(dictionary.shape[0]))
dictionary = dictionary.reshape(patch_size, patch_size, -1)
# Loop through each atom in dictionary and find best wavelet fit
wavelet = []
a = []
b = []
x0 = []
y0 = []
sigmax = []
sigmay = []
theta = []
phi = []
Lambda = []
phase = []
for i in range(dictionary.shape[-1]):
# try:
atom_matlab = matlab.double(dictionary[:, :, i].tolist())
wavelet_fit = eng.fit2dGabor(atom_matlab, options, stdout=io.StringIO())
# TODO: Improve code quality
wavelet.append(np.array(wavelet_fit['patch']))
a.append(wavelet_fit['fit']['a'])
b.append(wavelet_fit['fit']['b'])
x0.append(wavelet_fit['fit']['x0'])
y0.append(wavelet_fit['fit']['y0'])
sigmax.append(wavelet_fit['fit']['sigmax'])
sigmay.append(wavelet_fit['fit']['sigmay'])
theta.append(wavelet_fit['fit']['theta'])
phi.append(wavelet_fit['fit']['phi'])
Lambda.append(wavelet_fit['fit']['lambda'])
phase.append(wavelet_fit['fit']['phase'])
print("SUCCESS to fit wavelet to dictionary element {} of {}".format(i + 1, dictionary.shape[-1]))
# except:
# print("FAILED to fit wavelet to dictionary element {} of {}".format(i + 1, dictionary.shape[-1]))
# continue
np.savez_compressed(save_path, wavelet=np.array(wavelet), a=np.array(a), b=np.array(b), x0=np.array(x0),
y0=np.array(y0), sigmax=np.array(sigmax), sigmay=np.array(sigmay), theta=np.array(theta),
phi=np.array(phi), Lambda=np.array(Lambda), phase=np.array(phase))
print("... Saved in {}".format(save_path))