Skip to content

Commit

Permalink
vol3:java:Finalized structure for conv-net; however, some cleanup and…
Browse files Browse the repository at this point in the history
… likely debugging needed. Convergance rates could be better. #31
  • Loading branch information
jeffheaton committed Jan 23, 2016
1 parent 935eb0b commit 582a38a
Show file tree
Hide file tree
Showing 8 changed files with 128 additions and 21 deletions.
2 changes: 1 addition & 1 deletion vol3/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ or improve an example, please consider pushing a change via GitHub.
* vol3-java-examples
* vol3-csharp-examples

*Note: I am in the process of refactoring the conv-nets in the Java and C# versions. This is to fix a few bugs that were reported. Sorry for any inconvenience, I hope to have the new version complete by the end of January 2016.*
**Note: I am in the process of refactoring the conv-nets in the Java and C# versions. This is to fix a few bugs that were reported. Sorry for any inconvenience, I hope to have the new version complete by the end of January 2016.**

## Getting Help

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,9 @@ public void finalizeStructure() {
this.inputCount = this.layers.get(0).getCount();
this.outputCount = this.layers.get(layerCount - 1).getCount();

for (int i = this.layers.size() - 1; i >= 0; i--) {
final Layer layer = this.layers.get(i);
layer.finalizeStructure(this, i);
int i = 0;
for(Layer layer: this.layers) {
layer.finalizeStructure(this, i++);
this.layerOutput.addFlatObject(layer.getLayerSums());
this.layerOutput.addFlatObject(layer.getLayerOutput());
if( layer.getWeightMatrix()!=null ) {
Expand All @@ -213,6 +213,7 @@ public void finalizeStructure() {
}

this.layerOutput.finalizeStructure();
this.weights.reverseOrder();
this.weights.finalizeStructure();

clearOutput();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,12 @@
import com.heatonresearch.aifh.ann.activation.ActivationFunction;
import com.heatonresearch.aifh.ann.train.GradientCalc;
import com.heatonresearch.aifh.flat.FlatMatrix;
import com.heatonresearch.aifh.flat.FlatObject;
import com.heatonresearch.aifh.flat.FlatVolume;
import com.heatonresearch.aifh.randomize.GenerateRandom;

import java.util.Arrays;

/**
* A 2D convolution layer.
*
Expand Down Expand Up @@ -66,12 +69,12 @@ public class Conv2DLayer implements Layer {
/**
* The output columns.
*/
private int outColumns;
private int scanCols;

/**
* The output rows.
*/
private int outRows;
private int scanRows;

/**
* The input depth.
Expand Down Expand Up @@ -128,13 +131,17 @@ public void finalizeStructure(BasicNetwork theOwner, int theLayerIndex) {
throw new AIFHError("Conv2DLayer must have a previous layer (cannot be used as the input layer).");
}

this.outRows = (int)Math.floor((prevLayer.getDimensionCounts()[0] + this.padding * 2 - this.filterSize) / this.stride + 1);
this.outColumns = (int)Math.floor((prevLayer.getDimensionCounts()[1] + this.padding * 2 - this.filterSize) / this.stride + 1);
this.scanRows = (int)Math.floor((prevLayer.getDimensionCounts()[0] + this.padding * 2 - this.filterSize) / this.stride + 1);
this.scanCols = (int)Math.floor((prevLayer.getDimensionCounts()[1] + this.padding * 2 - this.filterSize) / this.stride + 1);

int[] shape = {this.outRows, this.outColumns, this.numFilters};
int[] shape = {this.scanRows, this.scanCols, this.numFilters};

this.layerOutput = new FlatVolume(shape, true);
this.layerSums = new FlatVolume(shape, true);
this.weightMatrix = new FlatMatrix[this.numFilters];
for(int i=0;i<this.weightMatrix.length;i++) {
this.weightMatrix[i] = new FlatMatrix(getCount(), prevLayer.getTotalCount());
}
}

@Override
Expand All @@ -160,7 +167,7 @@ public FlatMatrix[] getWeightMatrix() {
*/
@Override
public int getCount() {
return this.outRows * this.outColumns;
return this.filterSize * this.filterSize;
}

/**
Expand All @@ -185,9 +192,41 @@ public ActivationFunction getActivation() {
@Override
public void computeLayer() {
Layer prev = getOwner().getPreviousLayer(this);
int prevRows = prev.getDimensionCounts()[0];
int prevColumns = prev.getDimensionCounts()[1];
int prevDepth = prev.getDimensionCounts()[2];

// Loop over every filter
for(int currentFilter=0;currentFilter<this.numFilters;currentFilter++) {

int y = -this.padding;

// Scan each filter over the previous layer (shared weights). Handle rows.
for(int filterRow = 0; filterRow<this.scanRows; y+=this.stride,filterRow++) {
int x = -this.padding;

// Scan each filter over the previous layer (shared weights). Handle columns.
for (int filterCol = 0; filterCol < this.scanCols; x += this.stride, filterCol++) {

// Now process the previous layer's image at each scan position.
double sum = 0.0;
// Process the rows at each scan point.
for(int prevRowIndex = 0; prevRowIndex<this.scanRows; prevRowIndex++) {
int prevRow = y+prevRowIndex;
// Process the columns at each scan point.
for(int prevColIndex = 0; prevColIndex<this.scanCols; prevColIndex++) {
int prevCol = x+prevColIndex;
if(prevRow>=0 && prevRow<prevRows && prevCol>=0 && prevCol<prevColumns) {
// Process each element of the previous level's depth.
for(int currentPrevDepth=0;currentPrevDepth<prevDepth;currentPrevDepth++) {
sum += this.weightMatrix[currentFilter].get(prevRowIndex,prevColIndex)
* prev.getLayerOutput().get(prevRow,prevCol,currentPrevDepth);
}
}
}
}
this.layerOutput.set(filterRow,filterCol,currentFilter,sum);
}
}
}
}

Expand All @@ -196,11 +235,48 @@ public void computeLayer() {
*/
@Override
public void computeGradient(GradientCalc calc) {
final Layer prev = getOwner().getPreviousLayer(this);
final FlatVolume prevLayerDelta = (FlatVolume)calc.getLayerDelta().get(getLayerIndex()-1);
final FlatVolume layerDelta = (FlatVolume)calc.getLayerDelta().get(getLayerIndex());

final ActivationFunction activation = getActivation();
int totalLayers = getOwner().getLayers().size()-1;
FlatMatrix gradientMatrix = (FlatMatrix)calc.getGradientMatrix().getFlatObjects().get(totalLayers-getLayerIndex());

// Calculate the output for each filter (depth).
for(int dOutput=0;dOutput<this.numFilters;dOutput++) {
for (int dInput = 0; dInput < this.inDepth; dInput++) {
//computeGradient(calc);
int prevRows = prev.getDimensionCounts()[0];
int prevColumns = prev.getDimensionCounts()[1];
int prevDepth = prev.getDimensionCounts()[2];

// Loop over every filter
for(int currentFilter=0;currentFilter<this.numFilters;currentFilter++) {
int y = -this.padding;

// Scan each filter over the previous layer (shared weights). Handle rows.
for(int filterRow = 0; filterRow<this.scanRows; y+=this.stride,filterRow++) {
int x = -this.padding;
// Scan each filter over the previous layer (shared weights). Handle columns.
for (int filterCol = 0; filterCol < this.scanCols; x += this.stride, filterCol++) {

// Now process the previous layer's image at each scan position.
double sum = 0.0;
// Process the rows at each scan point.
for(int prevRowIndex = 0; prevRowIndex<this.scanRows; prevRowIndex++) {
int prevRow = y+prevRowIndex;
// Process the columns at each scan point.
for(int prevColIndex = 0; prevColIndex<this.scanCols; prevColIndex++) {
int prevCol = x+prevColIndex;
if(prevRow>=0 && prevRow<prevRows && prevCol>=0 && prevCol<prevColumns) {
// Process each element of the previous level's depth.
for(int currentPrevDepth=0;currentPrevDepth<prevDepth;currentPrevDepth++) {
//gradientMatrix.add(prevRowIndex,prevColIndex, -(output * layerDelta.get(xi)));
sum += this.weightMatrix[currentFilter].get(prevRowIndex,prevColIndex)
* prevLayerDelta.get(prevRow,prevCol, currentPrevDepth);
}
}
}
}
this.layerOutput.set(filterRow,filterCol,currentFilter,sum);
}
}
}
}
Expand Down Expand Up @@ -267,5 +343,16 @@ public int getStride() {
return this.stride;
}

@Override
public String toString() {
StringBuilder result = new StringBuilder();
result.append("[");
result.append(this.getClass().getSimpleName());
result.append(":dimensions:"+ Arrays.toString(getDimensionCounts()));
result.append(", totalCount:" + getTotalCount());
result.append("]");
return result.toString();
}


}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
import com.heatonresearch.aifh.flat.FlatMatrix;
import com.heatonresearch.aifh.flat.FlatObject;

import java.util.Arrays;

/**
* Base class for all layers (used with BasicNetwork) that have weights.
*/
Expand Down Expand Up @@ -166,6 +168,7 @@ public String toString() {
final StringBuilder result = new StringBuilder();
result.append("[");
result.append(this.getClass().getSimpleName());
result.append(",dimensions=").append(Arrays.toString(getDimensionCounts()));
result.append(",count=").append(getCount());

result.append("]");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
package com.heatonresearch.aifh.ann.randomize;

import com.heatonresearch.aifh.ann.BasicNetwork;
import com.heatonresearch.aifh.flat.FlatMatrix;

/**
* The Xaiver initialization (aka Glorot) weight initialization. A very good weight initialization method that provides very
Expand All @@ -51,8 +52,14 @@ private void randomizeLayer(BasicNetwork network, int fromLayer) {
for (int fromNeuron = 0; fromNeuron < fromCount; fromNeuron++) {
for (int toNeuron = 0; toNeuron < toCount; toNeuron++) {
double sigma = Math.sqrt(2.0/(fromCount+toCount));
double w = this.getRnd().nextGaussian() * sigma;
network.setWeight(fromLayer, fromNeuron, toNeuron, w);
FlatMatrix[] matrixes = network.getLayers().get(fromLayer+1).getWeightMatrix();

// Handle layers with multiple weight matrixes (e.g. convolution layers)
for(int i=0;i<matrixes.length;i++) {
double w = this.getRnd().nextGaussian() * sigma;
matrixes[i].set(toNeuron,fromNeuron, w);
}

}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
import com.heatonresearch.aifh.flat.FlatObject;
import com.heatonresearch.aifh.flat.FlatVolume;

import java.util.Arrays;

/**
* A utility class used to help calculate the gradient of the error function for neural networks.
*/
Expand Down Expand Up @@ -86,7 +88,11 @@ public GradientCalc(final BasicNetwork theNetwork,
this.errorFunction = ef;

for(Layer layer: this.network.getLayers()) {
this.layerDelta.addFlatObject(new FlatVolume(layer.getTotalCount(),1,1,false));
int r = layer.getDimensionCounts()[0];
int c = layer.getDimensionCounts()[1];
int d = layer.getDimensionCounts()[2];
this.layerDelta.addFlatObject(new FlatVolume(r,c,d,false));

if( layer.getWeightMatrix() != null ) {
for(FlatMatrix matrix: layer.getWeightMatrix()) {
this.gradients.addFlatObject(new FlatMatrix(matrix));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,12 @@ public void process() {

int outputCount = trainingReader.getData().get(0).getIdeal().length;

int[] inputShape = new int[] {trainingReader.getNumCols(),trainingReader.getNumRows(),3};
int[] inputShape = new int[] {trainingReader.getNumCols(),trainingReader.getNumRows(),1};

BasicNetwork network = new BasicNetwork();
network.addLayer(new BasicLayer(null,true,inputShape));
network.addLayer(new Conv2DLayer(new ActivationReLU(),32,5));
network.addLayer(new BasicLayer(new ActivationReLU(),true,100));
network.addLayer(new Conv2DLayer(new ActivationReLU(),8,5));
//network.addLayer(new BasicLayer(new ActivationReLU(),true,100));
network.addLayer(new BasicLayer(new ActivationSoftMax(),false,outputCount));
network.finalizeStructure();
network.reset();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@ public double get(final int index) {
*/
@Override
public void set(final int index, final double d) {
if( index>=this.length ) {
throw new AIFHError("Length ("+this.length+") exceeded: "+index);
}
this.data[this.offset+index] = d;
}

Expand Down

0 comments on commit 582a38a

Please sign in to comment.