-
Notifications
You must be signed in to change notification settings - Fork 0
/
ESRNN_Perturb_TargetFun.m
115 lines (110 loc) · 3.79 KB
/
ESRNN_Perturb_TargetFun.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
function [z, targetFeedforward] = ESRNN_Perturb_TargetFun(t, r, targetFunPassthrough, targetFeedforward)
dt = targetFunPassthrough.dt;
startPoint = targetFunPassthrough.pos';
kinStart = targetFunPassthrough.kinStart;
endPoint = targetFunPassthrough.end';
perturbTrials = targetFunPassthrough.perturbTrials;
perturbDir = targetFunPassthrough.perturbDir;
perturbMag = targetFunPassthrough.perturbMag;
perturbDist = targetFunPassthrough.perturbDist;
goTime = targetFunPassthrough.goTime;
if t == 0
pos1 = startPoint;
vel1 = [0 0];
F1 = [0 0];
targetFeedforward.t = [];
targetFeedforward.pos = [];
targetFeedforward.vel = [];
targetFeedforward.F = [];
targetFeedforward.FOut = [];
targetFeedforward.pON = false;
targetFeedforward.Feedback = [];
targetFeedforward.FeedbackHistory = [];
targetFeedforward.perturbDir = [];
targetFeedforward.perturbOnTime = [];
targetFeedforward.perturbMag = [];
targetFeedforward.perturbDist = [];
targetFeedforward.inTarg = [];
targetFeedforward.kinStart = kinStart;
targetFeedforward.lock = false;
targetFeedforward.lockTime = [];
else
vel1 = targetFeedforward.vel(end,:);
F1 = targetFeedforward.F(end,:);
pos1 = targetFeedforward.pos(end,:);
end
actFun = @(x) (x > 0) .* tanh(x);
ReLu = @(x) (x > 0) .* x;
FOut = actFun(r + randn(size(r))*0.0);
FOut = [-FOut(1) FOut(2) -FOut(3) FOut(4)]';
if t >= kinStart && ~targetFeedforward.lock
%% Calculate forces
dToStart = sqrt(sum((startPoint - pos1).^2));
dToEnd = sqrt(sum((endPoint - pos1).^2));
dTrigger = sqrt(sum((endPoint - startPoint).^2)) * (1 - perturbDist);
if dToEnd < 0.1
targetFeedforward.inTarg(end+1) = t;
if length(targetFeedforward.inTarg) >= 20
conseq = diff(targetFeedforward.inTarg(end-19:end));
if sum(conseq == 1) == 19
% targetFeedforward.lock = true;
% targetFeedforward.lockTime = t;
end
end
end
if (perturbTrials == 1) && (dToStart > 0.1) && ...
(dToEnd < dTrigger) && isempty(targetFeedforward.perturbDir) && t >= goTime
%targetFeedforward.pON = true;
%targetFeedforward.perturbDir = perturbDir;
%targetFeedforward.perturbMag = perturbMag;
%targetFeedforward.perturbDist = perturbDist;
%targetFeedforward.perturbOnTime = t;
end
F = [0 0];
if targetFeedforward.pON
% F = [sin(targetFeedforward.perturbDir), cos(targetFeedforward.perturbDir)] * perturbMag;
end
if ~targetFeedforward.lock
%% Update current velocity
outputDelay = 1;
if size(targetFeedforward.FOut,1) < outputDelay
oInd = size(targetFeedforward.FOut,1);
else
oInd = outputDelay;
end
if isempty(targetFeedforward.FOut)
thisFOut = FOut;
else
thisFOut = targetFeedforward.FOut(end-(oInd-1),:);
end
tempF = [sum(thisFOut(1:2)) sum(thisFOut(3:4))];
vel = vel1 + (tempF + F) * (dt/1000) * 10;
pos = pos1 + vel;
else
pos = pos1;
vel = 0;
FOut = 0;
F = 0;
end
else
pos = pos1;
vel = vel1;
F = F1;
end
targetFeedforward.vel(end+1,:) = vel;
targetFeedforward.F(end+1,:) = F;
targetFeedforward.FOut(end+1,:) = FOut;
targetFeedforward.t(end+1) = t;
targetFeedforward.pos(end+1,:) = pos;
feedbackDelay = 1;
if size(targetFeedforward.pos,1) < feedbackDelay+1
FInd = size(targetFeedforward.pos,1) - 1;
else
FInd = feedbackDelay;
end
targetFeedforward.Feedback = [targetFeedforward.pos(end-FInd,:)'; ...
targetFeedforward.vel(end-FInd,:)'];
targetFeedforward.Feedback = targetFeedforward.Feedback;% + randn(size(targetFeedforward.Feedback))*0.01;
targetFeedforward.FeedbackHistory(end+1,:) = targetFeedforward.Feedback;
z = pos';
end