-
Notifications
You must be signed in to change notification settings - Fork 0
/
trainImageMethod.m
65 lines (48 loc) · 1.79 KB
/
trainImageMethod.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
function trainImageMethod(classifierName,classifierConstructor,opts)
% SPDX-License-Identifier: BSD-3-Clause
arguments
classifierName (1,1) string
classifierConstructor (1,1) function_handle
opts.UseParallel = false
opts.UseGPU = false
end
if opts.UseParallel
if isempty(gcp('nocreate'))
parpool();
end
end
% Set up data paths
beehiveDataSetup;
% Load in the hyperparameters
load(hyperparameterResultsDir + filesep + classifierName + "Hyperparams",...
"hyperparams");
% Load in the training data
load(trainingDataDir + filesep + "trainingData","trainingData",...
"trainingImgLabels");
% Load in the validation data
load(validationDataDir + filesep + "validationData",...
"validationData","validationImgLabels");
% Combine the training and validation data into one set for training
data = horzcat(trainingData,validationData);
labels = horzcat(trainingImgLabels,validationImgLabels);
% Free up some memory
clear "trainingData" "validationData" "trainingImgLabels" "validationImgLabels";
% Assmeble the classifier's hyperparameter arguments
params = classifierConstructor().formatOptimizableParams(hyperparams);
classifierArgs = namedargs2cell(params);
if opts.UseGPU
% NOTE: not all classifiers support GPU acceleration; the ones that don't
% support GPU acceleration don't have a UseGPU argument, so these
% classifiers will raise an error if UseGPU is passed in.
classifierArgs = horzcat(classifierArgs, {'UseGPU'}, opts.UseGPU);
end
% Construct the classifier
classifier = classifierConstructor(classifierArgs{:});
% Train the classifier
fit(classifier,data,labels);
% Save the classifier
if ~exist(finalClassifierDir)
mkdir(trainingResultsDir,"classifiers");
end
save(finalClassifierDir + filesep + classifierName,"classifier","-v7.3");
end