-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathEnsemble.h
56 lines (46 loc) · 1.68 KB
/
Ensemble.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
/*
* This defines and trains an ensemble of neural networks.
*/
#ifndef ENSEMBLE_H
#define ENSEMBLE_H
#include <QList>
#include "Network.h"
#include "ProblemInfo.h"
class NetworkEnsemble
{
public:
explicit NetworkEnsemble(int);
virtual ~NetworkEnsemble();
/*
* For training we need two lists: the first contains training samples, the second a subset of test samples
* used to measure a network's performance between two epochs.
*/
void training(QList< InputSample* >&, QList< InputSample* > &);
/*
* Tests the performance of a network, printing out some results on the command line. Returns the percentage
* of right answers given on the set.
*/
double test(QList< InputSample* >&);
private:
QList< Network* > m_networks;
int m_nextId; /* next available ID for a network */
/*
* Finds how is the network performing, as a percentage of wrong answers over all the test set.
*/
double computeAverageError(Network *, const QList< InputSample* > &);
/*
* Functions needed for NSGA-II.
*/
QMap< int, QList< Network* > > computeParetoFrontRank(QList< Network* >);
QList< Network* > paretoFront(QList< Network* > &);
bool paretoDominates(Network *, Network *);
QList< Network* > breed(QList< Network* >);
QList< Network* > sortBySparsity(QList< Network* >);
/*
* Functions needed to sort the network list when computing the sparsity of each network.
*/
static bool lessThanError(const Network *, const Network *);
static bool lessThanComplexity(const Network *, const Network *);
static bool lessThanSparsity(const Network *, const Network *);
};
#endif