forked from sjgershm/RL-models
-
Notifications
You must be signed in to change notification settings - Fork 0
/
fit_models.m
54 lines (43 loc) · 1.5 KB
/
fit_models.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
function [results, bms_results] = fit_models(data,opts)
% Fit RL models using MFIT.
%
% USAGE: [results, bms_results] = fit_models(data,opts)
%
% INPUTS:
% data - [S x 1] structure array of data for S subjects
% opts - [M x 1] structure of model options (see set_opts.m)
%
% OUTPUTS:
% results - [M x 1] model fits
% bms_results - Bayesian model selection results
%
% Sam Gershman, Nov 2015
% fill in missing fields
for s = 1:length(data)
if ~isfield(data(s),'block') || isempty(data(s).block); data(s).block = ones(data(s).N,1); end
if ~isfield(data(s),'go') || isempty(data(s).go); data(s).go = zeros(data(s).N,1); end
end
for m = 1:length(opts)
disp(['... fitting model ',num2str(m),' out of ',num2str(length(opts))]);
% get parameter structure
[opts1, param] = set_opts(opts(m));
% fit model
tic
fun = @(x,data) Qlearn(x,data,opts1);
R = mfit_optimize(fun,param,data);
toc
R.opts = opts1;
% collect latent variables
if opts1.latents
for s = 1:length(data)
[~,R.latents(s)] = fun(R.x(s,:),data(s));
end
end
% fit empirical prior
R.param_empirical = mfit_priorfit(R.x,param);
results(m) = R;
end
% Bayesian model selection
if nargout > 1
bms_results = mfit_bms(results);
end