-
Notifications
You must be signed in to change notification settings - Fork 1
/
demo1_autoRidgeRegress.m
74 lines (59 loc) · 1.9 KB
/
demo1_autoRidgeRegress.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
% demo1_autoRidgeRegress.m
%
% Demo script to illustrate empirical Bayes (EB) ridge regression
%
% This is a two-step inference procedure:
% 1. Use evidence optimization (aka "maximum marginal likelihood", aka
% "type-2 maximum likelihood") to infer ridge parameter
% 2. Find MAP estimate for weights in linear regression model given ridge
% prior
%
% Model
% -----
% k ~ N(0,1/alpha * I) % prior on weights
% y | x, k ~ N(x^T k, nsevar) % linear-Gaussian observations
%
% Empirical Bayes inference:
%
% 1. [alpha_hat, nsevar_hat] = arg max P(Y | X, alpha,nsevar);
% 2. k_hat = arg max_k P(k | Y, X, alpha_hat)
% set path
addpath tools
addpath inference/
%% 1. Make a simulated dataset
nk = 100; % number of regression coefficients
nsamps = 200; % number of samples
signse = 3; % stdev of added noise
% make filter
tt = (1:nk)'; % coefficient indices
k = gsmooth(randn(nk,1),3); % generate smooth weight vector
% make design matrix
Xdsgn = randn(nsamps,nk);
% simulate outputs
y = Xdsgn*k + randn(nsamps,1)*signse;
%% 2. Compute ML and ridge regression estimates
% Compute sufficient statistics
dd.xx = Xdsgn'*Xdsgn;
dd.xy = Xdsgn'*y;
dd.yy = y'*y;
dd.nx = nk;
dd.ny = nsamps;
% Compute ML estimate
kml = dd.xx\dd.xy;
% Compute EB ridge regression estimate
alpha0 = 1; % initial guess at alpha
[kridge,hprs_hat] = autoRidgeRegress_fixedpoint(dd,alpha0);
%% 3. Compare performance & make plots
clf;
plot(tt, k,'k',tt, kml, tt, kridge);
legend('true k', 'ML', 'ridge');
fprintf('\nInferred hyperparameters:\n');
fprintf('-----------------------\n');
fprintf('alpha = %.2f\n',hprs_hat.alpha);
fprintf('nsevar = %.2f (true = %.2f)\n\n',hprs_hat.nsevar, signse^2);
% Compare errors
r2fun = @(kest)(1-sum((k-kest).^2)/sum(k.^2));
fprintf('Performance comparison:\n');
fprintf('-----------------------\n');
fprintf(' ML: R2 = %.3f\n', r2fun(kml));
fprintf('ridge: R2 = %.3f\n', r2fun(kridge));