Skip to content

Commit

Permalink
Connection gradients
Browse files Browse the repository at this point in the history
  • Loading branch information
vitaliiznak committed Sep 18, 2024
1 parent 1316f06 commit 9c24759
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 0 deletions.
39 changes: 39 additions & 0 deletions src/trainer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,45 @@ export class Trainer {
return result;
}

stepBackwardAndFetchGradientsGroupedByConnection(): BackwardStepGradientsPerConnecrion | null {
// Recalculate the loss before each backward step
this.calculateLoss();

if (!this.currentLoss) {
console.error("Loss not calculated");
return null;
}

// Zero out existing gradients
this._network.zeroGrad();

// Perform backpropagation
this.currentLoss.backward();

const gradients: BackwardStepGradientsPerConnecrion = [];

// Iterate through each layer and neuron to collect gradients
this._network.layers.forEach((layer, layerIndex) => {
layer.neurons.forEach((neuron, neuronIndex) => {
neuron.w.forEach((weight, weightIndex) => {
const fromNodeId = layerIndex === 0 ? `input_${weightIndex}` : `layer${layerIndex - 1}_neuron${weightIndex}`;
const toNodeId = `layer${layerIndex}_neuron${neuronIndex}`;
const connectionId = `conn_${fromNodeId}_to_${toNodeId}`;

gradients.push({
connectionId,
weightGradient: weight.grad,
biasGradient: neuron.b.grad,
});
});
});
});

console.log('Gradients after backward pass:', gradients);

return gradients;
}

updateWeights(learningRate: number): TrainingStepResult {
const oldWeights = this._network.parameters().map(p => p.data);
this._network.parameters().forEach(p => {
Expand Down
6 changes: 6 additions & 0 deletions src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ export type BackwardStepGradients = {
gradients: number[];
}[];

export type BackwardStepGradientsPerConnecrion = {
connectionId: string,
weightGradient: number,
biasGradient: number,
}[];

export interface TrainingStepResult {
gradients: number[] | null;
oldWeights: number[] | null;
Expand Down

0 comments on commit 9c24759

Please sign in to comment.