-
Notifications
You must be signed in to change notification settings - Fork 0
/
WEKAClassifierTrainSlidingWindow.java
144 lines (112 loc) · 4.17 KB
/
WEKAClassifierTrainSlidingWindow.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
/**
* Author: Leandro L. Minku, University of Leicester, [email protected]
*
* Change in the WEKAClassifier in order to train on a sliding window of examples if the WEKA classifier.
* If the window is not full yet, at each time step, will train on all examples seen so far.
*
* The WEKAClassifier will train non-updateable WEKA classifiers as chunk-based learners.
*/
package moa.classifiers.meta;
import com.github.javacliparser.IntOption;
import com.yahoo.labs.samoa.instances.Instance;
import com.yahoo.labs.samoa.instances.SamoaToWekaInstanceConverter;
import moa.classifiers.AbstractClassifier;
import moa.classifiers.Regressor;
import moa.core.Measurement;
import moa.options.WEKAClassOption;
import weka.classifiers.Classifier;
public class WEKAClassifierTrainSlidingWindow
extends AbstractClassifier
implements Regressor {
private static final long serialVersionUID = 1L;
protected SamoaToWekaInstanceConverter instanceConverter;
public WEKAClassOption baseLearnerOption = new WEKAClassOption("baseLearner", 'l',
"Classifier to train.", weka.classifiers.Classifier.class, "weka.classifiers.bayes.NaiveBayesUpdateable");
public IntOption widthOption = new IntOption("width",
'w', "Size of Window for training learner.", 10, 1, Integer.MAX_VALUE);
public IntOption minInstancesForTraining = new IntOption("minInstances",
'm', "Minimum number of instances to use for allowing to build a classifier before the sliding window is full. Should be smaller or equal to the window width.", 1, 1, Integer.MAX_VALUE);
protected Classifier classifier;
protected weka.core.Instances instancesBuffer;
protected boolean isClassificationEnabled;
@Override
public String getPurposeString() {
return "Classifier from Weka trained on a sliding window.";
}
@Override
public boolean isRandomizable() {
return false;
}
@Override
public void resetLearningImpl() {
resetWekaClassifier();
this.instanceConverter = new SamoaToWekaInstanceConverter();
instancesBuffer = null;
isClassificationEnabled = false;
}
protected void resetWekaClassifier() {
try {
//System.out.println(baseLearnerOption.getValue());
String[] options = weka.core.Utils.splitOptions(baseLearnerOption.getValueAsCLIString());
createWekaClassifier(options);
isClassificationEnabled = false;
} catch (Exception e) {
System.err.println("Creating a new classifier: " + e.getMessage());
}
}
public void createWekaClassifier(String[] options) throws Exception {
String classifierName = options[0];
String[] newoptions = options.clone();
newoptions[0] = "";
this.classifier = weka.classifiers.AbstractClassifier.forName(classifierName, newoptions);
}
@Override
public void trainOnInstanceImpl(Instance samoaInstance) {
weka.core.Instance inst = this.instanceConverter.wekaInstance(samoaInstance);
try {
if (instancesBuffer == null)
instancesBuffer = new weka.core.Instances(inst.dataset());
if (instancesBuffer.size() < this.widthOption.getValue()) {
instancesBuffer.add(inst);
} else {
instancesBuffer.remove(0);
instancesBuffer.add(inst);
}
if (instancesBuffer.size() >= this.minInstancesForTraining.getValue()) {
resetWekaClassifier();
classifier.buildClassifier(instancesBuffer);
this.isClassificationEnabled = true;
}
} catch (Exception e) {
System.err.println("Training: " + e.getMessage());
}
}
@Override
public double[] getVotesForInstance(Instance samoaInstance) {
weka.core.Instance inst = this.instanceConverter.wekaInstance(samoaInstance);
double[] votes = new double[inst.numClasses()];
if (isClassificationEnabled == false) {
for (int i = 0; i < inst.numClasses(); i++) {
votes[i] = 1.0 / inst.numClasses();
}
} else {
try {
votes = this.classifier.distributionForInstance(inst);
} catch (Exception e) {
System.err.println(e.getMessage());
}
}
return votes;
}
@Override
protected Measurement[] getModelMeasurementsImpl() {
Measurement[] m = new Measurement[0];
return m;
}
@Override
public void getModelDescription(StringBuilder out, int indent) {
if (classifier != null) {
out.append(classifier.toString());
}
}
}