-
Notifications
You must be signed in to change notification settings - Fork 0
/
Brain.js
54 lines (50 loc) · 1.62 KB
/
Brain.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
class Brain {
constructor(inputLayers, hiddenLayers, outputLayers) {
this.inputLayers = inputLayers;
this.hiddenLayers = hiddenLayers;
this.outputLayers = outputLayers;
this.model = tf.sequential();
this.model.add(tf.layers.dense({ inputShape: [inputLayers], units: hiddenLayers, activation: 'relu' }));
this.model.add(tf.layers.dense({ units: outputLayers, activation: 'softmax' }));
}
think = (input) => {
const prediction = this.model.predict(tf.tensor2d([input]));
// prediction.print();
return prediction.dataSync();
};
clone = () => {
return tf.tidy(() => {
const brainCopy = new Brain(this.inputLayers, this.hiddenLayers, this.outputLayers);
const weights = this.model.getWeights();
const weightCopies = [];
for (let i = 0; i < weights.length; i++) {
weightCopies[i] = weights[i].clone();
}
brainCopy.model.setWeights(weightCopies);
return brainCopy;
});
};
mutate = (rate) => {
tf.tidy(() => {
const weights = this.model.getWeights();
const mutatedWeights = [];
for (let i = 0; i < weights.length; i++) {
let tensor = weights[i];
let shape = weights[i].shape;
let values = tensor.dataSync().slice();
for (let j = 0; j < values.length; j++) {
if (rate > random()) {
let w = values[j];
values[j] = w + randomGaussian();
}
}
let newTensor = tf.tensor(values, shape);
mutatedWeights[i] = newTensor;
}
this.model.setWeights(mutatedWeights);
});
};
clear = () => {
this.model.dispose();
};
}