-
Notifications
You must be signed in to change notification settings - Fork 0
/
tuneHyperparamsImageMethod.m
94 lines (76 loc) · 3.37 KB
/
tuneHyperparamsImageMethod.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
function tuneHyperparamsImageMethod(classifierName,classifierType,opts)
% tuneHyperparamsImageMethod perform hyperparameter tuning for image-based methods
%
% tuneHyperparamsImageMethod(classifierName,classifierType) performs
% hyperparameter tuning for the classifier specified by classifierName, e.g.,
% CNN2d3Layer. classifierType is a handle to the classifier's class, e.g.,
% @CNN2d.
%
% Name-value options:
% UseParallel - Use the parallell computing toolbox. The hyperparameter
% tuning using bayesopt is not performed in parallel
% because that is not reproducible. However, othe
% computations can be performed in parallel. Defaults to
% false.
% UseGPU - Train using a GPU. Defaults to false.
% NIterations - Number of iterations to use during Bayesian optimizaiton.
% Defaults to 15.
% SPDX-License-Identifier: BSD-3-Clause
arguments
classifierName (1,1) string
classifierType (1,1) function_handle
opts.UseParallel = false
opts.UseGPU = false
opts.NIterations = 15
end
if opts.UseParallel
if isempty(gcp('nocreate'))
parpool();
end
end
% Set up data paths
beehiveDataSetup;
% Load in the best data sampling parameters for the classifier
load(default2dCNNResultsDir + filesep + ...
classifierName + "ManualParamTraining","objective","userdata");
% Keep track of the original objective function value that way we know
% if the default hyperparameters were best or not
originalObjective = objective;
originalHyperparameters = userdata.Classifier.Hyperparams;
disp("originalObjective = " + originalObjective);
% Load in the optimizable hyperparameters for the classifier
load(trainingDataDir + filesep + classifierName + "HyperparameterSearchValues");
% Load in the training data
load(trainingDataDir + filesep + "trainingData","trainingData",...
"trainingImgLabels");
% Load in the validation data
load(validationDataDir + filesep + "validationData",...
"validationData","validationImgLabels");
% Create the minimization function for bayesopt
minfcn = @(optimizable)validationObjFcn(classifierType,trainingData,...
trainingImgLabels,validationData,validationImgLabels,...
UseParallel=opts.UseParallel,UseGPU=opts.UseGPU,...
Static=originalHyperparameters,Optimizable=optimizable);
% Seed the random number generator for reproducibility
rng(7,'twister');
% Do the parameter search!
results = bayesopt(minfcn,optimizableParams,IsObjectiveDeterministic=true,...
UseParallel=false,AcquisitionFunctionName="expected-improvement-plus",...
MaxObjectiveEvaluations=opts.NIterations,Verbose=1,PlotFcn=[]);
% If minimum objective found by bayesopt is lower than the minimum objective
% found during the data sampling grid search, use the hyperparameters associated
% with that iteration of bayesopt. Otherwise, we'll use the hyperpararmetersr
if results.MinObjective < originalObjective
hyperparams = table2struct(bestPoint(results));
else
hyperparams = originalHyperparameters;
end
% Save the best hyperparameter settings
filename = classifierName + "Hyperparams";
if ~exist(hyperparameterResultsDir)
mkdir(trainingResultsDir,"hyperparameter-tuning");
end
save(hyperparameterResultsDir + filesep + filename,...
"hyperparams","results","-v7.3");
writeValidationResultsToTxtFile(classifierName,false,results);
end