Skip to content

Commit

Permalink
putting cart before horse, and go ahead and design lstm, gru, and van…
Browse files Browse the repository at this point in the history
…illa recurrent nets using the planned api for recurrent nets using layer composition, and it is beautiful
  • Loading branch information
robertleeplummerjr committed Dec 19, 2017
1 parent be35fe9 commit 3c27691
Show file tree
Hide file tree
Showing 4 changed files with 200 additions and 92 deletions.
1 change: 0 additions & 1 deletion src/layer/add.js
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ export default class Add extends Base {
}

predict() {
console.log(this.inputLayers[0].weights, this.inputLayers[1].weights);
this.weights = this.predictKernel(this.inputLayers[0].weights, this.inputLayers[1].weights);
}

Expand Down
88 changes: 88 additions & 0 deletions src/layer/gru.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import {
add,
cloneNegative,
multiply,
multiplyElement,
ones,
sigmoid,
random,
tanh
} from './';

export default (settings, input, recurrentInput) => {
const updateGateWeights = random();
const updateGatePeepholes = random();
const updateGateBias = random();
const updateGate = sigmoid(
add(
add(
multiply(
updateGateWeights,
input
),
multiply(
updateGatePeepholes,
recurrentInput
)
),
updateGateBias
)
);

const resetGateWeights = random();
const resetGatePeepholes = random();
const resetGateBias = random();
let resetGate = sigmoid(
add(
add(
multiply(
resetGateWeights,
input
),
multiply(
resetGatePeepholes,
recurrentInput
)
),
resetGateBias
)
);

const memoryWeights = random();
const memoryPeepholes = random();
const memoryBias = random();
let cell = tanh(
add(
add(
multiply(
memoryWeights,
input
),
multiply(
memoryPeepholes,
multiplyElement(
resetGate,
recurrentInput
)
)
),
memoryBias
)
);

// compute hidden state as gated, saturated cell activations
// negate updateGate
return add(
multiplyElement(
add(
ones(updateGate.rows, updateGate.columns),
cloneNegative(updateGate)
),
cell
),
multiplyElement(
recurrentInput,
updateGate
)
);
}
177 changes: 86 additions & 91 deletions src/layer/lstm.js
Original file line number Diff line number Diff line change
@@ -1,102 +1,97 @@
import Group from './base';
import { sigmoid, add, multiply, multiplyElement, tanh } from './index';
import {
add,
multiply,
multiplyElement,
random,
sigmoid,
tanh
} from './index';

export default class LSTM extends Group {
constructor(settings) {
super(settings);

this.inputGate = new LSTMCell();
this.forgetGate = new LSTMCell();
this.outputGate = new LSTMCell();
this.memory = new LSTMCell();
}

static createKernel(settings) {
return (layer, inputLayer, previousOutputs) => {
const inputGate = sigmoid(
add(
add(
multiply(
layer.inputGate.inputWeights,
inputLayer
),
multiply(
layer.inputGate.peepholeWeights,
previousOutputs
)
),
layer.inputGate.bias
export default (settings, input, recurrentInput) => {
const inputGateWeights = random();
const inputGatePeepholes = random();
const inputGateBias = random();
const inputGate = sigmoid(
add(
add(
multiply(
inputGateWeights,
input
),
multiply(
inputGatePeepholes,
recurrentInput
)
);
),
inputGateBias
)
);

const forgetGate = sigmoid(
add(
add(
multiply(
layer.forgetGate.inputWeights,
inputLayer
),
multiply(
layer.forgetGate.peepholeWeights,
previousOutputs
)
),
layer.forgetGate.bias
const forgetGateWeights = random();
const forgetGatePeepholes = random();
const forgetGateBias = random();
const forgetGate = sigmoid(
add(
add(
multiply(
forgetGateWeights,
input
),
multiply(
forgetGatePeepholes,
recurrentInput
)
);
),
forgetGateBias
)
);

// output gate
const outputGate = sigmoid(
add(
add(
multiply(
layer.outputGate.inputWeights,
inputLayer
),
multiply(
layer.outputGate.peepholeWeights,
previousOutputs
)
),
layer.outputGate.bias
const outputGateWeights = random();
const outputGatePeepholes = random();
const outputGateBias = random();
const outputGate = sigmoid(
add(
add(
multiply(
outputGateWeights,
input
),
multiply(
outputGatePeepholes,
recurrentInput
)
);
),
outputGateBias
)
);

// write operation on cells
const memory = tanh(
add(
add(
multiply(
layer.memory.inputWeights,
inputLayer
),
multiply(
layer.memory.peepholeWeights,
previousOutputs
)
),
layer.memory.bias
const memoryWeights = random();
const memoryPeepholes = random();
const memoryBias = random();
const memory = tanh(
add(
add(
multiply(
memoryWeights,
input
),
multiply(
memoryPeepholes,
recurrentInput
)
);

// compute new cell activation
const retainCell = multiplyElement(forgetGate, inputLayer); // what do we keep from cell
const writeCell = multiplyElement(inputGate, memory); // what do we write to cell
const cell = add(retainCell, writeCell); // new cell contents
),
memoryBias
)
);

// compute hidden state as gated, saturated cell activations
return multiplyElement(
outputGate,
tanh(cell)
);
};
}
}
// compute new cell activation
const retainCell = multiplyElement(forgetGate, input); // what do we keep from cell
const writeCell = multiplyElement(inputGate, memory); // what do we write to cell
const cell = add(retainCell, writeCell); // new cell contents

class LSTMCell {
constructor() {
this.inputWeights = {};
this.peepholeWeights = {};
this.bias = {};
}
// compute hidden state as gated, saturated cell activations
return multiplyElement(
outputGate,
tanh(cell)
);
}
26 changes: 26 additions & 0 deletions src/layer/recurrent.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import { relu, add, multiply } from './';

export default (settings, recurrentInput, input) => {
//wxh
const weight = new Random(hiddenSize, prevSize, 0.08);
//whh
const transition = new Random(hiddenSize, hiddenSize, 0.08);
//bhh
const bias = new Zeros(hiddenSize, 1);

return relu(
add(
add(
multiply(
weight,
input
),
multiply(
transition,
recurrentInput
)
),
bias
)
);
}

0 comments on commit 3c27691

Please sign in to comment.