Skip to content

Commit

Permalink
Adding the coefficient of determination and adjusted coefficient of d…
Browse files Browse the repository at this point in the history
…etermination to Basic and Window Regression Performance Evaluators. (#281)
  • Loading branch information
hmgomes authored Sep 7, 2023
1 parent f37dffd commit 73cb667
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*
*
*/
package moa.evaluation;

Expand All @@ -24,7 +24,6 @@
import moa.core.Measurement;

import com.yahoo.labs.samoa.instances.Instance;
import com.yahoo.labs.samoa.instances.InstanceData;
import com.yahoo.labs.samoa.instances.Prediction;

/**
Expand All @@ -45,11 +44,15 @@ public class BasicRegressionPerformanceEvaluator extends AbstractMOAObject
protected double averageError;

protected double sumTarget;

protected double squareTargetError;

protected double averageTargetError;

protected double totalSumSquares;

protected double numAttributes;

@Override
public void reset() {
this.weightObserved = 0.0;
Expand All @@ -58,47 +61,67 @@ public void reset() {
this.sumTarget = 0.0;
this.averageTargetError = 0.0;
this.squareTargetError = 0.0;

this.totalSumSquares = 0.0;
this.numAttributes = 0.0;
}

@Override
public void addResult(Example<Instance> example, double[] prediction) {
Instance inst = example.getData();
Instance inst = example.getData();
if (inst.weight() > 0.0) {
if (prediction.length > 0) {
double meanTarget = this.weightObserved != 0 ?
this.sumTarget / this.weightObserved : 0.0;
double meanTarget = this.weightObserved != 0 ?
this.sumTarget / this.weightObserved : 0.0;
this.squareError += (inst.classValue() - prediction[0]) * (inst.classValue() - prediction[0]);
this.averageError += Math.abs(inst.classValue() - prediction[0]);
this.squareTargetError += (inst.classValue() - meanTarget) * (inst.classValue() - meanTarget);
this.averageTargetError += Math.abs(inst.classValue() - meanTarget);
this.sumTarget += inst.classValue();
this.weightObserved += inst.weight();
this.numAttributes = inst.numAttributes();
}
//System.out.println(inst.classValue()+", "+prediction[0]);
}
}

@Override
public Measurement[] getPerformanceMeasurements() {
return new Measurement[]{
new Measurement("classified instances",
getTotalWeightObserved()),
new Measurement("mean absolute error",
getMeanError()),
new Measurement("root mean squared error",
getSquareError()),
new Measurement("relative mean absolute error",
getRelativeMeanError()),
new Measurement("relative root mean squared error",
getRelativeSquareError())
new Measurement("classified instances",
getTotalWeightObserved()),
new Measurement("mean absolute error",
getMeanError()),
new Measurement("root mean squared error",
getSquareError()),
new Measurement("relative mean absolute error",
getRelativeMeanError()),
new Measurement("relative root mean squared error",
getRelativeSquareError()),
new Measurement("coefficient of determination",
getCoefficientOfDetermination()),
new Measurement("adjusted coefficient of determination",
getAdjustedCoefficientOfDetermination())
};
}

public double getTotalWeightObserved() {
return this.weightObserved;
}

public double getCoefficientOfDetermination() {
if(weightObserved > 0.0) {
double SSres = squareError;
double SStot = squareTargetError;

return 1 - (SSres / SStot);
}
return 0.0;
}

public double getAdjustedCoefficientOfDetermination() {
return 1 - ((1-getCoefficientOfDetermination())*(getTotalWeightObserved() - 1)) /
(getTotalWeightObserved() - numAttributes - 1);
}

public double getMeanError() {
return this.weightObserved > 0.0 ? this.averageError
/ this.weightObserved : 0.0;
Expand Down Expand Up @@ -130,18 +153,18 @@ private double getRelativeMeanError() {
//return targetMeanError > 0 ? getMeanError()/targetMeanError : 0.0;
return this.averageTargetError> 0 ?
this.averageError/this.averageTargetError : 0.0;
}
}

private double getRelativeSquareError() {
//double targetSquareError = getTargetSquareError();
//return targetSquareError > 0 ? getSquareError()/targetSquareError : 0.0;
return Math.sqrt(this.squareTargetError> 0 ?
return Math.sqrt(this.squareTargetError> 0 ?
this.squareError/this.squareTargetError : 0.0);
}

@Override
public void addResult(Example<Instance> example, Prediction prediction) {
if(prediction!=null)
addResult(example,prediction.getVotes(0));
if(prediction!=null)
addResult(example,prediction.getVotes(0));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*
*
*/
package moa.evaluation;

Expand All @@ -29,7 +29,6 @@
import moa.tasks.TaskMonitor;

import com.yahoo.labs.samoa.instances.Instance;
import com.yahoo.labs.samoa.instances.InstanceData;
import com.yahoo.labs.samoa.instances.Prediction;

/**
Expand All @@ -52,6 +51,14 @@ public class WindowRegressionPerformanceEvaluator extends AbstractOptionHandler

protected Estimator squareError;

protected Estimator squareTargetError;

protected Estimator sumTarget;

protected double numAttributes;

protected Estimator averageTargetError;

protected Estimator averageError;

protected int numClasses;
Expand Down Expand Up @@ -103,6 +110,9 @@ public void reset(int numClasses) {
this.weightObserved = new Estimator(this.widthOption.getValue());
this.squareError = new Estimator(this.widthOption.getValue());
this.averageError = new Estimator(this.widthOption.getValue());
this.squareTargetError = new Estimator(this.widthOption.getValue());
this.sumTarget = new Estimator(this.widthOption.getValue());
this.averageTargetError = new Estimator(this.widthOption.getValue());
this.TotalweightObserved = 0;
}

Expand All @@ -118,7 +128,17 @@ public void addResult(Example<Instance> example, double[] prediction) {
this.weightObserved.add(weight);

if (prediction.length > 0) {
double meanTarget = this.weightObserved.total() != 0 ?
this.sumTarget.total() / this.weightObserved.total() : 0.0;

this.squareError.add((inst.classValue() - prediction[0]) * (inst.classValue() - prediction[0]));

this.squareTargetError.add((inst.classValue() - meanTarget) * (inst.classValue() - meanTarget));
this.sumTarget.add(inst.classValue());
this.numAttributes = inst.numAttributes();

this.averageTargetError.add(Math.abs(inst.classValue() - meanTarget));

this.averageError.add(Math.abs(inst.classValue() - prediction[0]));
}
//System.out.println(inst.classValue()+", "+prediction[0]);
Expand All @@ -128,12 +148,52 @@ public void addResult(Example<Instance> example, double[] prediction) {
@Override
public Measurement[] getPerformanceMeasurements() {
return new Measurement[]{
new Measurement("classified instances",
getTotalWeightObserved()),
new Measurement("mean absolute error",
getMeanError()),
new Measurement("root mean squared error",
getSquareError())};
new Measurement("classified instances",
getTotalWeightObserved()),
new Measurement("mean absolute error",
getMeanError()),
new Measurement("root mean squared error",
getSquareError()),
new Measurement("relative mean absolute error",
getRelativeMeanError()),
new Measurement("relative root mean squared error",
getRelativeSquareError()),
new Measurement("coefficient of determination",
getCoefficientOfDetermination()),
new Measurement("adjusted coefficient of determination",
getAdjustedCoefficientOfDetermination())
};
}

public double getCoefficientOfDetermination() {
if(weightObserved.total() > 0.0) {
double SSres = squareError.total();
double SStot = squareTargetError.total();

return 1 - (SSres / SStot);
}
return 0.0;
}

public double getAdjustedCoefficientOfDetermination() {
return 1 - ((1-getCoefficientOfDetermination())*(getTotalWeightObserved() - 1)) /
(getTotalWeightObserved() - numAttributes - 1);
}

private double getRelativeMeanError() {
//double targetMeanError = getTargetMeanError();
//return targetMeanError > 0 ? getMeanError()/targetMeanError : 0.0;
return this.averageTargetError.total() > 0 ?
this.averageError.total() / this.averageTargetError.total() : 0.0;
// //TODO: implement!
// return -1.0;
}

private double getRelativeSquareError() {
//double targetSquareError = getTargetSquareError();
//return targetSquareError > 0 ? getSquareError()/targetSquareError : 0.0;
return Math.sqrt(this.squareTargetError.total() > 0 ?
this.squareError.total() / this.squareTargetError.total() : 0.0);
}

public double getTotalWeightObserved() {
Expand All @@ -158,18 +218,18 @@ public void getDescription(StringBuilder sb, int indent) {

@Override
public void prepareForUseImpl(TaskMonitor monitor,
ObjectRepository repository) {
ObjectRepository repository) {
}


@Override
public void addResult(Example<Instance> testInst, Prediction prediction) {
double votes[];
if(prediction==null)
votes = new double[0];
else
votes=prediction.getVotes();
addResult(testInst, votes);

}


@Override
public void addResult(Example<Instance> testInst, Prediction prediction) {
double votes[];
if(prediction==null)
votes = new double[0];
else
votes=prediction.getVotes();
addResult(testInst, votes);

}
}

0 comments on commit 73cb667

Please sign in to comment.