From dbfee07cb322e3416c92dc497d74fe3e328ea3fa Mon Sep 17 00:00:00 2001 From: Anton Lee Date: Thu, 21 Sep 2023 14:59:22 +1200 Subject: [PATCH 1/2] Add PredictionLoggerEvaluator --- .../LearningPerformanceEvaluator.java | 6 +- .../evaluation/PredictionLoggerEvaluator.java | 160 ++++++++++++++++++ .../moa/tasks/EvaluateInterleavedChunks.java | 5 + .../EvaluateInterleavedTestThenTrain.java | 5 + .../main/java/moa/tasks/EvaluateModel.java | 31 +--- .../tasks/EvaluatePeriodicHeldOutTest.java | 5 + .../java/moa/tasks/EvaluatePrequential.java | 44 ++--- .../java/moa/tasks/EvaluatePrequentialCV.java | 6 + .../moa/tasks/EvaluatePrequentialDelayed.java | 32 +--- .../tasks/EvaluatePrequentialDelayedCV.java | 6 + .../WriteConfigurationToJupyterNotebook.java | 2 - 11 files changed, 211 insertions(+), 91 deletions(-) create mode 100644 moa/src/main/java/moa/evaluation/PredictionLoggerEvaluator.java diff --git a/moa/src/main/java/moa/evaluation/LearningPerformanceEvaluator.java b/moa/src/main/java/moa/evaluation/LearningPerformanceEvaluator.java index a7c655be8..3337b9cf7 100644 --- a/moa/src/main/java/moa/evaluation/LearningPerformanceEvaluator.java +++ b/moa/src/main/java/moa/evaluation/LearningPerformanceEvaluator.java @@ -35,7 +35,7 @@ * @author Richard Kirkby (rkirkby@cs.waikato.ac.nz) * @version $Revision: 7 $ */ -public interface LearningPerformanceEvaluator extends MOAObject, CapabilitiesHandler { +public interface LearningPerformanceEvaluator extends MOAObject, CapabilitiesHandler, AutoCloseable { /** * Resets this evaluator. It must be similar to @@ -66,4 +66,8 @@ default ImmutableCapabilities defineImmutableCapabilities() { return new ImmutableCapabilities(Capability.VIEW_STANDARD); } + @Override + default void close() throws Exception { + // By default an evaluator does nothing when closed. + } } diff --git a/moa/src/main/java/moa/evaluation/PredictionLoggerEvaluator.java b/moa/src/main/java/moa/evaluation/PredictionLoggerEvaluator.java new file mode 100644 index 000000000..a3145cda6 --- /dev/null +++ b/moa/src/main/java/moa/evaluation/PredictionLoggerEvaluator.java @@ -0,0 +1,160 @@ +package moa.evaluation; + +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.OutputStreamWriter; +import java.util.Arrays; +import java.util.zip.GZIPOutputStream; + +import com.github.javacliparser.FileOption; +import com.github.javacliparser.FlagOption; +import com.yahoo.labs.samoa.instances.Instance; +import com.yahoo.labs.samoa.instances.Prediction; + +import moa.capabilities.Capability; +import moa.capabilities.ImmutableCapabilities; +import moa.core.Example; +import moa.core.Measurement; +import moa.core.ObjectRepository; +import moa.core.Utils; +import moa.options.AbstractOptionHandler; +import moa.options.ClassOption; +import moa.tasks.TaskMonitor; + +public class PredictionLoggerEvaluator extends AbstractOptionHandler + implements ClassificationPerformanceEvaluator { + + private static final long serialVersionUID = 1L; + + private OutputStreamWriter writer; + private int index = 0; + + public FileOption outputPredictionFileOption = new FileOption("output", 'o', + "A file to write comma separated predictions to.", null, "csv.gzip", true); + + public FlagOption overwrite = new FlagOption("overwrite", 'f', "Overwrite existing file."); + + public ClassOption wrappedEvaluatorOption = new ClassOption("evaluator", 'e', + "Classification performance evaluation method.", ClassificationPerformanceEvaluator.class, + "BasicClassificationPerformanceEvaluator"); + + public FlagOption probabilities = new FlagOption("probabilities", 'p', + "Log probabilities instead of raw predictions."); + + public FlagOption uncompressed = new FlagOption("uncompressed", 'u', + "The output file should be saved uncompressed."); + + private ClassificationPerformanceEvaluator wrappedEvaluator; + + @Override + public String getPurposeString() { + return "Log raw predictions and probabilities to a CSV file, and evaluate using a wrapped evaluator."; + } + + @Override + public void addResult(Example example, double[] classVotes) { + Instance instance = example.getData(); + int predictedClass = Utils.maxIndex(classVotes); + double normalizingFactor = Arrays.stream(classVotes).sum(); + int numClasses = instance.numClasses(); + + if (normalizingFactor == 0) { + normalizingFactor = 1; + } + try { + // If this is the first result, write the header to the top of the file + if (index == 0) + writeHeader(numClasses); + + + // Add row to CSV file + if (instance.classIsMissing() == true) + { + writer.write(String.format("?,%d,", predictedClass)); + } + else + { + int trueClass = (int) instance.classValue(); + writer.write(String.format("%d,%d,", trueClass, predictedClass)); + } + + if (probabilities.isSet()) { + for (int i = 0; i < numClasses; i++) { + double probability = 0.0; + if (i < classVotes.length){ + probability = classVotes[i] / normalizingFactor; + } + writer.write(String.format("%.2f,", probability)); + } + } + + writer.write("\n"); + } catch (Exception e) { + throw new RuntimeException(e); + } + + // Pass result to wrapped evaluator + wrappedEvaluator.addResult(example, classVotes); + index ++; + } + + @Override + public void addResult(Example testInst, Prediction prediction) { + throw new UnsupportedOperationException("Not implemented"); + } + + @Override + protected void prepareForUseImpl(TaskMonitor monitor, ObjectRepository repository) { + wrappedEvaluator = (ClassificationPerformanceEvaluator) getPreparedClassOption(wrappedEvaluatorOption); + try { + File file = outputPredictionFileOption.getFile(); + if (file.exists() && !overwrite.isSet()) { + throw new RuntimeException( + "File already exists: " + file.getAbsolutePath() + + ". MOA doesn't want to overwrite it."); + } + if (uncompressed.isSet()) + writer = new OutputStreamWriter(new FileOutputStream(file)); + else + writer = new OutputStreamWriter(new GZIPOutputStream(new FileOutputStream(file))); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + private void writeHeader(int numClasses) throws IOException { + writer.write("true_class,class_prediction,"); + if (probabilities.isSet()) { + for (int i = 0; i < numClasses; i++) { + writer.write(String.format("class_probability_%d,", i)); + } + } + writer.write("\n"); + } + + @Override + public void close() throws Exception { + writer.close(); + } + + @Override + public void reset() { + wrappedEvaluator.reset(); + } + + @Override + public Measurement[] getPerformanceMeasurements() { + return wrappedEvaluator.getPerformanceMeasurements(); + } + + @Override + public void getDescription(StringBuilder sb, int indent) { + sb.append(getPurposeString()); + } + + @Override + public ImmutableCapabilities defineImmutableCapabilities() { + return new ImmutableCapabilities(Capability.VIEW_STANDARD); + } +} diff --git a/moa/src/main/java/moa/tasks/EvaluateInterleavedChunks.java b/moa/src/main/java/moa/tasks/EvaluateInterleavedChunks.java index 0b01516d8..3bbd41d57 100644 --- a/moa/src/main/java/moa/tasks/EvaluateInterleavedChunks.java +++ b/moa/src/main/java/moa/tasks/EvaluateInterleavedChunks.java @@ -287,6 +287,11 @@ protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) { if (immediateResultStream != null) { immediateResultStream.close(); } + try { + evaluator.close(); + } catch (Exception ex) { + throw new RuntimeException("Exception closing evaluator", ex); + } return learningCurve; } diff --git a/moa/src/main/java/moa/tasks/EvaluateInterleavedTestThenTrain.java b/moa/src/main/java/moa/tasks/EvaluateInterleavedTestThenTrain.java index e0c69a5bd..5a4fd1e67 100644 --- a/moa/src/main/java/moa/tasks/EvaluateInterleavedTestThenTrain.java +++ b/moa/src/main/java/moa/tasks/EvaluateInterleavedTestThenTrain.java @@ -217,6 +217,11 @@ protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) { if (immediateResultStream != null) { immediateResultStream.close(); } + try { + evaluator.close(); + } catch (Exception ex) { + throw new RuntimeException("Exception closing evaluator", ex); + } return learningCurve; } diff --git a/moa/src/main/java/moa/tasks/EvaluateModel.java b/moa/src/main/java/moa/tasks/EvaluateModel.java index 334e59e78..be9a55689 100644 --- a/moa/src/main/java/moa/tasks/EvaluateModel.java +++ b/moa/src/main/java/moa/tasks/EvaluateModel.java @@ -107,35 +107,10 @@ public Object doMainTask(TaskMonitor monitor, ObjectRepository repository) { long instancesProcessed = 0; monitor.setCurrentActivity("Evaluating model...", -1.0); - //File for output predictions - File outputPredictionFile = this.outputPredictionFileOption.getFile(); - PrintStream outputPredictionResultStream = null; - if (outputPredictionFile != null) { - try { - if (outputPredictionFile.exists()) { - outputPredictionResultStream = new PrintStream( - new FileOutputStream(outputPredictionFile, true), true); - } else { - outputPredictionResultStream = new PrintStream( - new FileOutputStream(outputPredictionFile), true); - } - } catch (Exception ex) { - throw new RuntimeException( - "Unable to open prediction result file: " + outputPredictionFile, ex); - } - } while (stream.hasMoreInstances() && ((maxInstances < 0) || (instancesProcessed < maxInstances))) { Example testInst = (Example) stream.nextInstance();//.copy(); - int trueClass = (int) ((Instance) testInst.getData()).classValue(); - //testInst.setClassMissing(); double[] prediction = model.getVotesForInstance(testInst); - //evaluator.addClassificationAttempt(trueClass, prediction, testInst - // .weight()); - if (outputPredictionFile != null) { - outputPredictionResultStream.println(Utils.maxIndex(prediction) + "," +( - ((Instance) testInst.getData()).classIsMissing() == true ? " ? " : trueClass)); - } evaluator.addResult(testInst, prediction); instancesProcessed++; @@ -169,8 +144,10 @@ public Object doMainTask(TaskMonitor monitor, ObjectRepository repository) { } } } - if (outputPredictionResultStream != null) { - outputPredictionResultStream.close(); + try { + evaluator.close(); + } catch (Exception ex) { + throw new RuntimeException("Exception closing evaluator", ex); } return learningCurve; } diff --git a/moa/src/main/java/moa/tasks/EvaluatePeriodicHeldOutTest.java b/moa/src/main/java/moa/tasks/EvaluatePeriodicHeldOutTest.java index 3e8a511ae..22a7f8bbb 100644 --- a/moa/src/main/java/moa/tasks/EvaluatePeriodicHeldOutTest.java +++ b/moa/src/main/java/moa/tasks/EvaluatePeriodicHeldOutTest.java @@ -285,6 +285,11 @@ protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) { if (immediateResultStream != null) { immediateResultStream.close(); } + try { + evaluator.close(); + } catch (Exception ex) { + throw new RuntimeException("Exception closing evaluator", ex); + } return learningCurve; } diff --git a/moa/src/main/java/moa/tasks/EvaluatePrequential.java b/moa/src/main/java/moa/tasks/EvaluatePrequential.java index 80034890e..d9a3180ea 100644 --- a/moa/src/main/java/moa/tasks/EvaluatePrequential.java +++ b/moa/src/main/java/moa/tasks/EvaluatePrequential.java @@ -59,7 +59,9 @@ public class EvaluatePrequential extends ClassificationMainTask implements Capab @Override public String getPurposeString() { - return "Evaluates a classifier on a stream by testing then training with each example in sequence."; + return + "Evaluates a classifier on a stream by testing then training with each example in sequence." + + "\n`outputPredictionFile` has been replaced with the `PredictionLoggerEvaluator`"; } private static final long serialVersionUID = 1L; @@ -97,9 +99,6 @@ public String getPurposeString() { public FileOption dumpFileOption = new FileOption("dumpFile", 'd', "File to append intermediate csv results to.", null, "csv", true); - public FileOption outputPredictionFileOption = new FileOption("outputPredictionFile", 'o', - "File to append output predictions to.", null, "pred", true); - //New for prequential method DEPRECATED public IntOption widthOption = new IntOption("width", 'w', "Size of Window", 1000); @@ -168,23 +167,6 @@ protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) { "Unable to open immediate result file: " + dumpFile, ex); } } - //File for output predictions - File outputPredictionFile = this.outputPredictionFileOption.getFile(); - PrintStream outputPredictionResultStream = null; - if (outputPredictionFile != null) { - try { - if (outputPredictionFile.exists()) { - outputPredictionResultStream = new PrintStream( - new FileOutputStream(outputPredictionFile, true), true); - } else { - outputPredictionResultStream = new PrintStream( - new FileOutputStream(outputPredictionFile), true); - } - } catch (Exception ex) { - throw new RuntimeException( - "Unable to open prediction result file: " + outputPredictionFile, ex); - } - } boolean firstDump = true; boolean preciseCPUTiming = TimingUtils.enablePreciseTiming(); long evaluateStartTime = TimingUtils.getNanoCPUTimeOfCurrentThread(); @@ -194,20 +176,14 @@ protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) { && ((maxInstances < 0) || (instancesProcessed < maxInstances)) && ((maxSeconds < 0) || (secondsElapsed < maxSeconds))) { Example trainInst = stream.nextInstance(); - Example testInst = (Example) trainInst; //.copy(); - //testInst.setClassMissing(); - double[] prediction = learner.getVotesForInstance(testInst); - // Output prediction - if (outputPredictionFile != null) { - int trueClass = (int) ((Instance) trainInst.getData()).classValue(); - outputPredictionResultStream.println(Utils.maxIndex(prediction) + "," + ( - ((Instance) testInst.getData()).classIsMissing() == true ? " ? " : trueClass)); - } + Example testInst = (Example) trainInst; - //evaluator.addClassificationAttempt(trueClass, prediction, testInst.weight()); + double[] prediction = learner.getVotesForInstance(testInst); evaluator.addResult(testInst, prediction); + learner.trainOnInstance(trainInst); instancesProcessed++; + if (instancesProcessed % this.sampleFrequencyOption.getValue() == 0 || stream.hasMoreInstances() == false) { long evaluateTime = TimingUtils.getNanoCPUTimeOfCurrentThread(); @@ -267,8 +243,10 @@ protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) { if (immediateResultStream != null) { immediateResultStream.close(); } - if (outputPredictionResultStream != null) { - outputPredictionResultStream.close(); + try { + evaluator.close(); + } catch (Exception ex) { + throw new RuntimeException("Exception closing evaluator", ex); } return learningCurve; } diff --git a/moa/src/main/java/moa/tasks/EvaluatePrequentialCV.java b/moa/src/main/java/moa/tasks/EvaluatePrequentialCV.java index 75ed8c46b..567fbc5ff 100644 --- a/moa/src/main/java/moa/tasks/EvaluatePrequentialCV.java +++ b/moa/src/main/java/moa/tasks/EvaluatePrequentialCV.java @@ -257,6 +257,12 @@ protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) { if (immediateResultStream != null) { immediateResultStream.close(); } + try { + for (LearningPerformanceEvaluator evaluator : evaluators) + evaluator.close(); + } catch (Exception ex) { + throw new RuntimeException("Exception closing evaluator", ex); + } return learningCurve; } diff --git a/moa/src/main/java/moa/tasks/EvaluatePrequentialDelayed.java b/moa/src/main/java/moa/tasks/EvaluatePrequentialDelayed.java index 40768892f..a30da8e02 100644 --- a/moa/src/main/java/moa/tasks/EvaluatePrequentialDelayed.java +++ b/moa/src/main/java/moa/tasks/EvaluatePrequentialDelayed.java @@ -120,9 +120,6 @@ public String getPurposeString() { public FileOption dumpFileOption = new FileOption("dumpFile", 'd', "File to append intermediate csv results to.", null, "csv", true); - public FileOption outputPredictionFileOption = new FileOption("outputPredictionFile", 'o', - "File to append output predictions to.", null, "pred", true); - //New for prequential method DEPRECATED public IntOption widthOption = new IntOption("width", 'w', "Size of Window", 1000); @@ -194,23 +191,6 @@ protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) { "Unable to open immediate result file: " + dumpFile, ex); } } - //File for output predictions - File outputPredictionFile = this.outputPredictionFileOption.getFile(); - PrintStream outputPredictionResultStream = null; - if (outputPredictionFile != null) { - try { - if (outputPredictionFile.exists()) { - outputPredictionResultStream = new PrintStream( - new FileOutputStream(outputPredictionFile, true), true); - } else { - outputPredictionResultStream = new PrintStream( - new FileOutputStream(outputPredictionFile), true); - } - } catch (Exception ex) { - throw new RuntimeException( - "Unable to open prediction result file: " + outputPredictionFile, ex); - } - } boolean firstDump = true; boolean preciseCPUTiming = TimingUtils.enablePreciseTiming(); long evaluateStartTime = TimingUtils.getNanoCPUTimeOfCurrentThread(); @@ -261,12 +241,6 @@ else if((this.initialWindowSizeOption.getValue() - instancesProcessed) < this.de testInstance = ((Instance) currentInst.getData()).copy(); testInst = new InstanceExample(testInstance); - // Output prediction - if (outputPredictionFile != null) { - int trueClass = (int) ((Instance) currentInst.getData()).classValue(); - outputPredictionResultStream.println(Utils.maxIndex(prediction) + "," + ( - ((Instance) testInst.getData()).classIsMissing() == true ? " ? " : trueClass)); - } evaluator.addResult(testInst, prediction); if (instancesProcessed % this.sampleFrequencyOption.getValue() == 0 @@ -328,8 +302,10 @@ else if((this.initialWindowSizeOption.getValue() - instancesProcessed) < this.de if (immediateResultStream != null) { immediateResultStream.close(); } - if (outputPredictionResultStream != null) { - outputPredictionResultStream.close(); + try { + evaluator.close(); + } catch (Exception ex) { + throw new RuntimeException("Exception closing evaluator", ex); } return learningCurve; } diff --git a/moa/src/main/java/moa/tasks/EvaluatePrequentialDelayedCV.java b/moa/src/main/java/moa/tasks/EvaluatePrequentialDelayedCV.java index 58d79d91f..2f7cee1a6 100644 --- a/moa/src/main/java/moa/tasks/EvaluatePrequentialDelayedCV.java +++ b/moa/src/main/java/moa/tasks/EvaluatePrequentialDelayedCV.java @@ -282,6 +282,12 @@ protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) { if (immediateResultStream != null) { immediateResultStream.close(); } + try { + for (LearningPerformanceEvaluator evaluator : evaluators) + evaluator.close(); + } catch (Exception ex) { + throw new RuntimeException("Exception closing evaluator", ex); + } return learningCurve; } diff --git a/moa/src/main/java/moa/tasks/WriteConfigurationToJupyterNotebook.java b/moa/src/main/java/moa/tasks/WriteConfigurationToJupyterNotebook.java index 5aa14e8ae..85abbe5e8 100644 --- a/moa/src/main/java/moa/tasks/WriteConfigurationToJupyterNotebook.java +++ b/moa/src/main/java/moa/tasks/WriteConfigurationToJupyterNotebook.java @@ -104,7 +104,6 @@ protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) { learnerString = ((EvaluatePrequential) currentTask).learnerOption.getValueAsCLIString().replace('\\', '/'); evaluatorString = ((EvaluatePrequential) currentTask).evaluatorOption.getValueAsCLIString().replace('\\', '/'); dumpFile = ((EvaluatePrequential) currentTask).dumpFileOption.getFile(); - outputPredictionFile = ((EvaluatePrequential) currentTask).outputPredictionFileOption.getFile(); sampleFrequency = ((EvaluatePrequential) currentTask).sampleFrequencyOption.getValue(); instanceLimit = ((EvaluatePrequential) currentTask).instanceLimitOption.getValue(); @@ -154,7 +153,6 @@ protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) { learnerString = ((EvaluatePrequentialDelayed) currentTask).learnerOption.getValueAsCLIString().replace('\\', '/'); evaluatorString = ((EvaluatePrequentialDelayed) currentTask).evaluatorOption.getValueAsCLIString().replace('\\', '/'); dumpFile = ((EvaluatePrequentialDelayed) currentTask).dumpFileOption.getFile(); - outputPredictionFile = ((EvaluatePrequentialDelayed) currentTask).outputPredictionFileOption.getFile(); sampleFrequency = ((EvaluatePrequentialDelayed) currentTask).sampleFrequencyOption.getValue(); instanceLimit = ((EvaluatePrequentialDelayed) currentTask).instanceLimitOption.getValue(); trainOnInitialWindow = ((EvaluatePrequentialDelayed) currentTask).trainOnInitialWindowOption.isSet(); From 8b83ecc5e43ee36d0ee52b20dd8e317f5a379a17 Mon Sep 17 00:00:00 2001 From: Anton Lee Date: Thu, 21 Sep 2023 15:19:12 +1200 Subject: [PATCH 2/2] Close evaluator after every task --- .../java/moa/tasks/EvaluateConceptDrift.java | 8 +- .../moa/tasks/EvaluateModelMultiLabel.java | 5 + .../moa/tasks/EvaluateModelMultiTarget.java | 5 + .../moa/tasks/EvaluateModelRegression.java | 5 + .../tasks/EvaluatePrequentialMultiLabel.java | 5 + .../tasks/EvaluatePrequentialMultiTarget.java | 5 + ...aluatePrequentialMultiTargetSemiSuper.java | 6 +- .../tasks/EvaluatePrequentialRegression.java | 5 + .../meta/ALPrequentialEvaluationTask.java | 322 +++++++++--------- 9 files changed, 203 insertions(+), 163 deletions(-) diff --git a/moa/src/main/java/moa/tasks/EvaluateConceptDrift.java b/moa/src/main/java/moa/tasks/EvaluateConceptDrift.java index 322eb0b24..1086367cd 100644 --- a/moa/src/main/java/moa/tasks/EvaluateConceptDrift.java +++ b/moa/src/main/java/moa/tasks/EvaluateConceptDrift.java @@ -231,9 +231,11 @@ protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) { if (immediateResultStream != null) { immediateResultStream.close(); } - /* if (outputPredictionResultStream != null) { - outputPredictionResultStream.close(); - }*/ + try { + evaluator.close(); + } catch (Exception ex) { + throw new RuntimeException("Exception closing evaluator", ex); + } return learningCurve; } } \ No newline at end of file diff --git a/moa/src/main/java/moa/tasks/EvaluateModelMultiLabel.java b/moa/src/main/java/moa/tasks/EvaluateModelMultiLabel.java index b1e118db7..010e9d25a 100644 --- a/moa/src/main/java/moa/tasks/EvaluateModelMultiLabel.java +++ b/moa/src/main/java/moa/tasks/EvaluateModelMultiLabel.java @@ -156,6 +156,11 @@ public Object doMainTask(TaskMonitor monitor, ObjectRepository repository) { if (outputPredictionResultStream != null) { outputPredictionResultStream.close(); } + try { + evaluator.close(); + } catch (Exception ex) { + throw new RuntimeException("Exception closing evaluator", ex); + } return new LearningEvaluation(evaluator, model); } } diff --git a/moa/src/main/java/moa/tasks/EvaluateModelMultiTarget.java b/moa/src/main/java/moa/tasks/EvaluateModelMultiTarget.java index df5d18631..ded4f6637 100644 --- a/moa/src/main/java/moa/tasks/EvaluateModelMultiTarget.java +++ b/moa/src/main/java/moa/tasks/EvaluateModelMultiTarget.java @@ -152,6 +152,11 @@ public Object doMainTask(TaskMonitor monitor, ObjectRepository repository) { if (outputPredictionResultStream != null) { outputPredictionResultStream.close(); } + try { + evaluator.close(); + } catch (Exception ex) { + throw new RuntimeException("Exception closing evaluator", ex); + } return new LearningEvaluation(evaluator, model); } } diff --git a/moa/src/main/java/moa/tasks/EvaluateModelRegression.java b/moa/src/main/java/moa/tasks/EvaluateModelRegression.java index 34b57232c..69dee8c18 100644 --- a/moa/src/main/java/moa/tasks/EvaluateModelRegression.java +++ b/moa/src/main/java/moa/tasks/EvaluateModelRegression.java @@ -152,6 +152,11 @@ public Object doMainTask(TaskMonitor monitor, ObjectRepository repository) { if (outputPredictionResultStream != null) { outputPredictionResultStream.close(); } + try { + evaluator.close(); + } catch (Exception ex) { + throw new RuntimeException("Exception closing evaluator", ex); + } return new LearningEvaluation(evaluator, model); } } diff --git a/moa/src/main/java/moa/tasks/EvaluatePrequentialMultiLabel.java b/moa/src/main/java/moa/tasks/EvaluatePrequentialMultiLabel.java index 8a5a198a9..21427a7e0 100644 --- a/moa/src/main/java/moa/tasks/EvaluatePrequentialMultiLabel.java +++ b/moa/src/main/java/moa/tasks/EvaluatePrequentialMultiLabel.java @@ -289,6 +289,11 @@ protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) { if (outputPredictionResultStream != null) { outputPredictionResultStream.close(); } + try { + evaluator.close(); + } catch (Exception ex) { + throw new RuntimeException("Exception closing evaluator", ex); + } return learningCurve; } } diff --git a/moa/src/main/java/moa/tasks/EvaluatePrequentialMultiTarget.java b/moa/src/main/java/moa/tasks/EvaluatePrequentialMultiTarget.java index 28f782665..083bc879c 100644 --- a/moa/src/main/java/moa/tasks/EvaluatePrequentialMultiTarget.java +++ b/moa/src/main/java/moa/tasks/EvaluatePrequentialMultiTarget.java @@ -277,6 +277,11 @@ protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) { if (outputPredictionResultStream != null) { outputPredictionResultStream.close(); } + try { + evaluator.close(); + } catch (Exception ex) { + throw new RuntimeException("Exception closing evaluator", ex); + } return learningCurve; } } diff --git a/moa/src/main/java/moa/tasks/EvaluatePrequentialMultiTargetSemiSuper.java b/moa/src/main/java/moa/tasks/EvaluatePrequentialMultiTargetSemiSuper.java index 52b24b4dd..1f973bd60 100644 --- a/moa/src/main/java/moa/tasks/EvaluatePrequentialMultiTargetSemiSuper.java +++ b/moa/src/main/java/moa/tasks/EvaluatePrequentialMultiTargetSemiSuper.java @@ -405,7 +405,11 @@ protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) { if (outputPredictionResultStream != null) { outputPredictionResultStream.close(); } - + try { + evaluator.close(); + } catch (Exception ex) { + throw new RuntimeException("Exception closing evaluator", ex); + } return learningCurve; } } diff --git a/moa/src/main/java/moa/tasks/EvaluatePrequentialRegression.java b/moa/src/main/java/moa/tasks/EvaluatePrequentialRegression.java index 26a3c844c..25fb1b5fd 100644 --- a/moa/src/main/java/moa/tasks/EvaluatePrequentialRegression.java +++ b/moa/src/main/java/moa/tasks/EvaluatePrequentialRegression.java @@ -272,6 +272,11 @@ protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) { if (outputPredictionResultStream != null) { outputPredictionResultStream.close(); } + try { + evaluator.close(); + } catch (Exception ex) { + throw new RuntimeException("Exception closing evaluator", ex); + } return learningCurve; } } diff --git a/moa/src/main/java/moa/tasks/meta/ALPrequentialEvaluationTask.java b/moa/src/main/java/moa/tasks/meta/ALPrequentialEvaluationTask.java index a2813a1b6..21df6ccc7 100644 --- a/moa/src/main/java/moa/tasks/meta/ALPrequentialEvaluationTask.java +++ b/moa/src/main/java/moa/tasks/meta/ALPrequentialEvaluationTask.java @@ -52,90 +52,90 @@ * @version $Revision: 1 $ */ public class ALPrequentialEvaluationTask extends ALMainTask { - - private static final long serialVersionUID = 1L; - - @Override - public String getPurposeString() { - return "Perform prequential evaluation (testing, then training with" - + " each example in sequence) for an active learning" - + " classifier."; - } - - public ClassOption learnerOption = new ClassOption("learner", 'l', + + private static final long serialVersionUID = 1L; + + @Override + public String getPurposeString() { + return "Perform prequential evaluation (testing, then training with" + + " each example in sequence) for an active learning" + + " classifier."; + } + + public ClassOption learnerOption = new ClassOption("learner", 'l', "Learner to train.", ALClassifier.class, "moa.classifiers.active.ALRandom"); - - public ClassOption streamOption = new ClassOption("stream", 's', + + public ClassOption streamOption = new ClassOption("stream", 's', "Stream to learn from.", ExampleStream.class, "generators.RandomTreeGenerator"); - - public ClassOption evaluatorOption = new ClassOption( - "evaluator", 'e', + + public ClassOption evaluatorOption = new ClassOption( + "evaluator", 'e', "Active Learning classification performance evaluation method.", ALClassificationPerformanceEvaluator.class, "ALWindowClassificationPerformanceEvaluator"); - - public IntOption instanceLimitOption = new IntOption("instanceLimit", 'i', + + public IntOption instanceLimitOption = new IntOption("instanceLimit", 'i', "Maximum number of instances to test/train on (-1 = no limit).", 100000000, -1, Integer.MAX_VALUE); - - public IntOption timeLimitOption = new IntOption("timeLimit", 't', + + public IntOption timeLimitOption = new IntOption("timeLimit", 't', "Maximum number of seconds to test/train for (-1 = no limit).", -1, -1, Integer.MAX_VALUE); - - public IntOption sampleFrequencyOption = new IntOption("sampleFrequency", + + public IntOption sampleFrequencyOption = new IntOption("sampleFrequency", 'f', "How many instances between samples of the learning performance.", 100000, 0, Integer.MAX_VALUE); - - public FileOption dumpFileOption = new FileOption("dumpFile", 'd', + + public FileOption dumpFileOption = new FileOption("dumpFile", 'd', "File to append intermediate csv results to.", null, "csv", true); - - - /** - * Constructor which sets the color coding to black. - */ - public ALPrequentialEvaluationTask() { - this(Color.BLACK); - } - - /** - * Constructor with which a color coding can be set. - * @param colorCoding the color used by the task - */ - public ALPrequentialEvaluationTask(Color colorCoding) { - this.colorCoding = colorCoding; - } - - @Override - public Class getTaskResultType() { - return LearningCurve.class; - } - - @SuppressWarnings("unchecked") - @Override - protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) { - // get stream - ExampleStream> stream = - (ExampleStream>) - getPreparedClassOption(this.streamOption); - - // initialize learner - ALClassifier learner = - (ALClassifier) getPreparedClassOption(this.learnerOption); - learner.resetLearning(); - learner.setModelContext(stream.getHeader()); - - // get evaluator + + + /** + * Constructor which sets the color coding to black. + */ + public ALPrequentialEvaluationTask() { + this(Color.BLACK); + } + + /** + * Constructor with which a color coding can be set. + * @param colorCoding the color used by the task + */ + public ALPrequentialEvaluationTask(Color colorCoding) { + this.colorCoding = colorCoding; + } + + @Override + public Class getTaskResultType() { + return LearningCurve.class; + } + + @SuppressWarnings("unchecked") + @Override + protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) { + // get stream + ExampleStream> stream = + (ExampleStream>) + getPreparedClassOption(this.streamOption); + + // initialize learner + ALClassifier learner = + (ALClassifier) getPreparedClassOption(this.learnerOption); + learner.resetLearning(); + learner.setModelContext(stream.getHeader()); + + // get evaluator ALClassificationPerformanceEvaluator evaluator = (ALClassificationPerformanceEvaluator) - getPreparedClassOption(this.evaluatorOption); + getPreparedClassOption(this.evaluatorOption); // initialize learning curve LearningCurve learningCurve = new LearningCurve( - "learning evaluation instances"); + "learning evaluation instances"); - // perform training and testing + // perform training and testing int maxInstances = this.instanceLimitOption.getValue(); int instancesProcessed = 0; int maxSeconds = this.timeLimitOption.getValue(); @@ -149,75 +149,75 @@ protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) { File dumpFile = this.dumpFileOption.getFile(); PrintStream immediateResultStream = null; if (dumpFile != null) { - try { - if (dumpFile.exists()) { - immediateResultStream = new PrintStream( + try { + if (dumpFile.exists()) { + immediateResultStream = new PrintStream( new FileOutputStream(dumpFile, true), true); } else { immediateResultStream = new PrintStream( new FileOutputStream(dumpFile), true); } - } catch (Exception ex) { + } catch (Exception ex) { throw new RuntimeException( - "Unable to open immediate result file: " + dumpFile, ex); + "Unable to open immediate result file: " + dumpFile, ex); } } monitor.setCurrentActivity("Evaluating learner...", -1.0); while (stream.hasMoreInstances() - && ((maxInstances < 0) - || (instancesProcessed < maxInstances)) - && ((maxSeconds < 0) || (secondsElapsed < maxSeconds))) + && ((maxInstances < 0) + || (instancesProcessed < maxInstances)) + && ((maxSeconds < 0) || (secondsElapsed < maxSeconds))) { - Example trainInst = stream.nextInstance(); - Example testInst = trainInst; - - // predict class for instance - double[] prediction = learner.getVotesForInstance(testInst); - evaluator.addResult(testInst, prediction); - - // train on instance - learner.trainOnInstance(trainInst); - - // check if label was acquired - int labelAcquired = learner.getLastLabelAcqReport(); - evaluator.doLabelAcqReport(trainInst, labelAcquired); - - instancesProcessed++; - - // update learning curve - if (instancesProcessed % this.sampleFrequencyOption.getValue() == 0 - || !stream.hasMoreInstances()) - { - long evaluateTime = - TimingUtils.getNanoCPUTimeOfCurrentThread(); + Example trainInst = stream.nextInstance(); + Example testInst = trainInst; + + // predict class for instance + double[] prediction = learner.getVotesForInstance(testInst); + evaluator.addResult(testInst, prediction); + + // train on instance + learner.trainOnInstance(trainInst); + + // check if label was acquired + int labelAcquired = learner.getLastLabelAcqReport(); + evaluator.doLabelAcqReport(trainInst, labelAcquired); + + instancesProcessed++; + + // update learning curve + if (instancesProcessed % this.sampleFrequencyOption.getValue() == 0 + || !stream.hasMoreInstances()) + { + long evaluateTime = + TimingUtils.getNanoCPUTimeOfCurrentThread(); double time = TimingUtils.nanoTimeToSeconds( - evaluateTime - evaluateStartTime); + evaluateTime - evaluateStartTime); double timeIncrement = TimingUtils.nanoTimeToSeconds( - evaluateTime - lastEvaluateStartTime); + evaluateTime - lastEvaluateStartTime); double RAMHoursIncrement = - learner.measureByteSize() / (1024.0 * 1024.0 * 1024.0); //GBs + learner.measureByteSize() / (1024.0 * 1024.0 * 1024.0); //GBs RAMHoursIncrement *= (timeIncrement / 3600.0); //Hours RAMHours += RAMHoursIncrement; lastEvaluateStartTime = evaluateTime; - - learningCurve.insertEntry(new LearningEvaluation( - new Measurement[]{ - new Measurement( - "learning evaluation instances", - instancesProcessed), - new Measurement( - "evaluation time (" - + (preciseCPUTiming ? "cpu " - : "") + "seconds)", - time), - new Measurement( - "model cost (RAM-Hours)", - RAMHours), - }, - evaluator, learner)); - - if (immediateResultStream != null) { + + learningCurve.insertEntry(new LearningEvaluation( + new Measurement[]{ + new Measurement( + "learning evaluation instances", + instancesProcessed), + new Measurement( + "evaluation time (" + + (preciseCPUTiming ? "cpu " + : "") + "seconds)", + time), + new Measurement( + "model cost (RAM-Hours)", + RAMHours), + }, + evaluator, learner)); + + if (immediateResultStream != null) { if (firstDump) { immediateResultStream.println(learningCurve.headerToString()); firstDump = false; @@ -225,57 +225,61 @@ protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) { immediateResultStream.println(learningCurve.entryToString(learningCurve.numEntries() - 1)); immediateResultStream.flush(); } - } - - // update monitor - if (instancesProcessed % INSTANCES_BETWEEN_MONITOR_UPDATES == 0 && learningCurve.numEntries() > 0) { - if (monitor.taskShouldAbort()) { + } + + // update monitor + if (instancesProcessed % INSTANCES_BETWEEN_MONITOR_UPDATES == 0 && learningCurve.numEntries() > 0) { + if (monitor.taskShouldAbort()) { return null; } - - long estimatedRemainingInstances = - stream.estimatedRemainingInstances(); - - if (maxInstances > 0) { - long maxRemaining = maxInstances - instancesProcessed; - if ((estimatedRemainingInstances < 0 || estimatedRemainingInstances == 0) - || (maxRemaining < estimatedRemainingInstances)) - { - estimatedRemainingInstances = maxRemaining; - } - } - - - // calculate completion fraction - double fractionComplete = (double) instancesProcessed / - (instancesProcessed + estimatedRemainingInstances); - monitor.setCurrentActivityFractionComplete( - estimatedRemainingInstances < 0 ? - -1.0 : fractionComplete); - - - // TODO currently the preview is sent after each instance - // should be changed later on - if (monitor.resultPreviewRequested() || isSubtask()) { - monitor.setLatestResultPreview(new PreviewCollectionLearningCurveWrapper((LearningCurve)learningCurve.copy(), this.getClass())); + + long estimatedRemainingInstances = + stream.estimatedRemainingInstances(); + + if (maxInstances > 0) { + long maxRemaining = maxInstances - instancesProcessed; + if ((estimatedRemainingInstances < 0 || estimatedRemainingInstances == 0) + || (maxRemaining < estimatedRemainingInstances)) + { + estimatedRemainingInstances = maxRemaining; + } + } + + + // calculate completion fraction + double fractionComplete = (double) instancesProcessed / + (instancesProcessed + estimatedRemainingInstances); + monitor.setCurrentActivityFractionComplete( + estimatedRemainingInstances < 0 ? + -1.0 : fractionComplete); + + + // TODO currently the preview is sent after each instance + // should be changed later on + if (monitor.resultPreviewRequested() || isSubtask()) { + monitor.setLatestResultPreview(new PreviewCollectionLearningCurveWrapper((LearningCurve)learningCurve.copy(), this.getClass())); } - - // update time measurement - secondsElapsed = (int) TimingUtils.nanoTimeToSeconds( - TimingUtils.getNanoCPUTimeOfCurrentThread() + + // update time measurement + secondsElapsed = (int) TimingUtils.nanoTimeToSeconds( + TimingUtils.getNanoCPUTimeOfCurrentThread() - evaluateStartTime); - } + } } if (immediateResultStream != null) { immediateResultStream.close(); } - - return new PreviewCollectionLearningCurveWrapper(learningCurve, this.getClass()); - } - - @Override - public List getSubtaskThreads() { - return new ArrayList(); - } + try { + evaluator.close(); + } catch (Exception ex) { + throw new RuntimeException("Exception closing evaluator", ex); + } + return new PreviewCollectionLearningCurveWrapper(learningCurve, this.getClass()); + } + + @Override + public List getSubtaskThreads() { + return new ArrayList(); + } }