-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdemo_v1_1d.m
121 lines (80 loc) · 2.67 KB
/
demo_v1_1d.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
clear all;
%env = init_env_v1_2;
env = init_env_v1_1d;
filename = 'demo_v1_1d.mat';
w_train = {[1 -2 0], [-2 1 0], [1 -1 0], [-1 1 0]};
w_test = {[1 1 0], [0 0 1]};
params = init_params();
N = 60;
%
% train models
%
for subj = 1:N
UVFA{subj} = train_UVFA(env, w_train, params.gamma, 100);
psi{subj} = train_SFGPI(env, w_train, params.gamma);
Q{subj} = train_MF(env, w_train, params.gamma, params.alpha, params.eps);
end
save(filename);
%
% eval perf on train tasks
%
for subj = 1:N
% compute test policies
pi_train_UVFA = test_UVFA(env, w_train, params.gamma, UVFA{subj});
pi_train_SF = test_SFGPI(env, w_train, params.gamma, psi{subj});
pi_train_MB = test_MB(env, w_train, params.gamma);
pi_train_MF = test_MF(env, w_train, Q{subj});
for t = 1:length(w_train)
% test UVFA
[r, s] = test_perf(env, pi_train_UVFA{t}, w_train{t});
term_s_train(t, 1, subj) = s;
tot_r_train(t, 1, subj) = r;
% test SF
[r, s] = test_perf(env, pi_train_SF{t}, w_train{t});
term_s_train(t, 2, subj) = s;
tot_r_train(t, 2, subj) = r;
% test MB
[r, s] = test_perf(env, pi_train_MB{t}, w_train{t});
term_s_train(t, 3, subj) = s;
tot_r_train(t, 3, subj) = r;
% test MF
[r, s] = test_perf(env, pi_train_MF, w_train{t});
term_s_train(t, 4, subj) = s;
tot_r_train(t, 4, subj) = r;
end
end
%
% eval perf on test tasks
%
for subj = 1:N
% compute test policies
pi_test_UVFA = test_UVFA(env, w_test, params.gamma, UVFA{subj});
pi_test_SF = test_SFGPI(env, w_test, params.gamma, psi{subj});
pi_test_MB = test_MB(env, w_test, params.gamma);
pi_test_MF = test_MF(env, w_test, Q{subj});
for t = 1:length(w_test)
% test UVFA
[r, s] = test_perf(env, pi_test_UVFA{t}, w_test{t});
term_s_test(t, 1, subj) = s;
tot_r_test(t, 1, subj) = r;
% test SF
[r, s] = test_perf(env, pi_test_SF{t}, w_test{t});
term_s_test(t, 2, subj) = s;
tot_r_test(t, 2, subj) = r;
% test MB
[r, s] = test_perf(env, pi_test_MB{t}, w_test{t});
term_s_test(t, 3, subj) = s;
tot_r_test(t, 3, subj) = r;
% test MF
[r, s] = test_perf(env, pi_test_MF, w_test{t});
term_s_test(t, 4, subj) = s;
tot_r_test(t, 4, subj) = r;
end
end
save(filename);
%load(filename);
model_names = {'UVFA', 'SF&GPI', 'MB', 'MF'};
%plot_perf(env, w_train, tot_r_train, model_names); % <-- boring
plot_final_states(env, w_train, term_s_train, model_names);
plot_perf(env, w_test, tot_r_test, model_names);
plot_final_states(env, w_test, term_s_test, model_names);