-
Notifications
You must be signed in to change notification settings - Fork 3
/
lrtc_tnn.m
91 lines (85 loc) · 2.62 KB
/
lrtc_tnn.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
function [X,obj,out] = lrtc_tnn(M,omega,opts)
% Solve the Low-Rank Tensor Completion (LRTC) based on Tensor Nuclear Norm (TNN) problem by M-ADMM
%
% min_X ||X||_*, s.t. P_Omega(X) = P_Omega(M)
%
% ---------------------------------------------
% Input:
% M - d1*d2*d3 tensor
% omega - index of the observed entries
% opts - Structure value in Matlab. The fields are
% opts.tol - termination tolerance
% opts.max_iter - maximum number of iterations
% opts.mu - stepsize for dual variable updating in ADMM
% opts.max_mu - maximum stepsize
% opts.rho - rho>=1, ratio used to increase mu
% opts.DEBUG - 0 or 1
%
% Output:
% X - d1*d2*d3 tensor
% err - residual
% obj - objective function value
% iter - number of iterations
%
% version 1.0 - 25/06/2016
%
% Written by Canyi Lu ([email protected])
%
% References:
% Canyi Lu, Jiashi Feng, Zhouchen Lin, Shuicheng Yan
% Exact Low Tubal Rank Tensor Recovery from Gaussian Measurements
% International Joint Conference on Artificial Intelligence (IJCAI). 2018
tol = 1e-12;
max_iter = 2000;
rho = 1.1;
mu = 1e-4;
max_mu = 1e10;
DEBUG = 0;
out = [];
if ~exist('opts', 'var')
opts = [];
end
if isfield(opts, 'tol'); tol = opts.tol; end
if isfield(opts, 'max_iter'); max_iter = opts.max_iter; end
if isfield(opts, 'rho'); rho = opts.rho; end
if isfield(opts, 'mu'); mu = opts.mu; end
if isfield(opts, 'max_mu'); max_mu = opts.max_mu; end
if isfield(opts, 'DEBUG'); DEBUG = opts.DEBUG; end
dim = size(M);
X = zeros(dim);
X(omega) = M(omega);
E = zeros(dim);
Y = E;
iter = 0;
%tic;
for iter = 1 : max_iter
Xk = X;
Ek = E;
% update X
[X,tnnX] = prox_tnn(-E+M+Y/mu,1/mu);
% update E
E = M-X+Y/mu;
E(omega) = 0;
dY = M-X-E;
chgX = max(abs(Xk(:)-X(:)));
chgE = max(abs(Ek(:)-E(:)));
chg = max([chgX chgE max(abs(dY(:)))]);
if DEBUG
if iter == 1 || mod(iter, 5) == 0
obj = tnnX;
err = norm(dY(:));
disp(['iter ' num2str(iter) ', mu=' num2str(mu) ...
', obj=' num2str(obj) ', err=' num2str(err)]);
end
end
if chg < tol
break;
end
Y = Y + mu*dY;
mu = min(rho*mu,max_mu);
err = norm(dY(:));
out = [out; iter, err];
end
obj = tnnX;
% err = norm(dY(:));
% out = [out; iter, toc, err];