From 438201836d2800713610348fe40017d7244bda2d Mon Sep 17 00:00:00 2001 From: Simo Aaltonen Date: Sun, 19 Nov 2023 21:40:07 +0200 Subject: [PATCH] Refactoring of NeuralNetwork class. --- src/core/network/NeuralNetwork.java | 876 +++++++++++++--------------- 1 file changed, 411 insertions(+), 465 deletions(-) diff --git a/src/core/network/NeuralNetwork.java b/src/core/network/NeuralNetwork.java index a7d41617..04d83503 100644 --- a/src/core/network/NeuralNetwork.java +++ b/src/core/network/NeuralNetwork.java @@ -5,7 +5,10 @@ package core.network; -import core.layer.*; +import core.layer.AbstractLayer; +import core.layer.InputLayer; +import core.layer.NeuralNetworkLayer; +import core.layer.OutputLayer; import core.metrics.ClassificationMetric; import core.metrics.Metric; import core.metrics.RegressionMetric; @@ -24,7 +27,7 @@ import java.util.TreeMap; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; -import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; import java.util.concurrent.locks.Condition; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; @@ -35,7 +38,7 @@ * Can support multiple layer of different types including regularization, normalization and optimization methods.
* */ -public class NeuralNetwork implements Runnable, Serializable { +public class NeuralNetwork implements Serializable { @Serial private static final long serialVersionUID = -1075977720550636471L; @@ -69,6 +72,12 @@ private enum ExecutionState { */ private transient Condition executeLockCondition; + /** + * Lock for synchronizing neural network thread operations. + * + */ + private transient Lock completeLock; + /** * Lock-condition for synchronizing completion of procedure execution and shift to idle state. * @@ -94,29 +103,17 @@ private enum ExecutionState { private transient boolean stopExecution; /** - * Thread pool for neural network. + * Network thread pool. * */ private transient ExecutorService networkThreadPool; /** - * Thread pool for neural network layers. + * Layer thread pool. * */ private transient ExecutorService layerThreadPool; - /** - * Neural network execution thread. - * - */ - private transient Thread neuralNetworkThread; - - /** - * Future of the neural network thread. - * - */ - private transient Future future; - /** * Name of neural network instance. * @@ -187,7 +184,7 @@ private enum ExecutionState { * Structure containing prediction input sequence. * */ - private final TreeMap predictInputs = new TreeMap<>(); + private transient TreeMap predictInputs; /** * Flag is neural network and it's layers are to be reset prior training phase. @@ -199,7 +196,7 @@ private enum ExecutionState { * Count of total neural network training iterations. * */ - private int totalIterations = 0; + private int totalTrainingIterations = 0; /** * Total training time of neural network in nanoseconds. @@ -207,6 +204,12 @@ private enum ExecutionState { */ private long trainingTime = 0; + /** + * Total validation time of neural network in nanoseconds. + * + */ + private long validationTime = 0; + /** * Length of automatic validation cycle in iterations. * @@ -266,18 +269,6 @@ public NeuralNetwork(NeuralNetworkConfiguration neuralNetworkConfiguration) thro build(neuralNetworkConfiguration); } - /** - * Constructor for neural network. - * - * @param neuralNetworkConfiguration neural network configuration. - * @param neuralNetworkName name for neural network instance. - * @throws NeuralNetworkException thrown if initialization of layer fails or neural network is already built. - */ - public NeuralNetwork(NeuralNetworkConfiguration neuralNetworkConfiguration, String neuralNetworkName) throws NeuralNetworkException { - this.neuralNetworkName = neuralNetworkName; - build(neuralNetworkConfiguration); - } - /** * Sets name for neural network instance. * @@ -451,6 +442,8 @@ public void removeLastHiddenLayer() throws NeuralNetworkException { hiddenLayers.remove(lastLayerIndex); neuralNetworkLayers.remove(lastNeuralNetworkLayerIndex); + + if (persistence != null) persistence = persistence.reference(this); } /** @@ -502,15 +495,13 @@ public Persistence getPersistence() { * @throws DynamicParamException throws exception if parameter (params) setting fails. */ public void start() throws NeuralNetworkException, MatrixException, DynamicParamException { - if (neuralNetworkLayers.isEmpty()) throw new NeuralNetworkException("Neural network is not built."); checkStarted(); + if (neuralNetworkLayers.isEmpty()) throw new NeuralNetworkException("Neural network is not built."); - trainingMetrics.clear(); - for (Integer outputLayerIndex : getOutputLayers().keySet()) trainingMetrics.put(outputLayerIndex, new SingleRegressionMetric(showTrainingMetrics)); - for (Integer outputLayerIndex : getOutputLayers().keySet()) if (validationMetrics.get(outputLayerIndex) != null) validationMetrics.put(outputLayerIndex, validationMetrics.get(outputLayerIndex).reference()); - - if (!earlyStoppingMap.isEmpty()) { - for (Integer outputLayerIndex : getOutputLayers().keySet()) { + for (Integer outputLayerIndex : getOutputLayers().keySet()) { + trainingMetrics.put(outputLayerIndex, new SingleRegressionMetric(showTrainingMetrics)); + if (validationMetrics.get(outputLayerIndex) != null) validationMetrics.put(outputLayerIndex, validationMetrics.get(outputLayerIndex).reference()); + if (!earlyStoppingMap.isEmpty()) { earlyStoppingMap.get(outputLayerIndex).setTrainingMetric(trainingMetrics.get(outputLayerIndex)); earlyStoppingMap.get(outputLayerIndex).setValidationMetric(validationMetrics.get(outputLayerIndex)); } @@ -518,22 +509,17 @@ public void start() throws NeuralNetworkException, MatrixException, DynamicParam executeLock = new ReentrantLock(); executeLockCondition = executeLock.newCondition(); - completeLockCondition = executeLock.newCondition(); + completeLock = new ReentrantLock(); + completeLockCondition = completeLock.newCondition(); executionState = ExecutionState.IDLE; - stopLock = new ReentrantLock(); stopExecution = false; - neuralNetworkThread = new Thread(this); - neuralNetworkThread.setName("NeuralNetwork" + (neuralNetworkName != null ? " (" + neuralNetworkName + ")" : "")); - - layerThreadPool = Executors.newWorkStealingPool(); - - if (networkThreadPool != null) future = networkThreadPool.submit(neuralNetworkThread); - else neuralNetworkThread.start(); - - for (InputLayer inputLayer : inputLayers.values()) inputLayer.start(layerThreadPool); + networkThreadPool = Executors.newSingleThreadExecutor(); + executeLayer(networkThreadPool); + layerThreadPool = Executors.newCachedThreadPool(); + for (NeuralNetworkLayer neuralNetworkLayer : neuralNetworkLayers.values()) neuralNetworkLayer.start(layerThreadPool); } /** @@ -543,32 +529,72 @@ public void start() throws NeuralNetworkException, MatrixException, DynamicParam public void stop() { if (!isStarted()) return; waitToComplete(); - executeLock.lock(); - for (NeuralNetworkLayer neuralNetworkLayer : neuralNetworkLayers.values()) neuralNetworkLayer.stop(); - executionState = ExecutionState.TERMINATED; - executeLockCondition.signal(); - executeLock.unlock(); + nextState(ExecutionState.TERMINATED); + for (NeuralNetworkLayer neuralNetworkLayer : inputLayers.values()) neuralNetworkLayer.stop(); - if (layerThreadPool == null) { - try { - neuralNetworkThread.join(); - } catch (InterruptedException e) { - e.printStackTrace(); + try { + layerThreadPool.shutdownNow(); + networkThreadPool.shutdownNow(); + if (!layerThreadPool.awaitTermination(10, TimeUnit.SECONDS) || !networkThreadPool.awaitTermination(10, TimeUnit.SECONDS)) { + System.out.println("Failed to shut down neural network."); } } - else { - layerThreadPool.shutdownNow(); - if (future != null) future.cancel(true); - if (networkThreadPool != null) networkThreadPool.shutdownNow(); + catch (InterruptedException ignored) { } + } + + /** + * Executes layer. + * + * @param executorService executor service + * @throws RuntimeException throws runtime exception in case any exception happens. + */ + private void executeLayer(ExecutorService executorService) throws RuntimeException { + executorService.execute(() -> { + try { + while (!executeLayerOperation()) {} + } catch (Exception exception) { + throw new RuntimeException(exception); + } + }); + } - executeLock = null; - executeLockCondition = null; - completeLockCondition = null; - stopLock = null; - neuralNetworkThread = null; - layerThreadPool = null; - networkThreadPool = null; + /** + * Thread run function.
+ * Executes given neural network procedures and synchronizes their execution via neural network thread execution lock.
+ * + * @return return true if layer has been terminated otherwise returns true. + * @throws MatrixException throws exception if matrix operation fails. + * @throws NeuralNetworkException throws exception if neural network operation fails. + * @throws IOException throws exception if neural network persistence operation fails. + * @throws DynamicParamException throws exception if parameter (params) setting fails. + */ + private boolean executeLayerOperation() throws MatrixException, NeuralNetworkException, DynamicParamException, IOException { + try { + executeLock.lock(); + switch (executionState) { + case TRAIN -> { + trainIterations(); + complete(); + } + case VALIDATE -> { + validateInput(); + complete(); + } + case PREDICT -> { + predictInput(); + complete(); + } + case TERMINATED -> { + complete(); + return true; + } + } + } + finally { + executeLock.unlock(); + } + return false; } /** @@ -577,7 +603,7 @@ public void stop() { * @return returns true if neural network is started otherwise false. */ public boolean isStarted() { - return networkThreadPool != null ? future != null && (!future.isDone()) : neuralNetworkThread != null && (neuralNetworkThread.getState() != Thread.State.NEW); + return networkThreadPool != null && !networkThreadPool.isTerminated(); } /** @@ -598,53 +624,50 @@ private void checkNotStarted() throws NeuralNetworkException { if (!isStarted()) throw new NeuralNetworkException("Neural network is not started"); } - /** - * Aborts execution of neural network.
- * Execution is aborted after execution of last single operation is completed.
- * Useful when neural network is executing multiple training iterations.
- * - */ - public void abortExecution() { - if (!isStarted()) return; - stopLock.lock(); - stopExecution = true; - stopLock.unlock(); - } - /** * Sets neural network into completed state and makes state to idle state. * */ - private void stateCompleted() { - executeLock.lock(); - executionState = ExecutionState.IDLE; - completeLockCondition.signal(); - executeLock.unlock(); + private void complete() { + try { + completeLock.lock(); + executionState = ExecutionState.IDLE; + completeLockCondition.signal(); + } + finally { + completeLock.unlock(); + } } /** - * Checks if neural network is executing (processing). + * Waits for neural network to finalize it execution (processing). * - * @return true if neural network is executing otherwise false. */ - public boolean isProcessing() { - if (!isStarted()) return false; - executeLock.lock(); - boolean isProcessing; - isProcessing = !(executionState == ExecutionState.IDLE || executionState == ExecutionState.TERMINATED); - executeLock.unlock(); - return isProcessing; + public void waitToComplete() { + if (!isStarted()) return; + try { + completeLock.lock(); + while (executionState != ExecutionState.IDLE) completeLockCondition.await(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + finally { + completeLock.unlock(); + } } /** - * Waits for neural network to finalize it execution (processing). + * Stops execution. * */ - public void waitToComplete() { - if (!isStarted()) return; - executeLock.lock(); - if (isProcessing()) completeLockCondition.awaitUninterruptibly(); - executeLock.unlock(); + public void stopExecution() { + try { + stopLock.lock(); + stopExecution = true; + } + finally { + stopLock.unlock(); + } } /** @@ -653,14 +676,14 @@ public void waitToComplete() { * @return true if execution is stopped otherwise false. */ private boolean stoppedExecution() { - stopLock.lock(); - boolean stopped = false; - if (stopExecution) { - stopped = true; + try { + stopLock.lock(); + return stopExecution; + } + finally { stopExecution = false; + stopLock.unlock(); } - stopLock.unlock(); - return stopped; } /** @@ -669,9 +692,14 @@ private boolean stoppedExecution() { * @param executionState next state for neural network. */ private void nextState(ExecutionState executionState) { - this.executionState = executionState; - executeLockCondition.signal(); - executeLock.unlock(); + try { + executeLock.lock(); + this.executionState = executionState; + executeLockCondition.signal(); + } + finally { + executeLock.unlock(); + } } /** @@ -682,55 +710,30 @@ private void nextState(ExecutionState executionState) { * @throws NeuralNetworkException throws exception if setting of training sample sets fail. */ public void setTrainingData(Sampler trainingSampler) throws NeuralNetworkException { - checkNotStarted(); waitToComplete(); - executeLock.lock(); - if (trainingSampler == null) { - executeLock.unlock(); - throw new NeuralNetworkException("Training sampler is not set."); - } + if (trainingSampler == null) throw new NeuralNetworkException("Training sampler is not set."); this.trainingSampler = trainingSampler; - executeLock.unlock(); } /** * Sets early stopping conditions. * - * @param earlyStoppingMap early stopping instances. + * @param newEarlyStoppingMap early stopping instances. * @throws NeuralNetworkException throws exception if early stopping is not defined. */ - public void setTrainingEarlyStopping(TreeMap earlyStoppingMap) throws NeuralNetworkException { + public void setTrainingEarlyStopping(TreeMap newEarlyStoppingMap) throws NeuralNetworkException { waitToComplete(); - if (earlyStoppingMap == null) throw new NeuralNetworkException("Early stopping is not defined."); - this.earlyStoppingMap.clear(); - this.earlyStoppingMap.putAll(earlyStoppingMap); - for (Integer earlyStoppingIndex : this.earlyStoppingMap.keySet()) { + if (newEarlyStoppingMap == null) throw new NeuralNetworkException("Early stopping is not defined."); + earlyStoppingMap.clear(); + earlyStoppingMap.putAll(newEarlyStoppingMap); + for (Integer earlyStoppingIndex : newEarlyStoppingMap.keySet()) { earlyStoppingMap.get(earlyStoppingIndex).setTrainingMetric(trainingMetrics.get(earlyStoppingIndex)); + } + for (Integer earlyStoppingIndex : newEarlyStoppingMap.keySet()) { earlyStoppingMap.get(earlyStoppingIndex).setValidationMetric(validationMetrics.get(earlyStoppingIndex)); } } - /** - * Sets auto validation on. - * - * @param autoValidationCycle validation cycle in iterations. - * @throws NeuralNetworkException throws exception if number of auto validation cycles are below 1. - */ - public void setAutoValidate(int autoValidationCycle) throws NeuralNetworkException { - waitToComplete(); - if (autoValidationCycle < 1) throw new NeuralNetworkException("Auto validation cycle size must be at least 1."); - this.autoValidationCycle = autoValidationCycle; - } - - /** - * Unsets auto validation. - * - */ - public void unsetAutoValidate() { - waitToComplete(); - autoValidationCycle = 0; - } - /** * Trains neural network. * @@ -797,7 +800,6 @@ public void train(boolean reset, boolean waitToComplete) throws NeuralNetworkExc public void train(Sampler trainingSampler, boolean reset, boolean waitToComplete) throws NeuralNetworkException { checkNotStarted(); waitToComplete(); - executeLock.lock(); if (trainingSampler != null) this.trainingSampler = trainingSampler; if (this.trainingSampler == null) throw new NeuralNetworkException("Training sampler is not set."); this.reset = reset; @@ -805,6 +807,130 @@ public void train(Sampler trainingSampler, boolean reset, boolean waitToComplete if (waitToComplete) waitToComplete(); } + /** + * Trains neural network with defined number of iterations. + * + * @throws MatrixException throws exception if matrix operation fails. + * @throws IOException throws exception if neural network persistence operation fails. + * @throws NeuralNetworkException throws exception if neural network training fails. + * @throws DynamicParamException throws exception if parameter (params) setting fails. + */ + private void trainIterations() throws MatrixException, IOException, NeuralNetworkException, DynamicParamException { + trainingSampler.reset(); + int numberOfIterations = trainingSampler.getNumberOfIterations(); + for (int iteration = 0; iteration < numberOfIterations; iteration++) { + trainIteration(); + if (stoppedExecution()) return; + if (!earlyStoppingMap.isEmpty()) { + boolean stopTraining = true; + for (EarlyStopping earlyStopping : earlyStoppingMap.values()) if (!earlyStopping.stopTraining()) stopTraining = false; + if (stopTraining) return; + } + } + } + + /** + * Trains single neural network iteration. + * + * @throws MatrixException throws exception if matrix operation fails. + * @throws IOException throws exception if neural network persistence operation fails. + * @throws NeuralNetworkException throws exception if neural network training fails. + * @throws DynamicParamException throws exception if parameter (params) setting fails. + */ + private void trainIteration() throws MatrixException, IOException, NeuralNetworkException, DynamicParamException { + totalTrainingIterations++; + long trainingStartTime = System.nanoTime(); + for (NeuralNetworkLayer neuralNetworkLayer : neuralNetworkLayers.values()) if (reset) neuralNetworkLayer.resetOptimizer(); + TreeMap inputSequences = new TreeMap<>(); + TreeMap outputSequences = new TreeMap<>(); + trainingSampler.getSamples(inputSequences, outputSequences); + for (Map.Entry entry : getOutputLayers().entrySet()) entry.getValue().setTargets(outputSequences.get(entry.getKey())); + for (Map.Entry entry : getInputLayers().entrySet()) entry.getValue().train(inputSequences.get(entry.getKey())); + for (Map.Entry entry : getOutputLayers().entrySet()) entry.getValue().backward(); + for (Map.Entry entry : getInputLayers().entrySet()) entry.getValue().update(); + long trainingEndTime = System.nanoTime(); + trainingTime += trainingEndTime - trainingStartTime; + for (Map.Entry entry : trainingMetrics.entrySet()) entry.getValue().report(getOutputLayers().get(entry.getKey()).getTotalError()); + if (!earlyStoppingMap.isEmpty()) for (EarlyStopping earlyStopping : earlyStoppingMap.values()) earlyStopping.evaluateTrainingCondition(totalTrainingIterations); + if (autoValidationCycle > 0) { + autoValidationCount++; + if (autoValidationCount >= autoValidationCycle) { + long validationStartTime = System.nanoTime(); + validateInput(); + long validationEndTime = System.nanoTime(); + validationTime += validationEndTime - validationStartTime; + if (!earlyStoppingMap.isEmpty()) for (EarlyStopping earlyStopping : earlyStoppingMap.values()) earlyStopping.evaluateValidationCondition(totalTrainingIterations); + autoValidationCount = 0; + } + } + if (verboseTraining) verboseTrainingStatus(); + if (persistence != null) persistence.cycle(); + } + + /** + * Verboses (prints to console) neural network training status.
+ * Prints number of iteration, neural network training time and training error.
+ * + */ + private void verboseTrainingStatus() { + StringBuilder meanSquaredError = new StringBuilder("[ "); + for (SingleRegressionMetric trainingMetric : trainingMetrics.values()) { + meanSquaredError.append(String.format("%.4f", trainingMetric.getLastAbsoluteError())).append(" "); + } + meanSquaredError.append("]"); + if (totalTrainingIterations % verboseCycle == 0) System.out.println((neuralNetworkName != null ? neuralNetworkName + ": " : "") + "Training error (iteration #" + totalTrainingIterations +"): " + meanSquaredError + ", Training time: " + String.format("%.3f", trainingTime / Math.pow(10, 9)) + "s" + (autoValidationCycle > 0 ? ", Validation time: " + String.format("%.3f", validationTime / Math.pow(10, 9)) + "s" : "")); + } + + /** + * Returns output of neural network (output layer). + * + * @return output of neural network. + */ + public TreeMap getOutput() { + waitToComplete(); + TreeMap outputs = new TreeMap<>(); + for (Map.Entry entry : getOutputLayers().entrySet()) outputs.put(entry.getKey(), entry.getValue().getLayerOutputs()); + return outputs; + } + + /** + * Returns neural network output error. + * + * @throws DynamicParamException throws exception if parameter (params) setting fails. + * @throws MatrixException throws exception if matrix operation fails. + * @return neural network output error. + */ + public TreeMap getOutputError() throws MatrixException, DynamicParamException { + waitToComplete(); + TreeMap outputErrors = new TreeMap<>(); + for (Map.Entry entry : getOutputLayers().entrySet()) outputErrors.put(entry.getKey(), entry.getValue().getTotalError()); + return outputErrors; + } + + /** + * Sets importance sampling weights to output layer. + * + * @param importanceSamplingWeights importance sampling weights + * @throws NeuralNetworkException throws exception if neural network operation fails. + */ + public void setImportanceSamplingWeights(TreeMap> importanceSamplingWeights) throws NeuralNetworkException { + checkNotStarted(); + waitToComplete(); + executeLock.lock(); + for (Map.Entry entry : getOutputLayers().entrySet()) entry.getValue().setImportanceSamplingWeights(importanceSamplingWeights.get(entry.getKey())); + executeLock.unlock(); + } + + /** + * Sets reset flag for procedure expression dependencies. + * + * @param resetDependencies if true procedure expression dependencies are reset otherwise false. + */ + public void resetDependencies(boolean resetDependencies) { + waitToComplete(); + for (NeuralNetworkLayer neuralNetworkLayer : neuralNetworkLayers.values()) neuralNetworkLayer.resetDependencies(resetDependencies); + } + /** * Verboses (prints to console) neural network training progress.
* Print information of neural network training iteration count, training time and training error.
@@ -847,100 +973,46 @@ public long getTrainingTimeInSeconds() { } /** - * Sets reset flag for procedure expression dependencies. + * Returns training metrics instance. * - * @param resetDependencies if true procedure expression dependencies are reset otherwise false. + * @return training metrics instance. */ - public void resetDependencies(boolean resetDependencies) { + public TreeMap getTrainingMetrics() { waitToComplete(); - for (NeuralNetworkLayer neuralNetworkLayer : neuralNetworkLayers.values()) neuralNetworkLayer.resetDependencies(resetDependencies); + return trainingMetrics; } /** - * Predicts values based on current test set inputs. + * Sets if training metrics is shown. * - * @return predicted values (neural network output). - * @throws NeuralNetworkException throws exception if prediction fails. + * @param showTrainingMetrics if true training metrics is shown otherwise not. + * @throws NeuralNetworkException throws exception if parameter is attempted to be set when neural network is already started. */ - public TreeMap predict() throws NeuralNetworkException { - return predict(null, true); + public void setShowTrainingMetrics(boolean showTrainingMetrics) throws NeuralNetworkException { + if(isStarted()) throw new NeuralNetworkException("Training metrics can be only enabled / disabled when neural network is not started."); + this.showTrainingMetrics = showTrainingMetrics; } /** - * Predicts values based on given input. - * - * @param inputs inputs for prediction. - * @return predicted values (neural network outputs). - * @throws NeuralNetworkException throws exception if prediction fails. - */ - public TreeMap predictMatrix(TreeMap inputs) throws NeuralNetworkException { - if (inputs.isEmpty()) throw new NeuralNetworkException("No prediction inputs set"); - TreeMap outputs = new TreeMap<>(); - for (Map.Entry entry : predict(Sequence.getSequencesFromMatrices(inputs), true).entrySet()) outputs.put(entry.getKey(), entry.getValue().get(0)); - return outputs; - } - - /** - * Predicts values based on current test set inputs.
- * Sets specific inputs for prediction.
- * - * @param inputs test input set for prediction. - * @return predicted values (neural network output). - * @throws NeuralNetworkException throws exception if prediction fails. - */ - public TreeMap predict(TreeMap inputs) throws NeuralNetworkException { - return predict(inputs, true); - } - - /** - * Predicts values based on current test set inputs.
- * Optionally waits neural network prediction procedure to complete.
- * - * @param waitToComplete if true waits for neural network execution complete otherwise returns function prior prediction completion. - * @return predicted values (neural network output). - * @throws NeuralNetworkException throws exception if prediction fails. - */ - public TreeMap predict(boolean waitToComplete) throws NeuralNetworkException { - return predict(null, waitToComplete); - } - - /** - * Predicts values based on given inputs.
- * Optionally waits neural network prediction procedure to complete.
- * - * @param inputs test input set for prediction. - * @param waitToComplete if true waits for neural network execution complete otherwise returns function prior prediction completion. - * @return predicted values (neural network output). - * @throws NeuralNetworkException throws exception if prediction fails. - */ - public TreeMap predict(TreeMap inputs, boolean waitToComplete) throws NeuralNetworkException { - checkNotStarted(); - waitToComplete(); - if (inputs == null) throw new NeuralNetworkException("No prediction inputs set"); - executeLock.lock(); - predictInputs.clear(); - predictInputs.putAll(inputs); - nextState(ExecutionState.PREDICT); - return waitToComplete ? getOutput() : null; - } - - /** - * Sets verbosing for validation phase.
- * Follows training verbosing cycle.
+ * Returns total neural network training iterations count. * + * @return total neural network training iterations count. */ - public void verboseValidation() { + public int getTotalTrainingIterations() { waitToComplete(); - verboseValidation = true; + return totalTrainingIterations; } /** - * Unsets verbosing for validation phase. + * Sets validation data. * + * @param validationSampler validation sampler containing validation data set. + * @throws NeuralNetworkException throws exception if setting of validation data fails. */ - public void unverboseValidation() { + public void setValidationData(Sampler validationSampler) throws NeuralNetworkException { waitToComplete(); - verboseValidation = false; + if (validationSampler == null) throw new NeuralNetworkException("Validation sampler is not set."); + this.validationSampler = validationSampler; } /** @@ -974,24 +1046,6 @@ public void validate(Sampler validationSampler) throws NeuralNetworkException { validate(validationSampler, true); } - /** - * Sets validation data. - * - * @param validationSampler validation sampler containing validation data set. - * @throws NeuralNetworkException throws exception if setting of validation data fails. - */ - public void setValidationData(Sampler validationSampler) throws NeuralNetworkException { - checkNotStarted(); - waitToComplete(); - executeLock.lock(); - if (validationSampler == null) { - executeLock.unlock(); - throw new NeuralNetworkException("Validation sampler is not set."); - } - this.validationSampler = validationSampler; - executeLock.unlock(); - } - /** * Validates neural network.
* Sets specific input and output (actual true values) test samples.
@@ -1004,166 +1058,20 @@ public void setValidationData(Sampler validationSampler) throws NeuralNetworkExc public void validate(Sampler validationSampler, boolean waitToComplete) throws NeuralNetworkException { checkNotStarted(); waitToComplete(); - executeLock.lock(); if (validationSampler != null) this.validationSampler = validationSampler; if (this.validationSampler == null) throw new NeuralNetworkException("Validation sampler is not set."); nextState(ExecutionState.VALIDATE); if (waitToComplete) waitToComplete(); } - /** - * Returns output of neural network (output layer). - * - * @return output of neural network. - */ - public TreeMap getOutput() { - waitToComplete(); - TreeMap outputs = new TreeMap<>(); - for (Map.Entry entry : getOutputLayers().entrySet()) outputs.put(entry.getKey(), entry.getValue().getLayerOutputs()); - return outputs; - } - - /** - * Returns total training time. - * - * @return total training time. - */ - public double getTotalTrainingTime() { - double totalTrainingTime = 0; - for (NeuralNetworkLayer neuralNetworkLayer : neuralNetworkLayers.values()) { - double layerTrainingTime = neuralNetworkLayer.getTrainingExecutionTime(); - totalTrainingTime += layerTrainingTime; - } - return totalTrainingTime; - } - - /** - * Returns total prediction time. - * - * @return total prediction time. - */ - public double getTotalPredictionTime() { - double totalPredictionTime = 0; - for (NeuralNetworkLayer neuralNetworkLayer : neuralNetworkLayers.values()) totalPredictionTime += neuralNetworkLayer.getPredictExecutionTime(); - return totalPredictionTime; - } - - /** - * Thread run function.
- * Executes given neural network procedures and synchronizes their execution via neural network thread execution lock.
- * - */ - public void run() { - while (true) { - executeLock.lock(); - try { - switch (executionState) { - case TRAIN -> trainIterations(); - case PREDICT -> predictInput(); - case VALIDATE -> validateInput(true); - case TERMINATED -> { - return; - } - } - } - catch (Exception exception) { - exception.printStackTrace(); - System.exit(-1); - } - executeLock.unlock(); - } - } - - /** - * Trains neural network with defined number of iterations. - * - * @throws MatrixException throws exception if matrix operation fails. - * @throws IOException throws exception if neural network persistence operation fails. - * @throws NeuralNetworkException throws exception if neural network training fails. - * @throws DynamicParamException throws exception if parameter (params) setting fails. - */ - private void trainIterations() throws MatrixException, IOException, NeuralNetworkException, DynamicParamException { - trainingSampler.reset(); - int numberOfIterations = trainingSampler.getNumberOfIterations(); - for (int iteration = 0; iteration < numberOfIterations; iteration++) { - trainIteration(); - if (stoppedExecution()) break; - if (!earlyStoppingMap.isEmpty()) { - boolean stopTraining = true; - for (EarlyStopping earlyStopping : earlyStoppingMap.values()) if (!earlyStopping.stopTraining()) stopTraining = false; - if (stopTraining) break; - } - } - stateCompleted(); - } - - /** - * Trains single neural network iteration. - * - * @throws MatrixException throws exception if matrix operation fails. - * @throws IOException throws exception if neural network persistence operation fails. - * @throws NeuralNetworkException throws exception if neural network training fails. - * @throws DynamicParamException throws exception if parameter (params) setting fails. - */ - private void trainIteration() throws MatrixException, IOException, NeuralNetworkException, DynamicParamException { - long startTime = System.nanoTime(); - for (NeuralNetworkLayer neuralNetworkLayer : neuralNetworkLayers.values()) if (reset) neuralNetworkLayer.resetOptimizer(); - TreeMap inputSequences = new TreeMap<>(); - TreeMap outputSequences = new TreeMap<>(); - trainingSampler.getSamples(inputSequences, outputSequences); - for (Map.Entry entry : getOutputLayers().entrySet()) entry.getValue().setTargets(outputSequences.get(entry.getKey())); - for (Map.Entry entry : getInputLayers().entrySet()) entry.getValue().train(inputSequences.get(entry.getKey())); - for (Map.Entry entry : getOutputLayers().entrySet()) entry.getValue().backward(); - for (Map.Entry entry : getInputLayers().entrySet()) entry.getValue().update(); - long endTime = System.nanoTime(); - trainingTime += endTime - startTime; - for (Map.Entry entry : trainingMetrics.entrySet()) entry.getValue().report(getOutputLayers().get(entry.getKey()).getTotalError()); - totalIterations++; - if (autoValidationCycle > 0) { - autoValidationCount++; - if (autoValidationCount >= autoValidationCycle) { - validateInput(false); - if (!earlyStoppingMap.isEmpty()) for (EarlyStopping earlyStopping : earlyStoppingMap.values()) earlyStopping.evaluateValidationCondition(totalIterations); - autoValidationCount = 0; - } - } - if (!earlyStoppingMap.isEmpty()) for (EarlyStopping earlyStopping : earlyStoppingMap.values()) earlyStopping.evaluateTrainingCondition(totalIterations); - if (verboseTraining) verboseTrainingStatus(); - if (persistence != null) persistence.cycle(); - } - - /** - * Verboses (prints to console) neural network training status.
- * Prints number of iteration, neural network training time and training error.
- * - */ - private void verboseTrainingStatus() { - StringBuilder meanSquaredError = new StringBuilder("[ "); - for (SingleRegressionMetric trainingMetric : trainingMetrics.values()) { - meanSquaredError.append(String.format("%.4f", trainingMetric.getLastAbsoluteError())).append(" "); - } - meanSquaredError.append("]"); - if (totalIterations % verboseCycle == 0) System.out.println((neuralNetworkName != null ? neuralNetworkName + ": " : "") + "Training error (iteration #" + totalIterations +"): " + meanSquaredError + ", Training time in seconds: " + (int)(getTotalTrainingTime() / Math.pow(10, 9))); - } - - /** - * Predicts using given test set inputs. - * - */ - private void predictInput() { - for (Map.Entry entry : getInputLayers().entrySet()) entry.getValue().predict(predictInputs.get(entry.getKey())); - stateCompleted(); - } - /** * Validates with given test set inputs and outputs. * - * @param stateCompleted if flag is sets calls stateCompleted function. * @throws MatrixException throws exception if matrix operation fails. * @throws NeuralNetworkException throws exception if validation fails. * @throws DynamicParamException throws exception if parameter (params) setting fails. */ - private void validateInput(boolean stateCompleted) throws MatrixException, NeuralNetworkException, DynamicParamException { + private void validateInput() throws MatrixException, NeuralNetworkException, DynamicParamException { for (Metric validationMetric : validationMetrics.values()) validationMetric.reset(); validationSampler.reset(); int numberOfIterations = validationSampler.getNumberOfIterations(); @@ -1174,53 +1082,7 @@ private void validateInput(boolean stateCompleted) throws MatrixException, Neura for (Map.Entry entry : getInputLayers().entrySet()) entry.getValue().predict(inputSequences.get(entry.getKey())); for (Map.Entry entry : getOutputLayers().entrySet()) validationMetrics.get(entry.getKey()).report(entry.getValue().getLayerOutputs(), outputSequences.get(entry.getKey())); } - if (verboseValidation && (totalIterations % verboseCycle == 0)) verboseValidationStatus(); - if (stateCompleted) stateCompleted(); - } - - /** - * Returns training metrics instance. - * - * @return training metrics instance. - */ - public TreeMap getTrainingMetrics() { - waitToComplete(); - return trainingMetrics; - } - - /** - * Sets if training metrics is shown. - * - * @param showTrainingMetrics if true training metrics is shown otherwise not. - * @throws NeuralNetworkException throws exception if parameter is attempted to be set when neural network is already started. - */ - public void setShowTrainingMetrics(boolean showTrainingMetrics) throws NeuralNetworkException { - if(isStarted()) throw new NeuralNetworkException("Training metrics can be only enabled / disabled when neural network is not started."); - this.showTrainingMetrics = showTrainingMetrics; - } - - /** - * Returns total neural network training iterations count. - * - * @return total neural network training iterations count. - */ - public int getTotalIterations() { - waitToComplete(); - return totalIterations; - } - - /** - * Returns neural network output error. - * - * @throws DynamicParamException throws exception if parameter (params) setting fails. - * @throws MatrixException throws exception if matrix operation fails. - * @return neural network output error. - */ - public TreeMap getOutputError() throws MatrixException, DynamicParamException { - waitToComplete(); - TreeMap outputErrors = new TreeMap<>(); - for (Map.Entry entry : getOutputLayers().entrySet()) outputErrors.put(entry.getKey(), entry.getValue().getTotalError()); - return outputErrors; + if (verboseValidation && (totalTrainingIterations % verboseCycle == 0)) verboseValidationStatus(); } /** @@ -1229,9 +1091,7 @@ public TreeMap getOutputError() throws MatrixException, Dynamic * @param showMetric if true shows metric otherwise not. */ public void setAsRegression(boolean showMetric) { - waitToComplete(); - validationMetrics.clear(); - for (Integer outputLayerIndex : getOutputLayers().keySet()) validationMetrics.put(outputLayerIndex, new RegressionMetric(showMetric)); + setAsRegression(false, showMetric); } /** @@ -1241,8 +1101,7 @@ public void setAsRegression(boolean showMetric) { * @param showMetric if true shows metric otherwise not. */ public void setAsRegression(int outputLayerIndex, boolean showMetric) { - waitToComplete(); - validationMetrics.put(outputLayerIndex, new RegressionMetric(showMetric)); + setAsRegression(outputLayerIndex, false, showMetric); } /** @@ -1254,7 +1113,7 @@ public void setAsRegression(int outputLayerIndex, boolean showMetric) { public void setAsRegression(boolean useR2AsLastError, boolean showMetric) { waitToComplete(); validationMetrics.clear(); - for (Integer outputLayerIndex : getOutputLayers().keySet()) validationMetrics.put(outputLayerIndex, new RegressionMetric(useR2AsLastError, showMetric)); + for (Integer outputLayerIndex : getOutputLayers().keySet()) setAsRegression(outputLayerIndex, useR2AsLastError, showMetric); } /** @@ -1266,7 +1125,7 @@ public void setAsRegression(boolean useR2AsLastError, boolean showMetric) { */ public void setAsRegression(int outputLayerIndex, boolean useR2AsLastError, boolean showMetric) { waitToComplete(); - validationMetrics.put(outputLayerIndex, new RegressionMetric(useR2AsLastError, showMetric)); + setValidationMetric(outputLayerIndex, new RegressionMetric(useR2AsLastError, showMetric)); } /** @@ -1275,9 +1134,7 @@ public void setAsRegression(int outputLayerIndex, boolean useR2AsLastError, bool * @param showMetric if true shows metric otherwise not. */ public void setAsClassification(boolean showMetric) { - waitToComplete(); - validationMetrics.clear(); - for (Integer outputLayerIndex : getOutputLayers().keySet()) validationMetrics.put(outputLayerIndex, new ClassificationMetric(showMetric)); + setAsClassification(false, showMetric); } /** @@ -1287,8 +1144,7 @@ public void setAsClassification(boolean showMetric) { * @param showMetric if true shows metric otherwise not. */ public void setAsClassification(int outputLayerIndex, boolean showMetric) { - waitToComplete(); - validationMetrics.put(outputLayerIndex, new ClassificationMetric(showMetric)); + setAsClassification(outputLayerIndex, false, showMetric); } /** @@ -1300,7 +1156,7 @@ public void setAsClassification(int outputLayerIndex, boolean showMetric) { public void setAsClassification(boolean multiClass, boolean showMetric) { waitToComplete(); validationMetrics.clear(); - for (Integer outputLayerIndex : getOutputLayers().keySet()) validationMetrics.put(outputLayerIndex, new ClassificationMetric(multiClass, showMetric)); + for (Integer outputLayerIndex : getOutputLayers().keySet()) setAsClassification(outputLayerIndex, multiClass, showMetric); } /** @@ -1312,7 +1168,58 @@ public void setAsClassification(boolean multiClass, boolean showMetric) { */ public void setAsClassification(int outputLayerIndex, boolean multiClass, boolean showMetric) { waitToComplete(); - validationMetrics.put(outputLayerIndex, new ClassificationMetric(multiClass, showMetric)); + setValidationMetric(outputLayerIndex, new ClassificationMetric(multiClass, showMetric)); + } + + /** + * Sets validation metric. + * + * @param outputLayerIndex output layer index. + * @param metric metric. + */ + private void setValidationMetric(int outputLayerIndex, Metric metric) { + validationMetrics.put(outputLayerIndex, metric); + if (earlyStoppingMap.get(outputLayerIndex) != null) earlyStoppingMap.get(outputLayerIndex).setValidationMetric(metric); + } + + /** + * Sets auto validation on. + * + * @param autoValidationCycle validation cycle in iterations. + * @throws NeuralNetworkException throws exception if number of auto validation cycles are below 1. + */ + public void setAutoValidate(int autoValidationCycle) throws NeuralNetworkException { + waitToComplete(); + if (autoValidationCycle < 1) throw new NeuralNetworkException("Auto validation cycle size must be at least 1."); + this.autoValidationCycle = autoValidationCycle; + } + + /** + * Unsets auto validation. + * + */ + public void unsetAutoValidate() { + waitToComplete(); + autoValidationCycle = 0; + } + + /** + * Sets verbosing for validation phase.
+ * Follows training verbosing cycle.
+ * + */ + public void verboseValidation() { + waitToComplete(); + verboseValidation = true; + } + + /** + * Unsets verbosing for validation phase. + * + */ + public void unverboseValidation() { + waitToComplete(); + verboseValidation = false; } /** @@ -1360,6 +1267,58 @@ private void verboseValidationStatus() throws MatrixException, DynamicParamExcep for (Metric validationMetric : validationMetrics.values()) validationMetric.printReport(); } + /** + * Predicts values based on given input. + * + * @param inputs inputs for prediction. + * @return predicted values (neural network outputs). + * @throws NeuralNetworkException throws exception if prediction fails. + */ + public TreeMap predictMatrix(TreeMap inputs) throws NeuralNetworkException { + if (inputs.isEmpty()) throw new NeuralNetworkException("No prediction inputs set"); + TreeMap outputs = new TreeMap<>(); + for (Map.Entry entry : predict(Sequence.getSequencesFromMatrices(inputs), true).entrySet()) outputs.put(entry.getKey(), entry.getValue().get(0)); + return outputs; + } + + /** + * Predicts values based on current test set inputs.
+ * Sets specific inputs for prediction.
+ * + * @param inputs test input set for prediction. + * @return predicted values (neural network output). + * @throws NeuralNetworkException throws exception if prediction fails. + */ + public TreeMap predict(TreeMap inputs) throws NeuralNetworkException { + return predict(inputs, true); + } + + /** + * Predicts values based on given inputs.
+ * Optionally waits neural network prediction procedure to complete.
+ * + * @param inputs test input set for prediction. + * @param waitToComplete if true waits for neural network execution complete otherwise returns function prior prediction completion. + * @return predicted values (neural network output). + * @throws NeuralNetworkException throws exception if prediction fails. + */ + public TreeMap predict(TreeMap inputs, boolean waitToComplete) throws NeuralNetworkException { + checkNotStarted(); + waitToComplete(); + if (inputs == null) throw new NeuralNetworkException("No prediction inputs set"); + predictInputs = inputs; + nextState(ExecutionState.PREDICT); + return waitToComplete ? getOutput() : null; + } + + /** + * Predicts using given test set inputs. + * + */ + private void predictInput() { + for (Map.Entry entry : getInputLayers().entrySet()) entry.getValue().predict(predictInputs.get(entry.getKey())); + } + /** * Makes deep copy of neural network by using object serialization. * @@ -1370,7 +1329,6 @@ private void verboseValidationStatus() throws MatrixException, DynamicParamExcep */ public NeuralNetwork copy() throws IOException, ClassNotFoundException, MatrixException { waitToComplete(); - predictInputs.clear(); for (NeuralNetworkLayer neuralNetworkLayer : neuralNetworkLayers.values()) neuralNetworkLayer.reset(); ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); ObjectOutputStream objectOutputStream = new ObjectOutputStream(byteArrayOutputStream); @@ -1407,11 +1365,13 @@ public NeuralNetwork reference() throws IOException, ClassNotFoundException, Mat */ public void reinitialize() throws MatrixException, DynamicParamException { waitToComplete(); - predictInputs.clear(); + predictInputs = new TreeMap<>(); + trainingSampler = null; + validationSampler = null; for (NeuralNetworkLayer neuralNetworkLayer : neuralNetworkLayers.values()) neuralNetworkLayer.reinitialize(); for (Map.Entry entry : earlyStoppingMap.entrySet()) earlyStoppingMap.put(entry.getKey(), entry.getValue().reference()); for (Metric validationMetric : validationMetrics.values()) validationMetric.reinitialize(); - totalIterations = 0; + totalTrainingIterations = 0; trainingTime = 0; } @@ -1429,20 +1389,6 @@ public void append(NeuralNetwork otherNeuralNetwork, double tau) throws MatrixEx } } - /** - * Sets importance sampling weights to output layer. - * - * @param importanceSamplingWeights importance sampling weights - * @throws NeuralNetworkException throws exception if neural network operation fails. - */ - public void setImportanceSamplingWeights(TreeMap> importanceSamplingWeights) throws NeuralNetworkException { - checkNotStarted(); - waitToComplete(); - executeLock.lock(); - for (Map.Entry entry : getOutputLayers().entrySet()) entry.getValue().setImportanceSamplingWeights(importanceSamplingWeights.get(entry.getKey())); - executeLock.unlock(); - } - /** * Prints structure and metadata of neural network.