-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathnormalBayes.hpp
28 lines (20 loc) · 949 Bytes
/
normalBayes.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
#pragma once
//----------------------------------------------------------------------------
//----------------------------------------------------------------------------
void trainNormalBayes(const cv::Ptr<cv::ml::TrainData> &dataset)
{
auto normal_bayes = cv::ml::NormalBayesClassifier::create();
cv::Mat trainData = dataset->getTrainSamples();
cv::Mat trainLabels = dataset->getTrainResponses();
normal_bayes->train(trainData, 0, trainLabels);
normal_bayes->save("NormalBayes.xml");
}
//----------------------------------------------------------------------------
//----------------------------------------------------------------------------
float testNormalBayes(const cv::Ptr<cv::ml::TrainData> &dataset)
{
auto normal_bayes = cv::ml::NormalBayesClassifier::load("NormalBayes.xml");
std::vector<int32_t> predictions;
auto error = normal_bayes->calcError(dataset, true, predictions);
return error;
}