-
Notifications
You must be signed in to change notification settings - Fork 0
/
trainNNCV.m
128 lines (107 loc) · 3.74 KB
/
trainNNCV.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
function theta=trainNNCV(net,sys,par,data,w0,numIterationsPerEpoch)
% a batch process nn CV training process
import casadi.*
nn = net.nn;
corrP = net.corrP;
% theta = net.theta;
num_u = size(data(1).u,1);
num_Xi = size(data(1).y,1);
num_theta = nn.numel_in - num_u - num_Xi;
w = MX.sym('theta',num_theta);
uk_s = MX.sym('uk',num_u);
Xi_s = MX.sym('Xi',par.nx);
% L = 1000; MX.sym('L',num_u);
% L2 = MX.sym('L2',num_u)
b_s = MX.sym('b',par.nu);
Xi_f= sys.F(Xi_s,[par.d0;uk_s],0,0,0,0);
Xi_f_s = MX.sym('Xi_f',par.nx);
c_s = (nn(uk_s,[Xi_s;Xi_f_s],w));
dcdu = jacobian(c_s,uk_s)+jacobian(c_s,Xi_f_s)*jacobian(Xi_f,uk_s);
Loss_s = norm(dcdu\(b_s-c_s));%+10*norm(dcdu - 1);
% Loss_s = norm(dcdu\(c_s));
Xi_f_func = Function('loss',{uk_s,Xi_s},{Xi_f});
LossFunc = Function('loss',{b_s,uk_s,Xi_s,Xi_f_s,w},{Loss_s});
% dcduFunc = Function('dcdu',{uk_s,Xi_s,Xi_f_s,w},{norm(dcdu+1)});
% cFunc = Function('dcdu',{uk_s,Xi_s,Xi_f_s,w},{c_s});
Loss = 0;
g=[];{};
lbg = [];
ubg = [];
N = 0;
for i=1:length(data)
N = N + size(data(i).u,2);
end
for i=1:length(data)
for j = 1:size(data(i).u,2)
u = data(i).u(:,j);
Loss = Loss + 1/N*(LossFunc(zeros(par.nu,1),u,data(i).y(1:par.nx,j),data(i).y(par.nx+1:end,j),w));
% Loss = Loss + 1/N*sum(LossFunc((-1e-3+1e-5:1e-4:1e-3)*L,u+(-1e-3+1e-5:1e-4:1e-3),repmat(data(i).y(1:par.nx,j),1,20),Xi_f_func(u+(-1e-3+1e-5:1e-4:1e-3),repmat(data(i).y(1:par.nx,j),1,20)),w));
end
end
% g=norm(w(corrP{end}(1:end-num_u)));
% lbg = [ones(num_u,1)];
% ubg = [ones(num_u,1)];
% g=[w(end)];
% lbg = [0];
% ubg = [0];
rng(10086)
if nargin<5 || isempty(w0)
w0 = [net.w0];
end
%% NLP
opts = struct('ipopt',struct('max_iter',5000));
% nlp_prob = struct('f', Loss, 'x', [L;w], 'g', g);
% nlp_solver = nlpsol('nlp_solver', 'ipopt', nlp_prob,opts); % Solve relaxed problem
% Solve the NLP
% sol = nlp_solver('x0',w0, 'lbg',lbg, 'ubg',ubg);%, 'lbx',lbw, 'ubx',ubw);
% flag = nlp_solver.stats();
% flag.success
% theta = full(sol.x);
% grad = nlp_solver.get_function('nlp_grad_f');
% f = nlp_solver.get_function('nlp_f');
%% 梯度下降
vel=[];
learnRate = 0.001;
gradDecay = 0.9;
sqGradDecay = 0.999;
averageGrad = 0;
averageSqGrad = 0;
grad = Function('grad_f',{[w]},{jacobian(Loss,[w])});
f = Function('loss_f',{[w]},{Loss});
theta = w0;
numEpochs = 1;
% numIterationsPerEpoch=10000;%floor(numObservations./miniBatchSize);
numIterations = numEpochs * numIterationsPerEpoch;
monitor = trainingProgressMonitor(Metrics=["Loss","NormOfGradient"],Info="Epoch",XLabel="Iteration");
iteration=0;
epoch = 0;
epoch = epoch + 1;
i=0;
while i < numIterationsPerEpoch && ~monitor.Stop
i = i + 1;
iteration=iteration+1;
grad_f=grad(theta)';
% [theta,vel] = sgdmupdate(theta,grad_f,vel);
% Update the network parameters using the Adam optimizer.
[theta,averageGrad,averageSqGrad] = adamupdate(theta,grad_f,averageGrad,averageSqGrad,iteration,learnRate,gradDecay,sqGradDecay);
% Update the network parameters using the SGDM optimizer.
% [theta,vel] = sgdmupdate(theta,grad_f,vel);
% Update the network parameters using the Adam optimizer.
% [theta,averageGrad,averageSqGrad]=adamstep_my(theta,grad_f,averageGrad,averageSqGrad,iteration,learnRate,gradDecay,sqGradDecay,1e-8);
% % if isempty(averageSqGrad)
% averageSqGrad = averageSqGrad1;
% else
% averageSqGrad = max(averageSqGrad1,averageSqGrad);
% end
if mod(i,10)==0
% temp = loss;
loss = f(theta);
% if loss >= 10*temp
%
% end
% Update the training progress monitor.
recordMetrics(monitor,iteration,Loss=log(full(loss)),NormOfGradient=full(log(norm(grad_f))));
updateInfo(monitor,Epoch=epoch + " of " + numEpochs);
monitor.Progress = 100 * iteration/numIterations;
end
end