-
Notifications
You must be signed in to change notification settings - Fork 0
/
startSlave.m
155 lines (149 loc) · 4.35 KB
/
startSlave.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
function [] = startSlave(debugMode)
%STARTSLAVE Summary of this function goes here
% Detailed explanation goes here
% created 06-20-2018
% last modification -- -- --
% Okba Bekhelifi, <[email protected]>
clc;
if(debugMode)
fprintf('Recovering shared memory.\n');
end
wPids = getWorkersPids();
nWorkers = length(wPids);
[~, workerRank] = find(sort(cellfun(@str2num, wPids))==feature('getPid'));
pid = sprintf('%d', workerRank);
clear wPids
% Set IPC
masterPorts = 9091:9091+nWorkers;
slavePorts = 9191:9191+nWorkers;
if(debugMode)
fprintf('Worker %d Opening communication channel on port: %d\n', ...
feature('getPid'), ...
slavePorts(workerRank)...
);
end
slaveSocket = udp('Localhost', masterPorts(workerRank), ...
'LocalPort', slavePorts(workerRank)...
);
fopen(slaveSocket);
clear masterPorts slavePorts
% Recover Shared Memory
fHandle = SharedMemory('attach', 'shared_fhandle');
datacell = SharedMemory('attach', 'shared_data');
if(debugMode)
fprintf('Data recovery succeded\n');
end
param = SharedMemory('attach', ['shared_' pid]);
workerResult = cell(1, length(param));
% Evaluate Functions
if(debugMode)
fprintf('Worker %s Evaluating job\n', pid);
% fprintf('Evaluatating function: %s\n', fHandle);
end
if(isstruct(fHandle) && isstruct(datacell))
% Train & Predict mode
mode = 'double';
else
% Train only mode
mode = 'single';
end
for p=1:length(param)
if(strcmp(mode, 'single'))
workerResult{p} = feval(fHandle, datacell{:}, param{p});
else
if(strcmp(mode, 'double'))
% split data and evaluate folds
nfolds = max(datacell.fold);
acc_folds = zeros(1, nfolds);
for f=1:nfolds
idx = datacell.fold==f;
train = ~idx;
test = idx;
af = eval_fold(fHandle, ...
datacell.data, ...
param{p}, ...
train,...
test...
);
acc_folds(f) = af;
end
workerResult{p} = mean(acc_folds);
end
end
end
% Detach SharedMemroy
if(debugMode)
fprintf('Worker %s Detaching sharedMemory\n', pid);
end
SharedMemory('detach', 'shared_fhandle', fHandle);
SharedMemory('detach', 'shared_data', datacell);
SharedMemory('detach', ['shared_' pid], param);
clear fhandle datacell param
%
% Write results in SharedMemory
resKey = ['res_' pid];
if(debugMode)
fprintf('Worker %s Writing results in sharedMemory\n', pid);
fprintf('Worker %s shared result key %s\n', pid, resKey);
end
SharedMemory('clone', resKey, workerResult);
if(debugMode)
fprintf('Opening slave socket\n');
fprintf('writing data to socket \n');
end
fprintf(slaveSocket, '%d', feature('getPid'));
if(debugMode)
fprintf('Data sent : %d to %d\n',...
slaveSocket.ValuesSent, ...
slaveSocket.propinfo.RemotePort.DefaultValue...
);
end
fclose(slaveSocket);
delete(slaveSocket);
end
function af = eval_fold(fhandle, data, param, trainIdx, predictIdx)
dTrain = getSplit(data, trainIdx);
dPredict = getSplit(data, predictIdx);
if(isstruct(data))
slaveModel = feval(fhandle.tr, dTrain, param{:});
predFold = feval(fhandle.pr, dPredict, slaveModel);
else
slaveModel = feval(fhandle.tr, dTrain{:}, param);
predFold = feval(fhandle.pr, dPredict{:}, slaveModel);
end
af = getAccuracy(predFold, dPredict);
end
function d = getSplit(d, id)
if(isstruct(d))
fields = fieldnames(d);
if(numel(fields)==2)
d.x = d.x(id, :);
d.y = d.y(id, :);
else
for fd = 1:length(fields)
if(ndims(d.(fields{fd}))==3)
d.(fields{fd}) = d.(fields{fd})(:,:, id);
else if(ismatrix(d.(fields{fd})) && length(d.(fields{fd})) > sum(id) )
d.(fields{fd}) = d.(fields{fd})(id);
end
end
end
end
else
d{1} = d{1}(id, :);
d{2} = d{2}(id, :);
end
end
function acc = getAccuracy(predFold, data)
if(iscell(data))
if(size(data{1}, 2) > size(data{2}, 2))
% Label data in second cell
i = 2;
else
i = 1;
end
acc = (sum(data{i}==predFold) / length(data{i})) * 100;
else
acc = (sum(data.y==predFold.y) / length(data.y)) * 100;
end
end