-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
putting cart before horse, and go ahead and design lstm, gru, and van…
…illa recurrent nets using the planned api for recurrent nets using layer composition, and it is beautiful
- Loading branch information
1 parent
be35fe9
commit 3c27691
Showing
4 changed files
with
200 additions
and
92 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
import { | ||
add, | ||
cloneNegative, | ||
multiply, | ||
multiplyElement, | ||
ones, | ||
sigmoid, | ||
random, | ||
tanh | ||
} from './'; | ||
|
||
export default (settings, input, recurrentInput) => { | ||
const updateGateWeights = random(); | ||
const updateGatePeepholes = random(); | ||
const updateGateBias = random(); | ||
const updateGate = sigmoid( | ||
add( | ||
add( | ||
multiply( | ||
updateGateWeights, | ||
input | ||
), | ||
multiply( | ||
updateGatePeepholes, | ||
recurrentInput | ||
) | ||
), | ||
updateGateBias | ||
) | ||
); | ||
|
||
const resetGateWeights = random(); | ||
const resetGatePeepholes = random(); | ||
const resetGateBias = random(); | ||
let resetGate = sigmoid( | ||
add( | ||
add( | ||
multiply( | ||
resetGateWeights, | ||
input | ||
), | ||
multiply( | ||
resetGatePeepholes, | ||
recurrentInput | ||
) | ||
), | ||
resetGateBias | ||
) | ||
); | ||
|
||
const memoryWeights = random(); | ||
const memoryPeepholes = random(); | ||
const memoryBias = random(); | ||
let cell = tanh( | ||
add( | ||
add( | ||
multiply( | ||
memoryWeights, | ||
input | ||
), | ||
multiply( | ||
memoryPeepholes, | ||
multiplyElement( | ||
resetGate, | ||
recurrentInput | ||
) | ||
) | ||
), | ||
memoryBias | ||
) | ||
); | ||
|
||
// compute hidden state as gated, saturated cell activations | ||
// negate updateGate | ||
return add( | ||
multiplyElement( | ||
add( | ||
ones(updateGate.rows, updateGate.columns), | ||
cloneNegative(updateGate) | ||
), | ||
cell | ||
), | ||
multiplyElement( | ||
recurrentInput, | ||
updateGate | ||
) | ||
); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,102 +1,97 @@ | ||
import Group from './base'; | ||
import { sigmoid, add, multiply, multiplyElement, tanh } from './index'; | ||
import { | ||
add, | ||
multiply, | ||
multiplyElement, | ||
random, | ||
sigmoid, | ||
tanh | ||
} from './index'; | ||
|
||
export default class LSTM extends Group { | ||
constructor(settings) { | ||
super(settings); | ||
|
||
this.inputGate = new LSTMCell(); | ||
this.forgetGate = new LSTMCell(); | ||
this.outputGate = new LSTMCell(); | ||
this.memory = new LSTMCell(); | ||
} | ||
|
||
static createKernel(settings) { | ||
return (layer, inputLayer, previousOutputs) => { | ||
const inputGate = sigmoid( | ||
add( | ||
add( | ||
multiply( | ||
layer.inputGate.inputWeights, | ||
inputLayer | ||
), | ||
multiply( | ||
layer.inputGate.peepholeWeights, | ||
previousOutputs | ||
) | ||
), | ||
layer.inputGate.bias | ||
export default (settings, input, recurrentInput) => { | ||
const inputGateWeights = random(); | ||
const inputGatePeepholes = random(); | ||
const inputGateBias = random(); | ||
const inputGate = sigmoid( | ||
add( | ||
add( | ||
multiply( | ||
inputGateWeights, | ||
input | ||
), | ||
multiply( | ||
inputGatePeepholes, | ||
recurrentInput | ||
) | ||
); | ||
), | ||
inputGateBias | ||
) | ||
); | ||
|
||
const forgetGate = sigmoid( | ||
add( | ||
add( | ||
multiply( | ||
layer.forgetGate.inputWeights, | ||
inputLayer | ||
), | ||
multiply( | ||
layer.forgetGate.peepholeWeights, | ||
previousOutputs | ||
) | ||
), | ||
layer.forgetGate.bias | ||
const forgetGateWeights = random(); | ||
const forgetGatePeepholes = random(); | ||
const forgetGateBias = random(); | ||
const forgetGate = sigmoid( | ||
add( | ||
add( | ||
multiply( | ||
forgetGateWeights, | ||
input | ||
), | ||
multiply( | ||
forgetGatePeepholes, | ||
recurrentInput | ||
) | ||
); | ||
), | ||
forgetGateBias | ||
) | ||
); | ||
|
||
// output gate | ||
const outputGate = sigmoid( | ||
add( | ||
add( | ||
multiply( | ||
layer.outputGate.inputWeights, | ||
inputLayer | ||
), | ||
multiply( | ||
layer.outputGate.peepholeWeights, | ||
previousOutputs | ||
) | ||
), | ||
layer.outputGate.bias | ||
const outputGateWeights = random(); | ||
const outputGatePeepholes = random(); | ||
const outputGateBias = random(); | ||
const outputGate = sigmoid( | ||
add( | ||
add( | ||
multiply( | ||
outputGateWeights, | ||
input | ||
), | ||
multiply( | ||
outputGatePeepholes, | ||
recurrentInput | ||
) | ||
); | ||
), | ||
outputGateBias | ||
) | ||
); | ||
|
||
// write operation on cells | ||
const memory = tanh( | ||
add( | ||
add( | ||
multiply( | ||
layer.memory.inputWeights, | ||
inputLayer | ||
), | ||
multiply( | ||
layer.memory.peepholeWeights, | ||
previousOutputs | ||
) | ||
), | ||
layer.memory.bias | ||
const memoryWeights = random(); | ||
const memoryPeepholes = random(); | ||
const memoryBias = random(); | ||
const memory = tanh( | ||
add( | ||
add( | ||
multiply( | ||
memoryWeights, | ||
input | ||
), | ||
multiply( | ||
memoryPeepholes, | ||
recurrentInput | ||
) | ||
); | ||
|
||
// compute new cell activation | ||
const retainCell = multiplyElement(forgetGate, inputLayer); // what do we keep from cell | ||
const writeCell = multiplyElement(inputGate, memory); // what do we write to cell | ||
const cell = add(retainCell, writeCell); // new cell contents | ||
), | ||
memoryBias | ||
) | ||
); | ||
|
||
// compute hidden state as gated, saturated cell activations | ||
return multiplyElement( | ||
outputGate, | ||
tanh(cell) | ||
); | ||
}; | ||
} | ||
} | ||
// compute new cell activation | ||
const retainCell = multiplyElement(forgetGate, input); // what do we keep from cell | ||
const writeCell = multiplyElement(inputGate, memory); // what do we write to cell | ||
const cell = add(retainCell, writeCell); // new cell contents | ||
|
||
class LSTMCell { | ||
constructor() { | ||
this.inputWeights = {}; | ||
this.peepholeWeights = {}; | ||
this.bias = {}; | ||
} | ||
// compute hidden state as gated, saturated cell activations | ||
return multiplyElement( | ||
outputGate, | ||
tanh(cell) | ||
); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import { relu, add, multiply } from './'; | ||
|
||
export default (settings, recurrentInput, input) => { | ||
//wxh | ||
const weight = new Random(hiddenSize, prevSize, 0.08); | ||
//whh | ||
const transition = new Random(hiddenSize, hiddenSize, 0.08); | ||
//bhh | ||
const bias = new Zeros(hiddenSize, 1); | ||
|
||
return relu( | ||
add( | ||
add( | ||
multiply( | ||
weight, | ||
input | ||
), | ||
multiply( | ||
transition, | ||
recurrentInput | ||
) | ||
), | ||
bias | ||
) | ||
); | ||
} |