diff --git a/src/core/network/NeuralNetwork.java b/src/core/network/NeuralNetwork.java
index a7d41617..04d83503 100644
--- a/src/core/network/NeuralNetwork.java
+++ b/src/core/network/NeuralNetwork.java
@@ -5,7 +5,10 @@
package core.network;
-import core.layer.*;
+import core.layer.AbstractLayer;
+import core.layer.InputLayer;
+import core.layer.NeuralNetworkLayer;
+import core.layer.OutputLayer;
import core.metrics.ClassificationMetric;
import core.metrics.Metric;
import core.metrics.RegressionMetric;
@@ -24,7 +27,7 @@
import java.util.TreeMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
-import java.util.concurrent.Future;
+import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
@@ -35,7 +38,7 @@
* Can support multiple layer of different types including regularization, normalization and optimization methods.
*
*/
-public class NeuralNetwork implements Runnable, Serializable {
+public class NeuralNetwork implements Serializable {
@Serial
private static final long serialVersionUID = -1075977720550636471L;
@@ -69,6 +72,12 @@ private enum ExecutionState {
*/
private transient Condition executeLockCondition;
+ /**
+ * Lock for synchronizing neural network thread operations.
+ *
+ */
+ private transient Lock completeLock;
+
/**
* Lock-condition for synchronizing completion of procedure execution and shift to idle state.
*
@@ -94,29 +103,17 @@ private enum ExecutionState {
private transient boolean stopExecution;
/**
- * Thread pool for neural network.
+ * Network thread pool.
*
*/
private transient ExecutorService networkThreadPool;
/**
- * Thread pool for neural network layers.
+ * Layer thread pool.
*
*/
private transient ExecutorService layerThreadPool;
- /**
- * Neural network execution thread.
- *
- */
- private transient Thread neuralNetworkThread;
-
- /**
- * Future of the neural network thread.
- *
- */
- private transient Future> future;
-
/**
* Name of neural network instance.
*
@@ -187,7 +184,7 @@ private enum ExecutionState {
* Structure containing prediction input sequence.
*
*/
- private final TreeMap predictInputs = new TreeMap<>();
+ private transient TreeMap predictInputs;
/**
* Flag is neural network and it's layers are to be reset prior training phase.
@@ -199,7 +196,7 @@ private enum ExecutionState {
* Count of total neural network training iterations.
*
*/
- private int totalIterations = 0;
+ private int totalTrainingIterations = 0;
/**
* Total training time of neural network in nanoseconds.
@@ -207,6 +204,12 @@ private enum ExecutionState {
*/
private long trainingTime = 0;
+ /**
+ * Total validation time of neural network in nanoseconds.
+ *
+ */
+ private long validationTime = 0;
+
/**
* Length of automatic validation cycle in iterations.
*
@@ -266,18 +269,6 @@ public NeuralNetwork(NeuralNetworkConfiguration neuralNetworkConfiguration) thro
build(neuralNetworkConfiguration);
}
- /**
- * Constructor for neural network.
- *
- * @param neuralNetworkConfiguration neural network configuration.
- * @param neuralNetworkName name for neural network instance.
- * @throws NeuralNetworkException thrown if initialization of layer fails or neural network is already built.
- */
- public NeuralNetwork(NeuralNetworkConfiguration neuralNetworkConfiguration, String neuralNetworkName) throws NeuralNetworkException {
- this.neuralNetworkName = neuralNetworkName;
- build(neuralNetworkConfiguration);
- }
-
/**
* Sets name for neural network instance.
*
@@ -451,6 +442,8 @@ public void removeLastHiddenLayer() throws NeuralNetworkException {
hiddenLayers.remove(lastLayerIndex);
neuralNetworkLayers.remove(lastNeuralNetworkLayerIndex);
+
+ if (persistence != null) persistence = persistence.reference(this);
}
/**
@@ -502,15 +495,13 @@ public Persistence getPersistence() {
* @throws DynamicParamException throws exception if parameter (params) setting fails.
*/
public void start() throws NeuralNetworkException, MatrixException, DynamicParamException {
- if (neuralNetworkLayers.isEmpty()) throw new NeuralNetworkException("Neural network is not built.");
checkStarted();
+ if (neuralNetworkLayers.isEmpty()) throw new NeuralNetworkException("Neural network is not built.");
- trainingMetrics.clear();
- for (Integer outputLayerIndex : getOutputLayers().keySet()) trainingMetrics.put(outputLayerIndex, new SingleRegressionMetric(showTrainingMetrics));
- for (Integer outputLayerIndex : getOutputLayers().keySet()) if (validationMetrics.get(outputLayerIndex) != null) validationMetrics.put(outputLayerIndex, validationMetrics.get(outputLayerIndex).reference());
-
- if (!earlyStoppingMap.isEmpty()) {
- for (Integer outputLayerIndex : getOutputLayers().keySet()) {
+ for (Integer outputLayerIndex : getOutputLayers().keySet()) {
+ trainingMetrics.put(outputLayerIndex, new SingleRegressionMetric(showTrainingMetrics));
+ if (validationMetrics.get(outputLayerIndex) != null) validationMetrics.put(outputLayerIndex, validationMetrics.get(outputLayerIndex).reference());
+ if (!earlyStoppingMap.isEmpty()) {
earlyStoppingMap.get(outputLayerIndex).setTrainingMetric(trainingMetrics.get(outputLayerIndex));
earlyStoppingMap.get(outputLayerIndex).setValidationMetric(validationMetrics.get(outputLayerIndex));
}
@@ -518,22 +509,17 @@ public void start() throws NeuralNetworkException, MatrixException, DynamicParam
executeLock = new ReentrantLock();
executeLockCondition = executeLock.newCondition();
- completeLockCondition = executeLock.newCondition();
+ completeLock = new ReentrantLock();
+ completeLockCondition = completeLock.newCondition();
executionState = ExecutionState.IDLE;
-
stopLock = new ReentrantLock();
stopExecution = false;
- neuralNetworkThread = new Thread(this);
- neuralNetworkThread.setName("NeuralNetwork" + (neuralNetworkName != null ? " (" + neuralNetworkName + ")" : ""));
-
- layerThreadPool = Executors.newWorkStealingPool();
-
- if (networkThreadPool != null) future = networkThreadPool.submit(neuralNetworkThread);
- else neuralNetworkThread.start();
-
- for (InputLayer inputLayer : inputLayers.values()) inputLayer.start(layerThreadPool);
+ networkThreadPool = Executors.newSingleThreadExecutor();
+ executeLayer(networkThreadPool);
+ layerThreadPool = Executors.newCachedThreadPool();
+ for (NeuralNetworkLayer neuralNetworkLayer : neuralNetworkLayers.values()) neuralNetworkLayer.start(layerThreadPool);
}
/**
@@ -543,32 +529,72 @@ public void start() throws NeuralNetworkException, MatrixException, DynamicParam
public void stop() {
if (!isStarted()) return;
waitToComplete();
- executeLock.lock();
- for (NeuralNetworkLayer neuralNetworkLayer : neuralNetworkLayers.values()) neuralNetworkLayer.stop();
- executionState = ExecutionState.TERMINATED;
- executeLockCondition.signal();
- executeLock.unlock();
+ nextState(ExecutionState.TERMINATED);
+ for (NeuralNetworkLayer neuralNetworkLayer : inputLayers.values()) neuralNetworkLayer.stop();
- if (layerThreadPool == null) {
- try {
- neuralNetworkThread.join();
- } catch (InterruptedException e) {
- e.printStackTrace();
+ try {
+ layerThreadPool.shutdownNow();
+ networkThreadPool.shutdownNow();
+ if (!layerThreadPool.awaitTermination(10, TimeUnit.SECONDS) || !networkThreadPool.awaitTermination(10, TimeUnit.SECONDS)) {
+ System.out.println("Failed to shut down neural network.");
}
}
- else {
- layerThreadPool.shutdownNow();
- if (future != null) future.cancel(true);
- if (networkThreadPool != null) networkThreadPool.shutdownNow();
+ catch (InterruptedException ignored) {
}
+ }
+
+ /**
+ * Executes layer.
+ *
+ * @param executorService executor service
+ * @throws RuntimeException throws runtime exception in case any exception happens.
+ */
+ private void executeLayer(ExecutorService executorService) throws RuntimeException {
+ executorService.execute(() -> {
+ try {
+ while (!executeLayerOperation()) {}
+ } catch (Exception exception) {
+ throw new RuntimeException(exception);
+ }
+ });
+ }
- executeLock = null;
- executeLockCondition = null;
- completeLockCondition = null;
- stopLock = null;
- neuralNetworkThread = null;
- layerThreadPool = null;
- networkThreadPool = null;
+ /**
+ * Thread run function.
+ * Executes given neural network procedures and synchronizes their execution via neural network thread execution lock.
+ *
+ * @return return true if layer has been terminated otherwise returns true.
+ * @throws MatrixException throws exception if matrix operation fails.
+ * @throws NeuralNetworkException throws exception if neural network operation fails.
+ * @throws IOException throws exception if neural network persistence operation fails.
+ * @throws DynamicParamException throws exception if parameter (params) setting fails.
+ */
+ private boolean executeLayerOperation() throws MatrixException, NeuralNetworkException, DynamicParamException, IOException {
+ try {
+ executeLock.lock();
+ switch (executionState) {
+ case TRAIN -> {
+ trainIterations();
+ complete();
+ }
+ case VALIDATE -> {
+ validateInput();
+ complete();
+ }
+ case PREDICT -> {
+ predictInput();
+ complete();
+ }
+ case TERMINATED -> {
+ complete();
+ return true;
+ }
+ }
+ }
+ finally {
+ executeLock.unlock();
+ }
+ return false;
}
/**
@@ -577,7 +603,7 @@ public void stop() {
* @return returns true if neural network is started otherwise false.
*/
public boolean isStarted() {
- return networkThreadPool != null ? future != null && (!future.isDone()) : neuralNetworkThread != null && (neuralNetworkThread.getState() != Thread.State.NEW);
+ return networkThreadPool != null && !networkThreadPool.isTerminated();
}
/**
@@ -598,53 +624,50 @@ private void checkNotStarted() throws NeuralNetworkException {
if (!isStarted()) throw new NeuralNetworkException("Neural network is not started");
}
- /**
- * Aborts execution of neural network.
- * Execution is aborted after execution of last single operation is completed.
- * Useful when neural network is executing multiple training iterations.
- *
- */
- public void abortExecution() {
- if (!isStarted()) return;
- stopLock.lock();
- stopExecution = true;
- stopLock.unlock();
- }
-
/**
* Sets neural network into completed state and makes state to idle state.
*
*/
- private void stateCompleted() {
- executeLock.lock();
- executionState = ExecutionState.IDLE;
- completeLockCondition.signal();
- executeLock.unlock();
+ private void complete() {
+ try {
+ completeLock.lock();
+ executionState = ExecutionState.IDLE;
+ completeLockCondition.signal();
+ }
+ finally {
+ completeLock.unlock();
+ }
}
/**
- * Checks if neural network is executing (processing).
+ * Waits for neural network to finalize it execution (processing).
*
- * @return true if neural network is executing otherwise false.
*/
- public boolean isProcessing() {
- if (!isStarted()) return false;
- executeLock.lock();
- boolean isProcessing;
- isProcessing = !(executionState == ExecutionState.IDLE || executionState == ExecutionState.TERMINATED);
- executeLock.unlock();
- return isProcessing;
+ public void waitToComplete() {
+ if (!isStarted()) return;
+ try {
+ completeLock.lock();
+ while (executionState != ExecutionState.IDLE) completeLockCondition.await();
+ } catch (InterruptedException e) {
+ throw new RuntimeException(e);
+ }
+ finally {
+ completeLock.unlock();
+ }
}
/**
- * Waits for neural network to finalize it execution (processing).
+ * Stops execution.
*
*/
- public void waitToComplete() {
- if (!isStarted()) return;
- executeLock.lock();
- if (isProcessing()) completeLockCondition.awaitUninterruptibly();
- executeLock.unlock();
+ public void stopExecution() {
+ try {
+ stopLock.lock();
+ stopExecution = true;
+ }
+ finally {
+ stopLock.unlock();
+ }
}
/**
@@ -653,14 +676,14 @@ public void waitToComplete() {
* @return true if execution is stopped otherwise false.
*/
private boolean stoppedExecution() {
- stopLock.lock();
- boolean stopped = false;
- if (stopExecution) {
- stopped = true;
+ try {
+ stopLock.lock();
+ return stopExecution;
+ }
+ finally {
stopExecution = false;
+ stopLock.unlock();
}
- stopLock.unlock();
- return stopped;
}
/**
@@ -669,9 +692,14 @@ private boolean stoppedExecution() {
* @param executionState next state for neural network.
*/
private void nextState(ExecutionState executionState) {
- this.executionState = executionState;
- executeLockCondition.signal();
- executeLock.unlock();
+ try {
+ executeLock.lock();
+ this.executionState = executionState;
+ executeLockCondition.signal();
+ }
+ finally {
+ executeLock.unlock();
+ }
}
/**
@@ -682,55 +710,30 @@ private void nextState(ExecutionState executionState) {
* @throws NeuralNetworkException throws exception if setting of training sample sets fail.
*/
public void setTrainingData(Sampler trainingSampler) throws NeuralNetworkException {
- checkNotStarted();
waitToComplete();
- executeLock.lock();
- if (trainingSampler == null) {
- executeLock.unlock();
- throw new NeuralNetworkException("Training sampler is not set.");
- }
+ if (trainingSampler == null) throw new NeuralNetworkException("Training sampler is not set.");
this.trainingSampler = trainingSampler;
- executeLock.unlock();
}
/**
* Sets early stopping conditions.
*
- * @param earlyStoppingMap early stopping instances.
+ * @param newEarlyStoppingMap early stopping instances.
* @throws NeuralNetworkException throws exception if early stopping is not defined.
*/
- public void setTrainingEarlyStopping(TreeMap earlyStoppingMap) throws NeuralNetworkException {
+ public void setTrainingEarlyStopping(TreeMap newEarlyStoppingMap) throws NeuralNetworkException {
waitToComplete();
- if (earlyStoppingMap == null) throw new NeuralNetworkException("Early stopping is not defined.");
- this.earlyStoppingMap.clear();
- this.earlyStoppingMap.putAll(earlyStoppingMap);
- for (Integer earlyStoppingIndex : this.earlyStoppingMap.keySet()) {
+ if (newEarlyStoppingMap == null) throw new NeuralNetworkException("Early stopping is not defined.");
+ earlyStoppingMap.clear();
+ earlyStoppingMap.putAll(newEarlyStoppingMap);
+ for (Integer earlyStoppingIndex : newEarlyStoppingMap.keySet()) {
earlyStoppingMap.get(earlyStoppingIndex).setTrainingMetric(trainingMetrics.get(earlyStoppingIndex));
+ }
+ for (Integer earlyStoppingIndex : newEarlyStoppingMap.keySet()) {
earlyStoppingMap.get(earlyStoppingIndex).setValidationMetric(validationMetrics.get(earlyStoppingIndex));
}
}
- /**
- * Sets auto validation on.
- *
- * @param autoValidationCycle validation cycle in iterations.
- * @throws NeuralNetworkException throws exception if number of auto validation cycles are below 1.
- */
- public void setAutoValidate(int autoValidationCycle) throws NeuralNetworkException {
- waitToComplete();
- if (autoValidationCycle < 1) throw new NeuralNetworkException("Auto validation cycle size must be at least 1.");
- this.autoValidationCycle = autoValidationCycle;
- }
-
- /**
- * Unsets auto validation.
- *
- */
- public void unsetAutoValidate() {
- waitToComplete();
- autoValidationCycle = 0;
- }
-
/**
* Trains neural network.
*
@@ -797,7 +800,6 @@ public void train(boolean reset, boolean waitToComplete) throws NeuralNetworkExc
public void train(Sampler trainingSampler, boolean reset, boolean waitToComplete) throws NeuralNetworkException {
checkNotStarted();
waitToComplete();
- executeLock.lock();
if (trainingSampler != null) this.trainingSampler = trainingSampler;
if (this.trainingSampler == null) throw new NeuralNetworkException("Training sampler is not set.");
this.reset = reset;
@@ -805,6 +807,130 @@ public void train(Sampler trainingSampler, boolean reset, boolean waitToComplete
if (waitToComplete) waitToComplete();
}
+ /**
+ * Trains neural network with defined number of iterations.
+ *
+ * @throws MatrixException throws exception if matrix operation fails.
+ * @throws IOException throws exception if neural network persistence operation fails.
+ * @throws NeuralNetworkException throws exception if neural network training fails.
+ * @throws DynamicParamException throws exception if parameter (params) setting fails.
+ */
+ private void trainIterations() throws MatrixException, IOException, NeuralNetworkException, DynamicParamException {
+ trainingSampler.reset();
+ int numberOfIterations = trainingSampler.getNumberOfIterations();
+ for (int iteration = 0; iteration < numberOfIterations; iteration++) {
+ trainIteration();
+ if (stoppedExecution()) return;
+ if (!earlyStoppingMap.isEmpty()) {
+ boolean stopTraining = true;
+ for (EarlyStopping earlyStopping : earlyStoppingMap.values()) if (!earlyStopping.stopTraining()) stopTraining = false;
+ if (stopTraining) return;
+ }
+ }
+ }
+
+ /**
+ * Trains single neural network iteration.
+ *
+ * @throws MatrixException throws exception if matrix operation fails.
+ * @throws IOException throws exception if neural network persistence operation fails.
+ * @throws NeuralNetworkException throws exception if neural network training fails.
+ * @throws DynamicParamException throws exception if parameter (params) setting fails.
+ */
+ private void trainIteration() throws MatrixException, IOException, NeuralNetworkException, DynamicParamException {
+ totalTrainingIterations++;
+ long trainingStartTime = System.nanoTime();
+ for (NeuralNetworkLayer neuralNetworkLayer : neuralNetworkLayers.values()) if (reset) neuralNetworkLayer.resetOptimizer();
+ TreeMap inputSequences = new TreeMap<>();
+ TreeMap outputSequences = new TreeMap<>();
+ trainingSampler.getSamples(inputSequences, outputSequences);
+ for (Map.Entry entry : getOutputLayers().entrySet()) entry.getValue().setTargets(outputSequences.get(entry.getKey()));
+ for (Map.Entry entry : getInputLayers().entrySet()) entry.getValue().train(inputSequences.get(entry.getKey()));
+ for (Map.Entry entry : getOutputLayers().entrySet()) entry.getValue().backward();
+ for (Map.Entry entry : getInputLayers().entrySet()) entry.getValue().update();
+ long trainingEndTime = System.nanoTime();
+ trainingTime += trainingEndTime - trainingStartTime;
+ for (Map.Entry entry : trainingMetrics.entrySet()) entry.getValue().report(getOutputLayers().get(entry.getKey()).getTotalError());
+ if (!earlyStoppingMap.isEmpty()) for (EarlyStopping earlyStopping : earlyStoppingMap.values()) earlyStopping.evaluateTrainingCondition(totalTrainingIterations);
+ if (autoValidationCycle > 0) {
+ autoValidationCount++;
+ if (autoValidationCount >= autoValidationCycle) {
+ long validationStartTime = System.nanoTime();
+ validateInput();
+ long validationEndTime = System.nanoTime();
+ validationTime += validationEndTime - validationStartTime;
+ if (!earlyStoppingMap.isEmpty()) for (EarlyStopping earlyStopping : earlyStoppingMap.values()) earlyStopping.evaluateValidationCondition(totalTrainingIterations);
+ autoValidationCount = 0;
+ }
+ }
+ if (verboseTraining) verboseTrainingStatus();
+ if (persistence != null) persistence.cycle();
+ }
+
+ /**
+ * Verboses (prints to console) neural network training status.
+ * Prints number of iteration, neural network training time and training error.
+ *
+ */
+ private void verboseTrainingStatus() {
+ StringBuilder meanSquaredError = new StringBuilder("[ ");
+ for (SingleRegressionMetric trainingMetric : trainingMetrics.values()) {
+ meanSquaredError.append(String.format("%.4f", trainingMetric.getLastAbsoluteError())).append(" ");
+ }
+ meanSquaredError.append("]");
+ if (totalTrainingIterations % verboseCycle == 0) System.out.println((neuralNetworkName != null ? neuralNetworkName + ": " : "") + "Training error (iteration #" + totalTrainingIterations +"): " + meanSquaredError + ", Training time: " + String.format("%.3f", trainingTime / Math.pow(10, 9)) + "s" + (autoValidationCycle > 0 ? ", Validation time: " + String.format("%.3f", validationTime / Math.pow(10, 9)) + "s" : ""));
+ }
+
+ /**
+ * Returns output of neural network (output layer).
+ *
+ * @return output of neural network.
+ */
+ public TreeMap getOutput() {
+ waitToComplete();
+ TreeMap outputs = new TreeMap<>();
+ for (Map.Entry entry : getOutputLayers().entrySet()) outputs.put(entry.getKey(), entry.getValue().getLayerOutputs());
+ return outputs;
+ }
+
+ /**
+ * Returns neural network output error.
+ *
+ * @throws DynamicParamException throws exception if parameter (params) setting fails.
+ * @throws MatrixException throws exception if matrix operation fails.
+ * @return neural network output error.
+ */
+ public TreeMap getOutputError() throws MatrixException, DynamicParamException {
+ waitToComplete();
+ TreeMap outputErrors = new TreeMap<>();
+ for (Map.Entry entry : getOutputLayers().entrySet()) outputErrors.put(entry.getKey(), entry.getValue().getTotalError());
+ return outputErrors;
+ }
+
+ /**
+ * Sets importance sampling weights to output layer.
+ *
+ * @param importanceSamplingWeights importance sampling weights
+ * @throws NeuralNetworkException throws exception if neural network operation fails.
+ */
+ public void setImportanceSamplingWeights(TreeMap> importanceSamplingWeights) throws NeuralNetworkException {
+ checkNotStarted();
+ waitToComplete();
+ executeLock.lock();
+ for (Map.Entry entry : getOutputLayers().entrySet()) entry.getValue().setImportanceSamplingWeights(importanceSamplingWeights.get(entry.getKey()));
+ executeLock.unlock();
+ }
+
+ /**
+ * Sets reset flag for procedure expression dependencies.
+ *
+ * @param resetDependencies if true procedure expression dependencies are reset otherwise false.
+ */
+ public void resetDependencies(boolean resetDependencies) {
+ waitToComplete();
+ for (NeuralNetworkLayer neuralNetworkLayer : neuralNetworkLayers.values()) neuralNetworkLayer.resetDependencies(resetDependencies);
+ }
+
/**
* Verboses (prints to console) neural network training progress.
* Print information of neural network training iteration count, training time and training error.
@@ -847,100 +973,46 @@ public long getTrainingTimeInSeconds() {
}
/**
- * Sets reset flag for procedure expression dependencies.
+ * Returns training metrics instance.
*
- * @param resetDependencies if true procedure expression dependencies are reset otherwise false.
+ * @return training metrics instance.
*/
- public void resetDependencies(boolean resetDependencies) {
+ public TreeMap getTrainingMetrics() {
waitToComplete();
- for (NeuralNetworkLayer neuralNetworkLayer : neuralNetworkLayers.values()) neuralNetworkLayer.resetDependencies(resetDependencies);
+ return trainingMetrics;
}
/**
- * Predicts values based on current test set inputs.
+ * Sets if training metrics is shown.
*
- * @return predicted values (neural network output).
- * @throws NeuralNetworkException throws exception if prediction fails.
+ * @param showTrainingMetrics if true training metrics is shown otherwise not.
+ * @throws NeuralNetworkException throws exception if parameter is attempted to be set when neural network is already started.
*/
- public TreeMap predict() throws NeuralNetworkException {
- return predict(null, true);
+ public void setShowTrainingMetrics(boolean showTrainingMetrics) throws NeuralNetworkException {
+ if(isStarted()) throw new NeuralNetworkException("Training metrics can be only enabled / disabled when neural network is not started.");
+ this.showTrainingMetrics = showTrainingMetrics;
}
/**
- * Predicts values based on given input.
- *
- * @param inputs inputs for prediction.
- * @return predicted values (neural network outputs).
- * @throws NeuralNetworkException throws exception if prediction fails.
- */
- public TreeMap predictMatrix(TreeMap inputs) throws NeuralNetworkException {
- if (inputs.isEmpty()) throw new NeuralNetworkException("No prediction inputs set");
- TreeMap outputs = new TreeMap<>();
- for (Map.Entry entry : predict(Sequence.getSequencesFromMatrices(inputs), true).entrySet()) outputs.put(entry.getKey(), entry.getValue().get(0));
- return outputs;
- }
-
- /**
- * Predicts values based on current test set inputs.
- * Sets specific inputs for prediction.
- *
- * @param inputs test input set for prediction.
- * @return predicted values (neural network output).
- * @throws NeuralNetworkException throws exception if prediction fails.
- */
- public TreeMap predict(TreeMap inputs) throws NeuralNetworkException {
- return predict(inputs, true);
- }
-
- /**
- * Predicts values based on current test set inputs.
- * Optionally waits neural network prediction procedure to complete.
- *
- * @param waitToComplete if true waits for neural network execution complete otherwise returns function prior prediction completion.
- * @return predicted values (neural network output).
- * @throws NeuralNetworkException throws exception if prediction fails.
- */
- public TreeMap predict(boolean waitToComplete) throws NeuralNetworkException {
- return predict(null, waitToComplete);
- }
-
- /**
- * Predicts values based on given inputs.
- * Optionally waits neural network prediction procedure to complete.
- *
- * @param inputs test input set for prediction.
- * @param waitToComplete if true waits for neural network execution complete otherwise returns function prior prediction completion.
- * @return predicted values (neural network output).
- * @throws NeuralNetworkException throws exception if prediction fails.
- */
- public TreeMap predict(TreeMap inputs, boolean waitToComplete) throws NeuralNetworkException {
- checkNotStarted();
- waitToComplete();
- if (inputs == null) throw new NeuralNetworkException("No prediction inputs set");
- executeLock.lock();
- predictInputs.clear();
- predictInputs.putAll(inputs);
- nextState(ExecutionState.PREDICT);
- return waitToComplete ? getOutput() : null;
- }
-
- /**
- * Sets verbosing for validation phase.
- * Follows training verbosing cycle.
+ * Returns total neural network training iterations count.
*
+ * @return total neural network training iterations count.
*/
- public void verboseValidation() {
+ public int getTotalTrainingIterations() {
waitToComplete();
- verboseValidation = true;
+ return totalTrainingIterations;
}
/**
- * Unsets verbosing for validation phase.
+ * Sets validation data.
*
+ * @param validationSampler validation sampler containing validation data set.
+ * @throws NeuralNetworkException throws exception if setting of validation data fails.
*/
- public void unverboseValidation() {
+ public void setValidationData(Sampler validationSampler) throws NeuralNetworkException {
waitToComplete();
- verboseValidation = false;
+ if (validationSampler == null) throw new NeuralNetworkException("Validation sampler is not set.");
+ this.validationSampler = validationSampler;
}
/**
@@ -974,24 +1046,6 @@ public void validate(Sampler validationSampler) throws NeuralNetworkException {
validate(validationSampler, true);
}
- /**
- * Sets validation data.
- *
- * @param validationSampler validation sampler containing validation data set.
- * @throws NeuralNetworkException throws exception if setting of validation data fails.
- */
- public void setValidationData(Sampler validationSampler) throws NeuralNetworkException {
- checkNotStarted();
- waitToComplete();
- executeLock.lock();
- if (validationSampler == null) {
- executeLock.unlock();
- throw new NeuralNetworkException("Validation sampler is not set.");
- }
- this.validationSampler = validationSampler;
- executeLock.unlock();
- }
-
/**
* Validates neural network.
* Sets specific input and output (actual true values) test samples.
@@ -1004,166 +1058,20 @@ public void setValidationData(Sampler validationSampler) throws NeuralNetworkExc
public void validate(Sampler validationSampler, boolean waitToComplete) throws NeuralNetworkException {
checkNotStarted();
waitToComplete();
- executeLock.lock();
if (validationSampler != null) this.validationSampler = validationSampler;
if (this.validationSampler == null) throw new NeuralNetworkException("Validation sampler is not set.");
nextState(ExecutionState.VALIDATE);
if (waitToComplete) waitToComplete();
}
- /**
- * Returns output of neural network (output layer).
- *
- * @return output of neural network.
- */
- public TreeMap getOutput() {
- waitToComplete();
- TreeMap outputs = new TreeMap<>();
- for (Map.Entry entry : getOutputLayers().entrySet()) outputs.put(entry.getKey(), entry.getValue().getLayerOutputs());
- return outputs;
- }
-
- /**
- * Returns total training time.
- *
- * @return total training time.
- */
- public double getTotalTrainingTime() {
- double totalTrainingTime = 0;
- for (NeuralNetworkLayer neuralNetworkLayer : neuralNetworkLayers.values()) {
- double layerTrainingTime = neuralNetworkLayer.getTrainingExecutionTime();
- totalTrainingTime += layerTrainingTime;
- }
- return totalTrainingTime;
- }
-
- /**
- * Returns total prediction time.
- *
- * @return total prediction time.
- */
- public double getTotalPredictionTime() {
- double totalPredictionTime = 0;
- for (NeuralNetworkLayer neuralNetworkLayer : neuralNetworkLayers.values()) totalPredictionTime += neuralNetworkLayer.getPredictExecutionTime();
- return totalPredictionTime;
- }
-
- /**
- * Thread run function.
- * Executes given neural network procedures and synchronizes their execution via neural network thread execution lock.
- *
- */
- public void run() {
- while (true) {
- executeLock.lock();
- try {
- switch (executionState) {
- case TRAIN -> trainIterations();
- case PREDICT -> predictInput();
- case VALIDATE -> validateInput(true);
- case TERMINATED -> {
- return;
- }
- }
- }
- catch (Exception exception) {
- exception.printStackTrace();
- System.exit(-1);
- }
- executeLock.unlock();
- }
- }
-
- /**
- * Trains neural network with defined number of iterations.
- *
- * @throws MatrixException throws exception if matrix operation fails.
- * @throws IOException throws exception if neural network persistence operation fails.
- * @throws NeuralNetworkException throws exception if neural network training fails.
- * @throws DynamicParamException throws exception if parameter (params) setting fails.
- */
- private void trainIterations() throws MatrixException, IOException, NeuralNetworkException, DynamicParamException {
- trainingSampler.reset();
- int numberOfIterations = trainingSampler.getNumberOfIterations();
- for (int iteration = 0; iteration < numberOfIterations; iteration++) {
- trainIteration();
- if (stoppedExecution()) break;
- if (!earlyStoppingMap.isEmpty()) {
- boolean stopTraining = true;
- for (EarlyStopping earlyStopping : earlyStoppingMap.values()) if (!earlyStopping.stopTraining()) stopTraining = false;
- if (stopTraining) break;
- }
- }
- stateCompleted();
- }
-
- /**
- * Trains single neural network iteration.
- *
- * @throws MatrixException throws exception if matrix operation fails.
- * @throws IOException throws exception if neural network persistence operation fails.
- * @throws NeuralNetworkException throws exception if neural network training fails.
- * @throws DynamicParamException throws exception if parameter (params) setting fails.
- */
- private void trainIteration() throws MatrixException, IOException, NeuralNetworkException, DynamicParamException {
- long startTime = System.nanoTime();
- for (NeuralNetworkLayer neuralNetworkLayer : neuralNetworkLayers.values()) if (reset) neuralNetworkLayer.resetOptimizer();
- TreeMap inputSequences = new TreeMap<>();
- TreeMap outputSequences = new TreeMap<>();
- trainingSampler.getSamples(inputSequences, outputSequences);
- for (Map.Entry entry : getOutputLayers().entrySet()) entry.getValue().setTargets(outputSequences.get(entry.getKey()));
- for (Map.Entry entry : getInputLayers().entrySet()) entry.getValue().train(inputSequences.get(entry.getKey()));
- for (Map.Entry entry : getOutputLayers().entrySet()) entry.getValue().backward();
- for (Map.Entry entry : getInputLayers().entrySet()) entry.getValue().update();
- long endTime = System.nanoTime();
- trainingTime += endTime - startTime;
- for (Map.Entry entry : trainingMetrics.entrySet()) entry.getValue().report(getOutputLayers().get(entry.getKey()).getTotalError());
- totalIterations++;
- if (autoValidationCycle > 0) {
- autoValidationCount++;
- if (autoValidationCount >= autoValidationCycle) {
- validateInput(false);
- if (!earlyStoppingMap.isEmpty()) for (EarlyStopping earlyStopping : earlyStoppingMap.values()) earlyStopping.evaluateValidationCondition(totalIterations);
- autoValidationCount = 0;
- }
- }
- if (!earlyStoppingMap.isEmpty()) for (EarlyStopping earlyStopping : earlyStoppingMap.values()) earlyStopping.evaluateTrainingCondition(totalIterations);
- if (verboseTraining) verboseTrainingStatus();
- if (persistence != null) persistence.cycle();
- }
-
- /**
- * Verboses (prints to console) neural network training status.
- * Prints number of iteration, neural network training time and training error.
- *
- */
- private void verboseTrainingStatus() {
- StringBuilder meanSquaredError = new StringBuilder("[ ");
- for (SingleRegressionMetric trainingMetric : trainingMetrics.values()) {
- meanSquaredError.append(String.format("%.4f", trainingMetric.getLastAbsoluteError())).append(" ");
- }
- meanSquaredError.append("]");
- if (totalIterations % verboseCycle == 0) System.out.println((neuralNetworkName != null ? neuralNetworkName + ": " : "") + "Training error (iteration #" + totalIterations +"): " + meanSquaredError + ", Training time in seconds: " + (int)(getTotalTrainingTime() / Math.pow(10, 9)));
- }
-
- /**
- * Predicts using given test set inputs.
- *
- */
- private void predictInput() {
- for (Map.Entry entry : getInputLayers().entrySet()) entry.getValue().predict(predictInputs.get(entry.getKey()));
- stateCompleted();
- }
-
/**
* Validates with given test set inputs and outputs.
*
- * @param stateCompleted if flag is sets calls stateCompleted function.
* @throws MatrixException throws exception if matrix operation fails.
* @throws NeuralNetworkException throws exception if validation fails.
* @throws DynamicParamException throws exception if parameter (params) setting fails.
*/
- private void validateInput(boolean stateCompleted) throws MatrixException, NeuralNetworkException, DynamicParamException {
+ private void validateInput() throws MatrixException, NeuralNetworkException, DynamicParamException {
for (Metric validationMetric : validationMetrics.values()) validationMetric.reset();
validationSampler.reset();
int numberOfIterations = validationSampler.getNumberOfIterations();
@@ -1174,53 +1082,7 @@ private void validateInput(boolean stateCompleted) throws MatrixException, Neura
for (Map.Entry entry : getInputLayers().entrySet()) entry.getValue().predict(inputSequences.get(entry.getKey()));
for (Map.Entry entry : getOutputLayers().entrySet()) validationMetrics.get(entry.getKey()).report(entry.getValue().getLayerOutputs(), outputSequences.get(entry.getKey()));
}
- if (verboseValidation && (totalIterations % verboseCycle == 0)) verboseValidationStatus();
- if (stateCompleted) stateCompleted();
- }
-
- /**
- * Returns training metrics instance.
- *
- * @return training metrics instance.
- */
- public TreeMap getTrainingMetrics() {
- waitToComplete();
- return trainingMetrics;
- }
-
- /**
- * Sets if training metrics is shown.
- *
- * @param showTrainingMetrics if true training metrics is shown otherwise not.
- * @throws NeuralNetworkException throws exception if parameter is attempted to be set when neural network is already started.
- */
- public void setShowTrainingMetrics(boolean showTrainingMetrics) throws NeuralNetworkException {
- if(isStarted()) throw new NeuralNetworkException("Training metrics can be only enabled / disabled when neural network is not started.");
- this.showTrainingMetrics = showTrainingMetrics;
- }
-
- /**
- * Returns total neural network training iterations count.
- *
- * @return total neural network training iterations count.
- */
- public int getTotalIterations() {
- waitToComplete();
- return totalIterations;
- }
-
- /**
- * Returns neural network output error.
- *
- * @throws DynamicParamException throws exception if parameter (params) setting fails.
- * @throws MatrixException throws exception if matrix operation fails.
- * @return neural network output error.
- */
- public TreeMap getOutputError() throws MatrixException, DynamicParamException {
- waitToComplete();
- TreeMap outputErrors = new TreeMap<>();
- for (Map.Entry entry : getOutputLayers().entrySet()) outputErrors.put(entry.getKey(), entry.getValue().getTotalError());
- return outputErrors;
+ if (verboseValidation && (totalTrainingIterations % verboseCycle == 0)) verboseValidationStatus();
}
/**
@@ -1229,9 +1091,7 @@ public TreeMap getOutputError() throws MatrixException, Dynamic
* @param showMetric if true shows metric otherwise not.
*/
public void setAsRegression(boolean showMetric) {
- waitToComplete();
- validationMetrics.clear();
- for (Integer outputLayerIndex : getOutputLayers().keySet()) validationMetrics.put(outputLayerIndex, new RegressionMetric(showMetric));
+ setAsRegression(false, showMetric);
}
/**
@@ -1241,8 +1101,7 @@ public void setAsRegression(boolean showMetric) {
* @param showMetric if true shows metric otherwise not.
*/
public void setAsRegression(int outputLayerIndex, boolean showMetric) {
- waitToComplete();
- validationMetrics.put(outputLayerIndex, new RegressionMetric(showMetric));
+ setAsRegression(outputLayerIndex, false, showMetric);
}
/**
@@ -1254,7 +1113,7 @@ public void setAsRegression(int outputLayerIndex, boolean showMetric) {
public void setAsRegression(boolean useR2AsLastError, boolean showMetric) {
waitToComplete();
validationMetrics.clear();
- for (Integer outputLayerIndex : getOutputLayers().keySet()) validationMetrics.put(outputLayerIndex, new RegressionMetric(useR2AsLastError, showMetric));
+ for (Integer outputLayerIndex : getOutputLayers().keySet()) setAsRegression(outputLayerIndex, useR2AsLastError, showMetric);
}
/**
@@ -1266,7 +1125,7 @@ public void setAsRegression(boolean useR2AsLastError, boolean showMetric) {
*/
public void setAsRegression(int outputLayerIndex, boolean useR2AsLastError, boolean showMetric) {
waitToComplete();
- validationMetrics.put(outputLayerIndex, new RegressionMetric(useR2AsLastError, showMetric));
+ setValidationMetric(outputLayerIndex, new RegressionMetric(useR2AsLastError, showMetric));
}
/**
@@ -1275,9 +1134,7 @@ public void setAsRegression(int outputLayerIndex, boolean useR2AsLastError, bool
* @param showMetric if true shows metric otherwise not.
*/
public void setAsClassification(boolean showMetric) {
- waitToComplete();
- validationMetrics.clear();
- for (Integer outputLayerIndex : getOutputLayers().keySet()) validationMetrics.put(outputLayerIndex, new ClassificationMetric(showMetric));
+ setAsClassification(false, showMetric);
}
/**
@@ -1287,8 +1144,7 @@ public void setAsClassification(boolean showMetric) {
* @param showMetric if true shows metric otherwise not.
*/
public void setAsClassification(int outputLayerIndex, boolean showMetric) {
- waitToComplete();
- validationMetrics.put(outputLayerIndex, new ClassificationMetric(showMetric));
+ setAsClassification(outputLayerIndex, false, showMetric);
}
/**
@@ -1300,7 +1156,7 @@ public void setAsClassification(int outputLayerIndex, boolean showMetric) {
public void setAsClassification(boolean multiClass, boolean showMetric) {
waitToComplete();
validationMetrics.clear();
- for (Integer outputLayerIndex : getOutputLayers().keySet()) validationMetrics.put(outputLayerIndex, new ClassificationMetric(multiClass, showMetric));
+ for (Integer outputLayerIndex : getOutputLayers().keySet()) setAsClassification(outputLayerIndex, multiClass, showMetric);
}
/**
@@ -1312,7 +1168,58 @@ public void setAsClassification(boolean multiClass, boolean showMetric) {
*/
public void setAsClassification(int outputLayerIndex, boolean multiClass, boolean showMetric) {
waitToComplete();
- validationMetrics.put(outputLayerIndex, new ClassificationMetric(multiClass, showMetric));
+ setValidationMetric(outputLayerIndex, new ClassificationMetric(multiClass, showMetric));
+ }
+
+ /**
+ * Sets validation metric.
+ *
+ * @param outputLayerIndex output layer index.
+ * @param metric metric.
+ */
+ private void setValidationMetric(int outputLayerIndex, Metric metric) {
+ validationMetrics.put(outputLayerIndex, metric);
+ if (earlyStoppingMap.get(outputLayerIndex) != null) earlyStoppingMap.get(outputLayerIndex).setValidationMetric(metric);
+ }
+
+ /**
+ * Sets auto validation on.
+ *
+ * @param autoValidationCycle validation cycle in iterations.
+ * @throws NeuralNetworkException throws exception if number of auto validation cycles are below 1.
+ */
+ public void setAutoValidate(int autoValidationCycle) throws NeuralNetworkException {
+ waitToComplete();
+ if (autoValidationCycle < 1) throw new NeuralNetworkException("Auto validation cycle size must be at least 1.");
+ this.autoValidationCycle = autoValidationCycle;
+ }
+
+ /**
+ * Unsets auto validation.
+ *
+ */
+ public void unsetAutoValidate() {
+ waitToComplete();
+ autoValidationCycle = 0;
+ }
+
+ /**
+ * Sets verbosing for validation phase.
+ * Follows training verbosing cycle.
+ *
+ */
+ public void verboseValidation() {
+ waitToComplete();
+ verboseValidation = true;
+ }
+
+ /**
+ * Unsets verbosing for validation phase.
+ *
+ */
+ public void unverboseValidation() {
+ waitToComplete();
+ verboseValidation = false;
}
/**
@@ -1360,6 +1267,58 @@ private void verboseValidationStatus() throws MatrixException, DynamicParamExcep
for (Metric validationMetric : validationMetrics.values()) validationMetric.printReport();
}
+ /**
+ * Predicts values based on given input.
+ *
+ * @param inputs inputs for prediction.
+ * @return predicted values (neural network outputs).
+ * @throws NeuralNetworkException throws exception if prediction fails.
+ */
+ public TreeMap predictMatrix(TreeMap inputs) throws NeuralNetworkException {
+ if (inputs.isEmpty()) throw new NeuralNetworkException("No prediction inputs set");
+ TreeMap outputs = new TreeMap<>();
+ for (Map.Entry entry : predict(Sequence.getSequencesFromMatrices(inputs), true).entrySet()) outputs.put(entry.getKey(), entry.getValue().get(0));
+ return outputs;
+ }
+
+ /**
+ * Predicts values based on current test set inputs.
+ * Sets specific inputs for prediction.
+ *
+ * @param inputs test input set for prediction.
+ * @return predicted values (neural network output).
+ * @throws NeuralNetworkException throws exception if prediction fails.
+ */
+ public TreeMap predict(TreeMap inputs) throws NeuralNetworkException {
+ return predict(inputs, true);
+ }
+
+ /**
+ * Predicts values based on given inputs.
+ * Optionally waits neural network prediction procedure to complete.
+ *
+ * @param inputs test input set for prediction.
+ * @param waitToComplete if true waits for neural network execution complete otherwise returns function prior prediction completion.
+ * @return predicted values (neural network output).
+ * @throws NeuralNetworkException throws exception if prediction fails.
+ */
+ public TreeMap predict(TreeMap inputs, boolean waitToComplete) throws NeuralNetworkException {
+ checkNotStarted();
+ waitToComplete();
+ if (inputs == null) throw new NeuralNetworkException("No prediction inputs set");
+ predictInputs = inputs;
+ nextState(ExecutionState.PREDICT);
+ return waitToComplete ? getOutput() : null;
+ }
+
+ /**
+ * Predicts using given test set inputs.
+ *
+ */
+ private void predictInput() {
+ for (Map.Entry entry : getInputLayers().entrySet()) entry.getValue().predict(predictInputs.get(entry.getKey()));
+ }
+
/**
* Makes deep copy of neural network by using object serialization.
*
@@ -1370,7 +1329,6 @@ private void verboseValidationStatus() throws MatrixException, DynamicParamExcep
*/
public NeuralNetwork copy() throws IOException, ClassNotFoundException, MatrixException {
waitToComplete();
- predictInputs.clear();
for (NeuralNetworkLayer neuralNetworkLayer : neuralNetworkLayers.values()) neuralNetworkLayer.reset();
ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
ObjectOutputStream objectOutputStream = new ObjectOutputStream(byteArrayOutputStream);
@@ -1407,11 +1365,13 @@ public NeuralNetwork reference() throws IOException, ClassNotFoundException, Mat
*/
public void reinitialize() throws MatrixException, DynamicParamException {
waitToComplete();
- predictInputs.clear();
+ predictInputs = new TreeMap<>();
+ trainingSampler = null;
+ validationSampler = null;
for (NeuralNetworkLayer neuralNetworkLayer : neuralNetworkLayers.values()) neuralNetworkLayer.reinitialize();
for (Map.Entry entry : earlyStoppingMap.entrySet()) earlyStoppingMap.put(entry.getKey(), entry.getValue().reference());
for (Metric validationMetric : validationMetrics.values()) validationMetric.reinitialize();
- totalIterations = 0;
+ totalTrainingIterations = 0;
trainingTime = 0;
}
@@ -1429,20 +1389,6 @@ public void append(NeuralNetwork otherNeuralNetwork, double tau) throws MatrixEx
}
}
- /**
- * Sets importance sampling weights to output layer.
- *
- * @param importanceSamplingWeights importance sampling weights
- * @throws NeuralNetworkException throws exception if neural network operation fails.
- */
- public void setImportanceSamplingWeights(TreeMap> importanceSamplingWeights) throws NeuralNetworkException {
- checkNotStarted();
- waitToComplete();
- executeLock.lock();
- for (Map.Entry entry : getOutputLayers().entrySet()) entry.getValue().setImportanceSamplingWeights(importanceSamplingWeights.get(entry.getKey()));
- executeLock.unlock();
- }
-
/**
* Prints structure and metadata of neural network.