Skip to content

Commit

Permalink
Batch reset
Browse files Browse the repository at this point in the history
  • Loading branch information
vitaliiznak committed Sep 22, 2024
1 parent 59397c4 commit fec9ef3
Show file tree
Hide file tree
Showing 9 changed files with 170 additions and 147 deletions.
31 changes: 14 additions & 17 deletions src/FunctionVisualizer.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -17,78 +17,75 @@ const FunctionVisualizer: Component = () => {
const { xs, ys } = store.trainingData;

// Generate points for the true function
const trueX = Array.from({ length: 100 }, (_, i) => i);
const trueX = Array.from({ length: 100 }, (_, i) => i / 100); // 0 to 1 with step 0.01
const trueY = trueX.map(getTrueFunction);

// Generate points for the learned function
const learnedY = trueX.map(x => {
// Assuming store.network.forward accepts a regular number input
// and returns a regular number output
const output = store.network.forward([x]);
return output[0].data;
});

// Prepare data for the neural network predictions
const nnX = xs.map(x => x[0]);
const nnY = ys;
const nnX = xs.map(x => x[0]); // Already in 0-1 range
const nnY = ys; // Already in 0-1 range

const data = [
{
x: trueX,
y: trueY,
type: 'scatter',
mode: 'lines+text', // Include text labels
mode: 'lines+markers', // Changed from 'lines+text' to 'lines+markers' for better visibility
name: 'True Function',
line: { color: colors.primary, width: 3 },

marker: { color: colors.primary, size: 6 },
hoverinfo: 'x+y',
},
{
x: nnX,
y: nnY,
type: 'scatter',
mode: 'markers+text', // Include text labels
mode: 'markers', // Changed from 'markers+text' to 'markers' to reduce clutter
name: 'Training Data',
marker: { color: colors.error, size: 8 },

hoverinfo: 'x+y',
},
{
x: trueX,
y: learnedY,
type: 'scatter',
mode: 'lines+text', // Include text labels
mode: 'lines',
name: 'Learned Function',
line: { color: colors.success, width: 3, dash: 'dash' },
visible: showLearnedFunction() ? true : 'legendonly',

hoverinfo: 'x+y',
}
];

const layout = {
xaxis: {
title: 'ChatGPT Usage (%)',
range: [0, 100],
title: 'ChatGPT Usage (0-1)',
range: [0, 1],
gridcolor: colors.border,
zerolinecolor: colors.border,
tickformat: '.2f',
},
yaxis: {
title: 'Productivity Score',
range: [0, 100],
title: 'Productivity Score (0-1)',
range: [0, 1],
gridcolor: colors.border,
zerolinecolor: colors.border,
tickformat: '.2f',
},
legend: {
x: 1,
xanchor: 'right',
y: 1,

bordercolor: colors.border,
borderwidth: 1,
},
hovermode: 'closest',
plot_bgcolor: '#1B213D', // Ensure this matches the desired background color
plot_bgcolor: '#1B213D',
paper_bgcolor: '#1B213D',
font: {
family: typography.fontFamily,
Expand Down
9 changes: 3 additions & 6 deletions src/NeuralNetwork/MLP.ts
Original file line number Diff line number Diff line change
Expand Up @@ -89,19 +89,16 @@ export class MLP {
};
const newMLP = new MLP(config);

// Copy weights and biases
// Deep copy weights and biases
this.layers.forEach((layer, i) => {
layer.neurons.forEach((neuron, j) => {
neuron.w.forEach((w, k) => {
newMLP.layers[i].neurons[j].w[k].data = w.data;
newMLP.layers[i].neurons[j].w[k] = new Value(w.data, [], 'weight');
});
newMLP.layers[i].neurons[j].b.data = neuron.b.data;
newMLP.layers[i].neurons[j].b = new Value(neuron.b.data, [], 'bias');
});
});

// Ensure connection IDs are regenerated consistently
// Assuming connections are handled in the layout generation

return newMLP;
}

Expand Down
2 changes: 1 addition & 1 deletion src/NeuralNetwork/Neuron.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ export class Neuron {

constructor(nin: number, activation: ActivationFunction = 'tanh') {
this.w = Array(nin).fill(0).map(() => new Value(Math.random() * 2 - 1));
this.b = new Value(0);
this.b = new Value(Math.random() * 0.1 - 0.05); // Initialize bias with a small random value
this.activation = activation;
console.log(`Neuron created with ${nin} inputs and ${this.activation} activation`);
}
Expand Down
2 changes: 1 addition & 1 deletion src/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ export const CONFIG = {
INITIAL_NETWORK: {
inputSize: 1,
layers: [5, 3, 1],
activations: ['tanh', 'leaky-relu', 'identity']
activations: ['leaky-relu', 'leaky-relu', 'identity']
} as MLPConfig,
INITIAL_TRAINING: {
learningRate: 0.005
Expand Down
158 changes: 79 additions & 79 deletions src/index.ts
Original file line number Diff line number Diff line change
@@ -1,82 +1,82 @@
import { MLP } from "./NeuralNetwork/mlp";
import { Value } from "./NeuralNetwork/value";

// Example usage
const xs: number[][] = [
[2.0, 3.0, -1.0],
[3.0, -1.0, 0.5],
[0.5, 1.0, 1.0],
[1.0, 1.0, -1.0],
];
const yt: number[] = [1.0, -1.0, -1.0, 1.0];

// Create MLP
const n = new MLP({
inputSize: 1,
layers: [4, 4, 1],
activations: ['tanh', 'tanh', 'tanh']
});

// Hyperparameters
const learningRate = 0.01;
const iterations = 100;
const batchSize = 2;

// Training loop
for (let iteration = 0; iteration < iterations; iteration++) {
let totalLoss = new Value(0);

// Mini-batch training
for (let i = 0; i < xs.length; i += batchSize) {
const batchXs = xs.slice(i, i + batchSize);
const batchYt = yt.slice(i, i + batchSize);

const ypred = batchXs.map(x => n.forward(x.map(val => new Value(val)))[0]);

const loss = ypred.reduce((sum, ypred_el, j) => {
const target = new Value(batchYt[j]);
const diff = ypred_el.sub(target);
const squaredError = diff.mul(diff);
return sum.add(squaredError);
}, new Value(0));

// Accumulate total loss
totalLoss = totalLoss.add(loss);

// Backward pass
n.zeroGrad();
loss.backward();

// Update parameters
n.parameters().forEach(p => {
p.data -= learningRate * p.grad;
});

// Inside the training loop, after calculating the loss
console.log("Loss function tree:");
console.log(loss.toDot());
}

// Log average loss for the iteration
console.log(`Iteration ${iteration + 1}, Average Loss: ${totalLoss.data / xs.length}`);

// Early stopping (optional)
if (totalLoss.data / xs.length < 0.01) {
console.log(`Converged at iteration ${iteration + 1}`);
break;
}
}

// Evaluation
function evaluate(x: number[]): number {
const result = n.forward(x.map(val => new Value(val)));
return result[0].data;
}

console.log("Evaluation:");
xs.forEach((x, i) => {
console.log(`Input: [${x}], Predicted: ${evaluate(x).toFixed(4)}, Actual: ${yt[i]}`);
});
// import { MLP } from "./NeuralNetwork/mlp";
// import { Value } from "./NeuralNetwork/value";

// // Example usage
// const xs: number[][] = [
// [2.0, 3.0, -1.0],
// [3.0, -1.0, 0.5],
// [0.5, 1.0, 1.0],
// [1.0, 1.0, -1.0],
// ];
// const yt: number[] = [1.0, -1.0, -1.0, 1.0];

// // Create MLP
// const n = new MLP({
// inputSize: 1,
// layers: [4, 4, 1],
// activations: ['tanh', 'tanh', 'tanh']
// });

// // Hyperparameters
// const learningRate = 0.01;
// const iterations = 100;
// const batchSize = 2;

// // Training loop
// for (let iteration = 0; iteration < iterations; iteration++) {
// let totalLoss = new Value(0);

// // Mini-batch training
// for (let i = 0; i < xs.length; i += batchSize) {
// const batchXs = xs.slice(i, i + batchSize);
// const batchYt = yt.slice(i, i + batchSize);

// const ypred = batchXs.map(x => n.forward(x.map(val => new Value(val)))[0]);

// const loss = ypred.reduce((sum, ypred_el, j) => {
// const target = new Value(batchYt[j]);
// const diff = ypred_el.sub(target);
// const squaredError = diff.mul(diff);
// return sum.add(squaredError);
// }, new Value(0));

// // Accumulate total loss
// totalLoss = totalLoss.add(loss);

// // Backward pass
// n.zeroGrad();
// loss.backward();

// // Update parameters
// n.parameters().forEach(p => {
// p.data -= learningRate * p.grad;
// });

// // Inside the training loop, after calculating the loss
// console.log("Loss function tree:");
// console.log(loss.toDot());
// }

// // Log average loss for the iteration
// console.log(`Iteration ${iteration + 1}, Average Loss: ${totalLoss.data / xs.length}`);

// // Early stopping (optional)
// if (totalLoss.data / xs.length < 0.01) {
// console.log(`Converged at iteration ${iteration + 1}`);
// break;
// }
// }

// // Evaluation
// function evaluate(x: number[]): number {
// const result = n.forward(x.map(val => new Value(val)));
// return result[0].data;
// }

// console.log("Evaluation:");
// xs.forEach((x, i) => {
// console.log(`Input: [${x}], Predicted: ${evaluate(x).toFixed(4)}, Actual: ${yt[i]}`);
// });



13 changes: 10 additions & 3 deletions src/store.ts
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,6 @@ function trainingStateReset() {
setStore('trainingState', {
forwardStepResults: [],
backwardStepGradients: [],
lossHistory: [],
currentLoss: null,
weightUpdateResults: [],
currentPhase: 'idle'
});
Expand Down Expand Up @@ -217,7 +215,7 @@ function calculateLoss() {
}

const result = store.trainer.calculateLoss();
console.log("After calling trainer.calculateLoss(). Result:", result);
// console.log("After calling trainer.calculateLoss(). Result:", result);

let currentLoss: number;
if (result === null) {
Expand Down Expand Up @@ -281,6 +279,15 @@ function updateWeights() {
setStore('networkUpdateTrigger', store.networkUpdateTrigger + 1);

console.log("Action: Weights updated successfully");

// ** Reset Batch After Weight Update **
store.trainer.resetBatch();
batch(() => {
setStore('trainingState', 'forwardStepResults', []);
setStore('trainingState', 'backwardStepGradients', []);
setStore('trainingState', 'currentPhase', 'idle');
});
console.log('Batch has been reset for the next training cycle.');
});
}

Expand Down
9 changes: 7 additions & 2 deletions src/trainer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,7 @@ export class Trainer {
totalLoss = totalLoss.add(squaredDiff);
}

// Label the batch size value
const batchSizeValue = new Value(inputs.length, []);

const avgLoss = totalLoss.div(batchSizeValue);
console.log(`Total Loss: ${totalLoss.data}, Avg Loss: ${avgLoss.data}`);
return avgLoss;
Expand Down Expand Up @@ -249,4 +247,11 @@ export class Trainer {
return this.currentIteration;
}

resetBatch(): void {
this.currentBatchInputs = [];
this.currentBatchTargets = [];
this.currentDataIndex = 0;
console.log('Trainer batch has been reset.');
}

}
12 changes: 3 additions & 9 deletions src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -99,18 +99,12 @@ export interface BackwardStepGradientsPerConnection {
export type BackwardStepGradients = BackwardStepGradientsPerConnection[];

export interface TrainingStepResult {
gradients: number[] | null;
oldWeights: number[] | null;
newWeights: number[] | null;
gradients: BackwardStepGradientsPerConnection[] | null;
oldWeights: number[];
newWeights: number[];
}

export interface ForwardStepResults {
input: number[];
output: number[];
}

export interface TrainingStepResult {
gradients: number[] | null;
oldWeights: number[] | null;
newWeights: number[] | null;
}
Loading

0 comments on commit fec9ef3

Please sign in to comment.