-
Notifications
You must be signed in to change notification settings - Fork 0
/
BasicRegressionPerformanceEvaluator.java
169 lines (140 loc) · 6.1 KB
/
BasicRegressionPerformanceEvaluator.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
// <---le08/02/18 Modifications to consider error of predictions in log scale
// MeanErrorLogScale = abs(log(actual_value) - log(prediction))/num_examples
/*
* BasicRegressionPerformanceEvaluator.java
* Copyright (C) 2011 University of Waikato, Hamilton, New Zealand
* @author Albert Bifet (abifet at cs dot waikato dot ac dot nz)
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*
*/
package moa.evaluation;
import moa.AbstractMOAObject;
import moa.core.Example;
import moa.core.Measurement;
import com.yahoo.labs.samoa.instances.Instance;
import com.yahoo.labs.samoa.instances.InstanceData;
import com.yahoo.labs.samoa.instances.Prediction;
/**
* Regression evaluator that performs basic incremental evaluation.
*
* @author Albert Bifet (abifet at cs dot waikato dot ac dot nz)
* @version $Revision: 7 $
*/
public class BasicRegressionPerformanceEvaluator extends AbstractMOAObject
implements RegressionPerformanceEvaluator {
private static final long serialVersionUID = 1L;
protected double weightObserved;
protected double squareError;
protected double averageError;
protected double averageErrorLogScale; // <---le08/02/18
protected double sumTarget;
protected double squareTargetError;
protected double averageTargetError;
@Override
public void reset() {
this.weightObserved = 0.0;
this.squareError = 0.0;
this.averageError = 0.0;
this.sumTarget = 0.0;
this.averageTargetError = 0.0;
this.squareTargetError = 0.0;
}
@Override
public void addResult(Example<Instance> example, double[] prediction) {
Instance inst = example.getData();
if (inst.weight() > 0.0) {
if (prediction.length > 0) {
double meanTarget = this.weightObserved != 0 ?
this.sumTarget / this.weightObserved : 0.0;
this.squareError += (inst.classValue() - prediction[0]) * (inst.classValue() - prediction[0]);
this.averageError += Math.abs(inst.classValue() - prediction[0]);
double tmpLogClassValue = 0, tmpLogPrediction = 0;
if (inst.classValue() >= 1)
tmpLogClassValue = Math.log10(inst.classValue());
if (prediction[0] >= 1)
tmpLogPrediction = Math.log10(prediction[0]);
this.averageErrorLogScale += Math.abs(tmpLogClassValue - tmpLogPrediction); // <---le08/02/18
this.squareTargetError += (inst.classValue() - meanTarget) * (inst.classValue() - meanTarget);
this.averageTargetError += Math.abs(inst.classValue() - meanTarget);
this.sumTarget += inst.classValue();
this.weightObserved += inst.weight();
}
//System.out.println(inst.classValue()+", "+prediction[0]);
}
}
@Override
public Measurement[] getPerformanceMeasurements() {
return new Measurement[]{
new Measurement("classified instances",
getTotalWeightObserved()),
new Measurement("mean absolute error",
getMeanError()),
new Measurement("mean absolute error log scale",
getMeanErrorLogScale()), // <---le08/02/18
new Measurement("root mean squared error",
getSquareError()),
new Measurement("relative mean absolute error",
getRelativeMeanError()),
new Measurement("relative root mean squared error",
getRelativeSquareError())
};
}
public double getTotalWeightObserved() {
return this.weightObserved;
}
public double getMeanError() {
return this.weightObserved > 0.0 ? this.averageError
/ this.weightObserved : 0.0;
}
// <---le08/02/18
public double getMeanErrorLogScale() {
return this.weightObserved > 0.0 ? this.averageErrorLogScale
/ this.weightObserved : 0.0;
}
public double getSquareError() {
return Math.sqrt(this.weightObserved > 0.0 ? this.squareError
/ this.weightObserved : 0.0);
}
public double getTargetMeanError() {
return this.weightObserved > 0.0 ? this.averageTargetError
/ this.weightObserved : 0.0;
}
public double getTargetSquareError() {
return Math.sqrt(this.weightObserved > 0.0 ? this.squareTargetError
/ this.weightObserved : 0.0);
}
@Override
public void getDescription(StringBuilder sb, int indent) {
Measurement.getMeasurementsDescription(getPerformanceMeasurements(),
sb, indent);
}
private double getRelativeMeanError() {
//double targetMeanError = getTargetMeanError();
//return targetMeanError > 0 ? getMeanError()/targetMeanError : 0.0;
return this.averageTargetError> 0 ?
this.averageError/this.averageTargetError : 0.0;
}
private double getRelativeSquareError() {
//double targetSquareError = getTargetSquareError();
//return targetSquareError > 0 ? getSquareError()/targetSquareError : 0.0;
return Math.sqrt(this.squareTargetError> 0 ?
this.squareError/this.squareTargetError : 0.0);
}
@Override
public void addResult(Example<Instance> example, Prediction prediction) {
if(prediction!=null)
addResult(example,prediction.getVotes(0));
}
}