Skip to content

Commit

Permalink
Batch reset
Browse files Browse the repository at this point in the history
  • Loading branch information
vitaliiznak committed Sep 22, 2024
1 parent 59397c4 commit 958a3ed
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 11 deletions.
11 changes: 9 additions & 2 deletions src/store.ts
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,6 @@ function trainingStateReset() {
setStore('trainingState', {
forwardStepResults: [],
backwardStepGradients: [],
lossHistory: [],
currentLoss: null,
weightUpdateResults: [],
currentPhase: 'idle'
});
Expand Down Expand Up @@ -281,6 +279,15 @@ function updateWeights() {
setStore('networkUpdateTrigger', store.networkUpdateTrigger + 1);

console.log("Action: Weights updated successfully");

// ** Reset Batch After Weight Update **
store.trainer.resetBatch();
batch(() => {
setStore('trainingState', 'forwardStepResults', []);
setStore('trainingState', 'backwardStepGradients', []);
setStore('trainingState', 'currentPhase', 'idle');
});
console.log('Batch has been reset for the next training cycle.');
});
}

Expand Down
7 changes: 7 additions & 0 deletions src/trainer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -249,4 +249,11 @@ export class Trainer {
return this.currentIteration;
}

resetBatch(): void {
this.currentBatchInputs = [];
this.currentBatchTargets = [];
this.currentDataIndex = 0;
console.log('Trainer batch has been reset.');
}

}
12 changes: 3 additions & 9 deletions src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -99,18 +99,12 @@ export interface BackwardStepGradientsPerConnection {
export type BackwardStepGradients = BackwardStepGradientsPerConnection[];

export interface TrainingStepResult {
gradients: number[] | null;
oldWeights: number[] | null;
newWeights: number[] | null;
gradients: BackwardStepGradientsPerConnection[] | null;
oldWeights: number[];
newWeights: number[];
}

export interface ForwardStepResults {
input: number[];
output: number[];
}

export interface TrainingStepResult {
gradients: number[] | null;
oldWeights: number[] | null;
newWeights: number[] | null;
}

0 comments on commit 958a3ed

Please sign in to comment.