-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_lin_models.m
369 lines (327 loc) · 15.3 KB
/
test_lin_models.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
% Test linear model setup, prediction and update functions
%
clear variables
addpath("yaml")
addpath("plot-utils")
test_dir = "tests";
test_data_dir = "data";
%% Test initialization with data
Load = [50 100 150]';
Power = [35.05 70.18 104.77]';
data = table(Load, Power);
params = struct();
params.predictorNames = "Load";
params.responseNames = "Power";
params.significance = 0.1;
% Initialize a linear model
[model, vars] = lin_model_setup(data, params);
assert(isa(model, 'LinearModel'))
specific_energy = data.Power ./ data.Load;
assert(isequal(fieldnames(vars), {'significance'}'))
assert(isequal(model.CoefficientNames, {'(Intercept)', 'x1'}));
assert(isequal( ...
round(model.Coefficients.Estimate, 4), ...
[0.2800 0.6972]' ...
));
assert(round(1 - model.Rsquared.Adjusted, 5, 'significant') == 3.9992e-05);
% Test predictions with single point
x = 200;
[y_mean, y_sigma, y_int] = lin_model_predict(model, x, vars, params);
assert(round(y_mean, 4) == 139.7200);
assert(isequal(round(y_sigma, 4), 1.2926));
assert(isequal(round(y_int, 4), [137.5938 141.8462]));
% define and test a function handle
f_handle = @(model, x, vars, params) lin_model_predict(model, x, vars, params);
[y_mean, y_sigma, y_int] = f_handle(model, x, vars, params);
assert(round(y_mean, 4) == 139.7200);
assert(isequal(round(y_sigma, 4), 1.2926));
assert(isequal(round(y_int, 4), [137.5938 141.8462]));
% Test again using feval with function name
f_name = "lin_model_predict";
[y_mean, y_sigma, y_int] = builtin('feval', f_name, model, x, vars, params);
assert(round(y_mean, 4) == 139.7200);
assert(isequal(round(y_sigma, 4), 1.2926));
assert(isequal(round(y_int, 4), [137.5938 141.8462]));
%% Test with config file
% Load configuration file
filepath = fullfile(test_dir, test_data_dir, "test_config_lin.yaml");
config = yaml.loadFile(filepath, "ConvertToArray", true);
% Load training data from file
training_data = struct();
for machine = string(fieldnames(config.machines))'
filename = config.machines.(machine).trainingData;
training_data.(machine) = readtable(...
fullfile(test_dir, test_data_dir, filename) ...
);
end
% Create model objects by running the setup scripts with
% the pre-defined model data specified in the config struct
models = struct();
model_vars = struct();
% Test results to compare to
coeffs_chk = [
20.4981 0.5101
93.6391 0.3797
64.7529 0.5433
64.9354 0.5433
67.3376 0.5391
];
machine_names = string(fieldnames(config.machines))';
for i = 1:numel(machine_names)
machine = machine_names(i);
model_name = config.machines.(machine).model;
model_config = config.models.(model_name);
% Run model setup script
[model, vars] = feval( ...
model_config.setupFcn, ...
training_data.(machine), ...
model_config.params ...
);
% Check selected model variables and params
assert(vars.significance == model_config.params.significance)
%fprintf("%s\n", strjoin(string(model.Coefficients.Estimate), " "))
assert(isequal( ...
round(model.Coefficients.Estimate, 4), ...
coeffs_chk(i, :)' ...
))
% Save for use below
models.(machine) = model;
model_vars.(machine) = vars;
end
% Make predictions with one model
machine = "machine_1";
op_limits = config.machines.(machine).params.op_limits;
model_name = config.machines.(machine).model;
model_config = config.models.(model_name);
x = linspace(op_limits(1), op_limits(2), 101)';
[y_mean, y_sigma, y_int] = lin_model_predict( ...
models.(machine), ...
x, ...
model_vars.(machine), ...
model_config ...
);
% % Plot predictions and data
% figure(1); clf
% make_statdplot(y_mean, y_int(:, 1), y_int(:, 2), x, ...
% training_data.(machine){:, "Power"}, ...
% training_data.(machine){:, "Load"}, ...
% "Load", "Power")
% p = get(gcf, 'Position');
% set(gcf, 'Position', [p(1:2) 320 210])
% Check outputs
% Use this command to find these values:
% fprintf("%10.4f %10.4f %10.4f %10.4f %10.4f %10.4f ...\n", y_mean)
assert(isequal( ...
round(y_mean, 4), [
49.0645 49.9011 50.7377 51.5743 52.4109 53.2475 ...
54.0840 54.9206 55.7572 56.5938 57.4304 58.2670 ...
59.1036 59.9402 60.7767 61.6133 62.4499 63.2865 ...
64.1231 64.9597 65.7963 66.6329 67.4694 68.3060 ...
69.1426 69.9792 70.8158 71.6524 72.4890 73.3256 ...
74.1622 74.9987 75.8353 76.6719 77.5085 78.3451 ...
79.1817 80.0183 80.8549 81.6914 82.5280 83.3646 ...
84.2012 85.0378 85.8744 86.7110 87.5476 88.3841 ...
89.2207 90.0573 90.8939 91.7305 92.5671 93.4037 ...
94.2403 95.0769 95.9134 96.7500 97.5866 98.4232 ...
99.2598 100.0964 100.9330 101.7696 102.6061 103.4427 ...
104.2793 105.1159 105.9525 106.7891 107.6257 108.4623 ...
109.2988 110.1354 110.9720 111.8086 112.6452 113.4818 ...
114.3184 115.1550 115.9916 116.8281 117.6647 118.5013 ...
119.3379 120.1745 121.0111 121.8477 122.6843 123.5208 ...
124.3574 125.1940 126.0306 126.8672 127.7038 128.5404 ...
129.3770 130.2136 131.0501 131.8867 132.7233
]'))
% fprintf("%10.4f %10.4f %10.4f %10.4f %10.4f %10.4f ...\n", y_sigma)
assert(isequal( ...
round(y_sigma, 4), [
0.1957 0.1806 0.1672 0.1559 0.1472 0.1415 ...
0.1392 0.1406 0.1454 0.1534 0.1641 0.1770 ...
0.1916 0.2077 0.2248 0.2429 0.2616 0.2808 ...
0.3005 0.3205 0.3409 0.3615 0.3823 0.4032 ...
0.4244 0.4456 0.4670 0.4884 0.5100 0.5316 ...
0.5533 0.5750 0.5968 0.6186 0.6405 0.6624 ...
0.6844 0.7064 0.7284 0.7504 0.7725 0.7945 ...
0.8166 0.8387 0.8609 0.8830 0.9052 0.9273 ...
0.9495 0.9717 0.9939 1.0161 1.0384 1.0606 ...
1.0828 1.1051 1.1274 1.1496 1.1719 1.1942 ...
1.2164 1.2387 1.2610 1.2833 1.3056 1.3279 ...
1.3502 1.3726 1.3949 1.4172 1.4395 1.4619 ...
1.4842 1.5065 1.5289 1.5512 1.5735 1.5959 ...
1.6182 1.6406 1.6629 1.6853 1.7076 1.7300 ...
1.7524 1.7747 1.7971 1.8195 1.8418 1.8642 ...
1.8866 1.9089 1.9313 1.9537 1.9761 1.9984 ...
2.0208 2.0432 2.0656 2.0880 2.1103
]'))
% fprintf("%10.4f %10.4f %10.4f %10.4f %10.4f %10.4f ...\n", y_int(:, 1))
assert(isequaln(round(y_int(:, 1), 4), [
48.7427 49.6040 50.4626 51.3178 52.1688 53.0147 ...
53.8550 54.6894 55.5180 56.3415 57.1605 57.9759 ...
58.7884 59.5985 60.4069 61.2139 62.0197 62.8246 ...
63.6288 64.4325 65.2356 66.0383 66.8407 67.6428 ...
68.4446 69.2462 70.0477 70.8490 71.6501 72.4512 ...
73.2521 74.0529 74.8537 75.6543 76.4549 77.2555 ...
78.0560 78.8564 79.6568 80.4571 81.2575 82.0577 ...
82.8580 83.6582 84.4584 85.2586 86.0587 86.8588 ...
87.6589 88.4590 89.2591 90.0591 90.8591 91.6591 ...
92.4591 93.2591 94.0591 94.8591 95.6590 96.4590 ...
97.2589 98.0588 98.8588 99.6587 100.4586 101.2585 ...
102.0584 102.8583 103.6581 104.4580 105.2579 106.0577 ...
106.8576 107.6574 108.4573 109.2571 110.0570 110.8568 ...
111.6566 112.4564 113.2563 114.0561 114.8559 115.6557 ...
116.4555 117.2553 118.0551 118.8549 119.6547 120.4545 ...
121.2543 122.0541 122.8539 123.6537 124.4535 125.2532 ...
126.0530 126.8528 127.6526 128.4523 129.2521
]'))
% fprintf("%10.4f %10.4f %10.4f %10.4f %10.4f %10.4f ...\n", y_int(:, 2))
assert(isequaln(round(y_int(:, 2), 4), [
49.3864 50.1982 51.0127 51.8307 52.6530 53.4802 ...
54.3131 55.1519 55.9964 56.8461 57.7003 58.5581 ...
59.4188 60.2818 61.1466 62.0128 62.8801 63.7484 ...
64.6174 65.4869 66.3570 67.2274 68.0982 68.9693 ...
69.8407 70.7122 71.5839 72.4558 73.3278 74.2000 ...
75.0722 75.9446 76.8170 77.6895 78.5621 79.4347 ...
80.3074 81.1801 82.0529 82.9257 83.7986 84.6715 ...
85.5444 86.4174 87.2904 88.1634 89.0364 89.9095 ...
90.7826 91.6557 92.5288 93.4019 94.2751 95.1482 ...
96.0214 96.8946 97.7678 98.6410 99.5142 100.3874 ...
101.2607 102.1339 103.0072 103.8804 104.7537 105.6270 ...
106.5003 107.3736 108.2469 109.1202 109.9935 110.8668 ...
111.7401 112.6134 113.4868 114.3601 115.2334 116.1068 ...
116.9801 117.8535 118.7268 119.6002 120.4736 121.3469 ...
122.2203 123.0937 123.9670 124.8404 125.7138 126.5872 ...
127.4606 128.3339 129.2073 130.0807 130.9541 131.8275 ...
132.7009 133.5743 134.4477 135.3211 136.1945
]'))
% More data points
io_data = array2table([
145.0000 101.0839
175.0000 122.2633
140.0000 97.6366
205.0000 141.9694
150.0000 104.5735
210.0000 144.8131
120.0000 84.4186
75.0000 58.6758
95.0000 69.4629
170.0000 118.7371
], 'VariableNames', {'Load', 'Power'});
% Add one point to training data
training_data.machine_1 = [
training_data.(machine);
io_data(9, :)
];
% Test update function (trivial for GPs)
[models.(machine), vars] = builtin("feval", ...
model_config.updateFcn, ...
models.(machine), ...
training_data.(machine), ...
vars, ...
model_config.params);
% Check vars updated
assert(model_vars.(machine).significance == model_config.params.significance)
assert(isequal( ...
round(models.(machine).Coefficients.Estimate, 4), ...
[19.7704 0.5218]' ...
));
% Re-do predictions with model
[y_mean, y_sigma, y_int] = lin_model_predict( ...
models.(machine), ...
x, ...
model_vars.(machine), ...
model_config ...
);
% % Plot predictions and data
% figure(2); clf
% make_statdplot(y_mean, y_int(:, 1), y_int(:, 2), x, ...
% training_data.(machine){:, "Power"}, ...
% training_data.(machine){:, "Load"}, ...
% "Load", "Power")
% p = get(gcf, 'Position');
% set(gcf, 'Position', [p(1:2) 320 210])
% Check outputs
% fprintf("%10.4f %10.4f %10.4f %10.4f %10.4f %10.4f ...\n", y_mean)
assert(isequal( ...
round(y_mean, 4), [
48.9891 49.8448 50.7005 51.5562 52.4119 53.2676 ...
54.1233 54.9789 55.8346 56.6903 57.5460 58.4017 ...
59.2574 60.1131 60.9688 61.8245 62.6802 63.5359 ...
64.3915 65.2472 66.1029 66.9586 67.8143 68.6700 ...
69.5257 70.3814 71.2371 72.0928 72.9485 73.8041 ...
74.6598 75.5155 76.3712 77.2269 78.0826 78.9383 ...
79.7940 80.6497 81.5054 82.3610 83.2167 84.0724 ...
84.9281 85.7838 86.6395 87.4952 88.3509 89.2066 ...
90.0623 90.9180 91.7736 92.6293 93.4850 94.3407 ...
95.1964 96.0521 96.9078 97.7635 98.6192 99.4749 ...
100.3306 101.1862 102.0419 102.8976 103.7533 104.6090 ...
105.4647 106.3204 107.1761 108.0318 108.8875 109.7432 ...
110.5988 111.4545 112.3102 113.1659 114.0216 114.8773 ...
115.7330 116.5887 117.4444 118.3001 119.1557 120.0114 ...
120.8671 121.7228 122.5785 123.4342 124.2899 125.1456 ...
126.0013 126.8570 127.7127 128.5683 129.4240 130.2797 ...
131.1354 131.9911 132.8468 133.7025 134.5582
]'))
% fprintf("%10.4f %10.4f %10.4f %10.4f %10.4f %10.4f ...\n", y_sigma)
assert(isequal( ...
round(y_sigma, 4), [
0.2449 0.2321 0.2200 0.2087 0.1983 0.1888 ...
0.1806 0.1737 0.1683 0.1646 0.1627 0.1626 ...
0.1644 0.1679 0.1732 0.1799 0.1881 0.1974 ...
0.2078 0.2191 0.2311 0.2438 0.2570 0.2707 ...
0.2848 0.2992 0.3140 0.3290 0.3442 0.3596 ...
0.3752 0.3909 0.4068 0.4228 0.4389 0.4551 ...
0.4714 0.4877 0.5041 0.5206 0.5372 0.5538 ...
0.5704 0.5871 0.6038 0.6206 0.6374 0.6542 ...
0.6710 0.6879 0.7048 0.7217 0.7387 0.7557 ...
0.7726 0.7896 0.8067 0.8237 0.8407 0.8578 ...
0.8749 0.8920 0.9091 0.9262 0.9433 0.9604 ...
0.9775 0.9947 1.0118 1.0290 1.0462 1.0633 ...
1.0805 1.0977 1.1149 1.1321 1.1493 1.1665 ...
1.1837 1.2009 1.2182 1.2354 1.2526 1.2699 ...
1.2871 1.3044 1.3216 1.3389 1.3561 1.3734 ...
1.3906 1.4079 1.4252 1.4424 1.4597 1.4770 ...
1.4943 1.5115 1.5288 1.5461 1.5634
]'))
% fprintf("%10.4f %10.4f %10.4f %10.4f %10.4f %10.4f ...\n", y_int(:, 1))
assert(isequal( ...
round(y_int(:, 1), 4), ...
[ ...
48.5864 49.4630 50.3385 51.2129 52.0858 52.9570 ...
53.8262 54.6933 55.5578 56.4196 57.2784 58.1342 ...
58.9870 59.8369 60.6839 61.5285 62.3708 63.2111 ...
64.0497 64.8869 65.7228 66.5577 67.3916 68.2248 ...
69.0573 69.8892 70.7207 71.5517 72.3823 73.2127 ...
74.0427 74.8725 75.7021 76.5315 77.3607 78.1897 ...
79.0186 79.8474 80.6761 81.5047 82.3332 83.1616 ...
83.9899 84.8181 85.6463 86.4744 87.3025 88.1305 ...
88.9585 89.7864 90.6143 91.4422 92.2700 93.0978 ...
93.9255 94.7533 95.5810 96.4086 97.2363 98.0639 ...
98.8915 99.7191 100.5467 101.3742 102.2018 103.0293 ...
103.8568 104.6843 105.5118 106.3392 107.1667 107.9941 ...
108.8215 109.6490 110.4764 111.3038 112.1312 112.9585 ...
113.7859 114.6133 115.4407 116.2680 117.0953 117.9227 ...
118.7500 119.5773 120.4047 121.2320 122.0593 122.8866 ...
123.7139 124.5412 125.3685 126.1958 127.0230 127.8503 ...
128.6776 129.5049 130.3321 131.1594 131.9866
]'))
% fprintf("%10.4f %10.4f %10.4f %10.4f %10.4f %10.4f ...\n", y_int(:, 2))
assert(isequal( ...
round(y_int(:, 2), 4), ...
[ ...
49.3919 50.2266 51.0624 51.8995 52.7380 53.5782 ...
54.4203 55.2646 56.1115 56.9611 57.8136 58.6692 ...
59.5278 60.3893 61.2536 62.1205 62.9895 63.8606 ...
64.7333 65.6076 66.4830 67.3596 68.2370 69.1152 ...
69.9941 70.8735 71.7535 72.6338 73.5146 74.3956 ...
75.2770 76.1585 77.0403 77.9223 78.8045 79.6868 ...
80.5693 81.4519 82.3346 83.2174 84.1003 84.9833 ...
85.8664 86.7495 87.6327 88.5159 89.3992 90.2826 ...
91.1660 92.0495 92.9330 93.8165 94.7001 95.5837 ...
96.4673 97.3509 98.2346 99.1183 100.0021 100.8858 ...
101.7696 102.6534 103.5372 104.4210 105.3049 106.1887 ...
107.0726 107.9565 108.8404 109.7243 110.6082 111.4922 ...
112.3761 113.2601 114.1441 115.0281 115.9120 116.7960 ...
117.6800 118.5641 119.4481 120.3321 121.2162 122.1002 ...
122.9842 123.8683 124.7524 125.6364 126.5205 127.4046 ...
128.2887 129.1727 130.0568 130.9409 131.8250 132.7091 ...
133.5933 134.4774 135.3615 136.2456 137.1297 ...
]'))