Skip to content

Commit

Permalink
WIP: show gradients per conection
Browse files Browse the repository at this point in the history
  • Loading branch information
vitaliiznak committed Sep 20, 2024
1 parent 3d8b7d8 commit 0a9e702
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 19 deletions.
22 changes: 7 additions & 15 deletions src/TrainingControl/TrainingStepsVisualizer.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,12 @@ import { Component, For, Show, createMemo } from "solid-js";
import { css } from "@emotion/css";
import { FaSolidArrowRight, FaSolidCalculator, FaSolidArrowLeft, FaSolidArrowDown, FaSolidLayerGroup } from 'solid-icons/fa';
import WeightUpdateStep from './WeightUpdateStep';
import { TrainingStepResult } from "../types";
import { BackwardStepGradientsPerConnecrion, TrainingStepResult } from "../types";
import { colors } from '../styles/colors';

interface TrainingStepsVisualizerProps {
forwardStepResults: { input: number[], output: number[] }[];
backwardStepResults: {
neuron: number;
weights: number;
bias: number;
gradients: number[];
}[];
backwardStepResults: BackwardStepGradientsPerConnecrion[];
currentLoss: number | null;
weightUpdateResults: TrainingStepResult;
}
Expand Down Expand Up @@ -229,19 +224,16 @@ const TrainingStepsVisualizer: Component<TrainingStepsVisualizerProps> = (props)
<For each={props.backwardStepResults}>
{(element, neuronIndex) => (
<div class={styles.neuronGradients}>
<h4 class={styles.neuronLabel}>Neuron {neuronIndex()}</h4>
<h4 class={styles.neuronLabel}>Connection {element.connectionId}</h4>
<div class={styles.gradientGroup}>
<For each={element.gradients.slice(0, element.weights)}>
{(gradient, weightIndex) => (

<div class={styles.gradientItem}>
<span class={styles.gradientLabel}>W{weightIndex() + 1}:</span>
<span class={styles.gradientValue}>{gradient.toFixed(4)}</span>
<span class={styles.gradientLabel}>W:</span>
<span class={styles.gradientValue}>B: {element.weightGradient.toFixed(4)}</span>
</div>
)}
</For>
<div class={styles.gradientItem}>
<span class={styles.gradientLabel}>B:</span>
<span class={styles.gradientValue}>{element.gradients[element.gradients.length - 1].toFixed(4)}</span>
<span class={styles.gradientValue}>B: {element.biasGradient.toFixed(4)}</span>
</div>
</div>
</div>
Expand Down
4 changes: 2 additions & 2 deletions src/trainer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ export class Trainer {
return result;
}

stepBackwardAndGetGradientsGroupedByConnection(): BackwardStepGradientsPerConnecrion | null {
stepBackwardAndGetGradientsGroupedByConnection(): BackwardStepGradientsPerConnecrion[] {
// Recalculate the loss before each backward step
this.calculateLoss();

Expand All @@ -126,7 +126,7 @@ export class Trainer {
// Perform backpropagation
this.currentLoss.backward();

const gradients: BackwardStepGradientsPerConnecrion = [];
const gradients: BackwardStepGradientsPerConnecrion[] = [];

// Iterate through each layer and neuron to collect gradients
this._network.layers.forEach((layer, layerIndex) => {
Expand Down
4 changes: 2 additions & 2 deletions src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ export type AppState = {
iteration: number,
currentLoss: null | number,
forwardStepResults: Prediction,
backwardStepGradients: BackwardStepGradientsPerConnecrion,
backwardStepGradients: BackwardStepGradientsPerConnecrion[],
weightUpdateResults: [],
lossHistory: [],
};
Expand Down Expand Up @@ -46,7 +46,7 @@ export type BackwardStepGradientsPerConnecrion = {
connectionId: string,
weightGradient: number,
biasGradient: number,
}[];
};

export interface TrainingStepResult {
gradients: number[] | null;
Expand Down

0 comments on commit 0a9e702

Please sign in to comment.