-
Notifications
You must be signed in to change notification settings - Fork 2
/
split_bregman.py
107 lines (96 loc) · 3.47 KB
/
split_bregman.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
# uncompyle6 version 3.0.0
# Python bytecode 2.7 (62211)
# Decompiled from: Python 2.7.13 (default, Dec 1 2017, 09:21:53)
# [GCC 6.4.1 20170727 (Red Hat 6.4.1-1)]
# Embedded file name: /home/melrobin/research/ibraheem/SplitBregmanTVdenoising/split_bregman.py
# Compiled at: 2018-03-04 17:52:02
import numpy as np, pdb, scipy, scipy.sparse as sp, scipy.sparse.linalg as splinalg
def SB_ITV(g, mu):
g = g.flatten('F')
n = len(g)
B, Bt, BtB = DiffOper(int(np.sqrt(n)))
b = np.zeros((2*n,1))
d = b
u = g
err = 1
k = 1
tol = 0.001
lambda1 = 1
while err > tol:
print 'it. %g '% k,
up = u
u,_=sp.linalg.cg(sp.eye(n)+BtB,g-np.squeeze(lambda1*Bt.dot(b-d)),tol=1e-5, maxiter=100)
Bub = B.dot(u) + np.squeeze(b)
s = np.sqrt(Bub[:n]**2 + Bub[n:]**2)
if s[0]==0.:
s[0]=1.
d = np.concatenate((np.maximum(s-mu/lambda1,0.)*Bub[:n]/s,np.maximum(s-mu/lambda1,0.)*Bub[n:]/s))
b = Bub - d
err = np.linalg.norm(up - u) / np.linalg.norm(u)
print 'err=%g \n'% err,
k = k + 1
print 'Stopped because norm(up-u)/norm(u) <= tol=%.1e\n'% tol
return u
def SB_ATV(g, mu):
g = g.flatten()
n = len(g)
B, Bt, BtB = DiffOper(int(np.sqrt(n)))
b = np.zeros((2 * n, 1))
d = b
u = g
err = 1
k = 1
tol = 0.001
lambda1 = 1
while err > tol:
print 'it. %d ' % k,
up = u
u, _ = splinalg.cg(sp.eye(n) + BtB, g - np.squeeze(lambda1 * Bt.dot(b - d)), tol=1e-05, maxiter=100)
Bub = B * u + np.squeeze(b)
print np.linalg.norm(Bub),
d = np.maximum(np.abs(Bub) - mu / lambda1,0) * np.sign(Bub)
b = Bub - d
err = np.linalg.norm(up - u) / np.linalg.norm(u)
print 'err=%g' % err
k = k + 1
print 'Stopped because norm(up-u)/norm(u) <= tol=%.1e\n'% tol
return u
def delete_row_lil(mat, i):
if not isinstance(mat, scipy.sparse.lil_matrix):
raise ValueError('works only for LIL format -- use .tolil() first')
mat.rows = np.delete(mat.rows, i)
mat.data = np.delete(mat.data, i)
mat._shape = (mat._shape[0] - 1, mat._shape[1])
def delete_col_lil(mat, i):
if not isinstance(mat, scipy.sparse.lil_matrix):
raise ValueError('works only for LIL format -- use .tolil() first')
mat.cols = np.delete(mat.rows, i)
mat.data = np.delete(mat.data, i)
mat._shape = (mat._shape[0], mat._shape[1] - 1)
def delete_row_csr(mat, i):
if not isinstance(mat, scipy.sparse.csr_matrix):
raise ValueError('works only for CSR format -- use .tocsr() first')
n = mat.indptr[i + 1] - mat.indptr[i]
if n > 0:
mat.data[(mat.indptr[i]):(-n)] = mat.data[mat.indptr[i + 1]:]
mat.data = mat.data[:-n]
mat.indices[(mat.indptr[i]):(-n)] = mat.indices[mat.indptr[i + 1]:]
mat.indices = mat.indices[:-n]
mat.indptr[i:(-1)] = mat.indptr[i + 1:]
mat.indptr[i:] -= n
mat.indptr = mat.indptr[:-1]
mat._shape = (mat._shape[0] - 1, mat._shape[1])
def DiffOper(N):
data = np.vstack([-np.ones((1, N)), np.ones((1, N))])
D = sp.diags(data, [0, 1], (N, N + 1), 'csr')
#print 'shape before: ', D.shape
D = D[:, 1:]
#print 'shape afterward: ', D.shape
D[(0, 0)] = 0
#print 'D dimensions: ', D.shape
B = sp.vstack([sp.kron(sp.eye(N), D), sp.kron(D, sp.eye(N))], 'csr')
Bt = B.transpose().tocsr()
BtB = Bt * B
#print 'BtB dimensions: ', BtB.shape
#print 'Returned'
return B, Bt, BtB