-
Notifications
You must be signed in to change notification settings - Fork 0
/
demo_multi_task.m
executable file
·67 lines (54 loc) · 2 KB
/
demo_multi_task.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
% Copyright 2014 Yuan Shi & Aurelien Bellet
%
% This file is part of SCML.
%
% SCML is free software: you can redistribute it and/or modify
% it under the terms of the GNU General Public License as published by
% the Free Software Foundation, either version 3 of the License, or
% (at your option) any later version.
%
% SCML is distributed in the hope that it will be useful,
% but WITHOUT ANY WARRANTY; without even the implied warranty of
% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
% GNU General Public License for more details.
%
% You should have received a copy of the GNU General Public License
% along with SCML. If not, see <http://www.gnu.org/licenses/>.
% demo of st-SCML and mt-SCML
setpaths;
%% initialize random number generator
init_rng();
%% load data
load('dataset/sentiment_mt.mat');
% preprocessing
n_task = length(xTr);
for t = 1:n_task
[xTr{t}, xVa{t}, xTe{t}] = pre_process_data(xTr{t}, xVa{t}, xTe{t});
end
%% single-task learning
for t = 1:n_task
st_L_set{t} = SCML_global(xTr{t}, yTr{t}, 400, 1e-4);
end
for t = 1:n_task
err_st{t} = knnclassifytree(st_L_set{t},xTr{t}',yTr{t}',xTe{t}',yTe{t}',3);
end
mt_L_set = mt_SCML(xTr, yTr, 100*n_task, 1e-4);
for t = 1:n_task
err_mt{t} = knnclassifytree(mt_L_set{t},xTr{t}',yTr{t}',xTe{t}',yTe{t}',3);
end
fprintf('Single-task learning\n');
res_train = zeros(n_task,1);
res_test = zeros(n_task,1);
for t = 1:n_task
fprintf('%g task: train %.2f\t test %.2f\n', t, err_st{t}(1)*100, err_st{t}(2)*100);
res_train(t) = err_st{t}(1)*100;
res_test(t) = err_st{t}(2)*100;
end
fprintf('Average: train %.2f\t test %.2f\n\n', mean(res_train), mean(res_test));
fprintf('Multi-task learning\n');
for t = 1:n_task
fprintf('%g task: train %.2f\t test %.2f\n', t, err_mt{t}(1)*100, err_mt{t}(2)*100);
res_train(t) = err_mt{t}(1)*100;
res_test(t) = err_mt{t}(2)*100;
end
fprintf('Average: train %.2f\t test %.2f\n', mean(res_train), mean(res_test));