Skip to content

Commit

Permalink
uniform connectionId
Browse files Browse the repository at this point in the history
  • Loading branch information
vitaliiznak committed Sep 20, 2024
1 parent 4e281ee commit a50e291
Show file tree
Hide file tree
Showing 8 changed files with 58 additions and 70 deletions.
3 changes: 3 additions & 0 deletions src/NeuralNetwork/MLP.ts
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ export class MLP {
});
});

// Ensure connection IDs are regenerated consistently
// Assuming connections are handled in the layout generation

return newMLP;
}

Expand Down
29 changes: 17 additions & 12 deletions src/NeuralNetworkVisualizer/ConnectionSidebar.tsx
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import { Component, createMemo, Show } from "solid-js";
import { css } from "@emotion/css";
import { colors } from "../styles/colors";
import { VisualConnection } from "../types";
import { store } from "../store";

interface ConnectionSidebarProps {
Expand All @@ -11,10 +10,17 @@ interface ConnectionSidebarProps {

const ConnectionSidebar: Component<ConnectionSidebarProps> = (props) => {
const connectionObject = createMemo(() => {
if (!props.connection) return null;
const connections = store.trainingState.backwardStepGradients;
return connections.find((c) => c.connectionId === props.connection);
})
const propsConnectionId = props.connection;
if (!propsConnectionId) return null;
const visualConnectionData = store.visualData.connections.find((c) => c.id === propsConnectionId);
const connectionGradients = store.trainingState.backwardStepGradients.find((c) => c.connectionId === propsConnectionId);
console.log('here connectionGradients', {
connectionGradients: store.trainingState.backwardStepGradients,
connectionId: propsConnectionId,
});
return { ...visualConnectionData, ...connectionGradients };
});

return (
<Show when={connectionObject()}>
{(connection) => (
Expand All @@ -23,24 +29,23 @@ const ConnectionSidebar: Component<ConnectionSidebarProps> = (props) => {
<h2 class={styles.title}>Connection Details</h2>

<div class={styles.detail}>
<strong>From:</strong> {connectionObject().from}
<strong>From:</strong> {connectionObject()?.from}
</div>
<div class={styles.detail}>
<strong>To:</strong> {connectionObject().to}
<strong>To:</strong> {connectionObject()?.to}
</div>
<div class={styles.detail}>
<strong>Weight:</strong> {connectionObject().weight.toFixed(4)}
<strong>Weight:</strong> {connectionObject()?.weight?.toFixed(4)}
</div>
<div class={styles.detail}>
<strong>Weight Gradient:</strong> {connection().weightGradient?.toFixed(4) || 'N/A'}
<strong>Weight Gradient:</strong> {connectionObject()?.weightGradient?.toFixed(4) || 'N/A'}
</div>
<div class={styles.detail}>
<strong>Bias:</strong> {connection().bias.toFixed(4)}
<strong>Bias:</strong> {connectionObject()?.bias?.toFixed(4)}
</div>
<div class={styles.detail}>
<strong>Bias Gradient:</strong> {connection().biasGradient?.toFixed(4)}
<strong>Bias Gradient:</strong> {connectionObject()?.biasGradient?.toFixed(4) || 'N/A'}
</div>

</div>
)}
</Show>
Expand Down
15 changes: 1 addition & 14 deletions src/NeuralNetworkVisualizer/NeuronInfoSidebar.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ interface NeuronInfoSidebarProps {
}

const NeuronInfoSidebar: Component<NeuronInfoSidebarProps> = (props) => {

createEffect(() => {
if (props.neuron) {
renderActivationFunctionChart();
Expand Down Expand Up @@ -126,19 +125,7 @@ const NeuronInfoSidebar: Component<NeuronInfoSidebarProps> = (props) => {
ax: 0,
ay: 40,
font: { color: '#FFFFFF' }
},
// {
// x: xMin,
// y: neuronOutput,
// xref: 'x',
// yref: 'y',
// text: `Output: ${neuronOutput.toFixed(4)}`,
// showarrow: true,
// arrowhead: 4,
// ax: 0,
// ay: 40,
// font: { color: '#FFFFFF' }
// }
}
],
shapes: [
{
Expand Down
18 changes: 9 additions & 9 deletions src/NeuralNetworkVisualizer/layout.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ export class NetworkLayout {
const startY = (this.canvasHeight - totalHeight) / 2;
for (let i = 0; i < inputSize; i++) {
nodes.push({
id: `input_${i}`,
label: 'ChatGPT Usage',
id: `neuron_-1_${i}`, // Standardized ID
label: `Input ${i}`,
layerId: 'input',
x: this.inputValuesSpacing,
y: startY + i * (this.nodeHeight + this.nodeSpacing),
Expand Down Expand Up @@ -83,19 +83,19 @@ export class NetworkLayout {
// Connect input nodes to first layer
for (let i = 0; i < inputSize; i++) {
connections.push({
id: `conn_${nodeId}_to_${neuron.id}`, // Unique and consistent ID
from: `input_${i}`,
to: neuron.id,
id: `from_neuron_-1_${i}_to_neuron_${layerIndex}_${neuronIndex}`, // Standardized format
from: `neuron_-1_${i}`,
to: nodeId,
weight: neuron.weights[i],
bias: neuron.bias
});
}
} else {
network.layers[layerIndex - 1].neurons.forEach((prevNeuron, prevIndex) => {
connections.push({
id: `conn_${prevNeuron.id}_to_${neuron.id}`, // Unique and consistent ID
from: prevNeuron.id,
to: neuron.id,
id: `from_neuron_${layerIndex - 1}_${prevIndex}_to_neuron_${layerIndex}_${neuronIndex}`, // Standardized format
from: `neuron_${layerIndex - 1}_${prevIndex}`,
to: nodeId,
weight: neuron.weights[prevIndex],
bias: neuron.bias
});
Expand All @@ -114,7 +114,7 @@ export class NetworkLayout {
const layerIndex = parseInt(layerIndexStr);
const nodeIndex = parseInt(nodeIndexStr);

if (nodeType === 'input' && input[nodeIndex] !== undefined) {
if (nodeType === 'neuron' && layerIndex === -1 && input[nodeIndex] !== undefined) {
node.outputValue = input[nodeIndex];
} else if (nodeType === 'neuron') {
if (layerOutputs[layerIndex] && layerOutputs[layerIndex][nodeIndex] !== undefined) {
Expand Down
10 changes: 7 additions & 3 deletions src/NeuralNetworkVisualizer/renderer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ export class NetworkRenderer {
private lastRenderedSelectedNode: VisualNode | null = null;
private onConnectionClick: (connection: VisualConnection) => void = () => { };
private connectionControlPoints: { connection: VisualConnection; p0: Point; p1: Point; p2: Point; p3: Point }[] = [];
private selectedConnection: VisualConnection | null = null;
private readonly epsilon: number = 5; // pixels
private labelBoundingBoxes: { connection: VisualConnection; rect: { x: number; y: number; width: number; height: number } }[] = [];

Expand Down Expand Up @@ -302,8 +301,13 @@ export class NetworkRenderer {
this.connectionControlPoints = [];

connections.forEach(conn => {
const fromNode = nodes.find(n => n.id === conn.from)!;
const toNode = nodes.find(n => n.id === conn.to)!;
const fromNode = nodes.find(n => n.id === conn.from);
const toNode = nodes.find(n => n.id === conn.to);

if (!fromNode || !toNode) {
console.error(`Node not found for connection: ${conn.id}`);
return; // Skip this connection
}

const fromX = fromNode.x + this.nodeWidth;
const fromY = fromNode.y + this.nodeHeight / 2;
Expand Down
12 changes: 2 additions & 10 deletions src/store.ts
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,9 @@ function stepBackward() {
}

console.log("Calling trainer.stepBackward()");
let result;
let result: BackwardStepGradientsPerConnection[];
try {
result = store.trainer.stepBackward();
result = store.trainer.stepBackwardAndGetGradientsGroupedByConnection();
} catch (error) {
console.error("Error in stepBackward:", error);
return;
Expand All @@ -197,7 +197,6 @@ function stepBackward() {
if (result && Array.isArray(result)) {
console.log("Updating store with result");


batch(() => {
setStore('trainingState', 'currentPhase', 'backward');
setStore('trainingState', 'backwardStepGradients', result);
Expand All @@ -223,13 +222,6 @@ function updateWeights() {
setStore('trainingState', 'currentPhase', 'idle');
setStore('networkUpdateTrigger', store.networkUpdateTrigger + 1);


// setStore('trainingState', {
// forwardStepResults: [],
// backwardStepGradients: [],
// weightUpdateResults: [],
// });

console.log("Weights updated successfully");
});
}
Expand Down
22 changes: 12 additions & 10 deletions src/trainer.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { MLP } from "./NeuralNetwork/mlp";
import { Value } from "./NeuralNetwork/value";
import { BackwardStepGradients, BackwardStepGradientsPerConnecrion, Prediction, TrainingConfig, TrainingStepResult } from "./types";
import { BackwardStepGradients, BackwardStepGradientsPerConnection, Prediction, TrainingConfig, TrainingStepResult } from "./types";

export class Trainer {
_network: MLP;
Expand All @@ -14,7 +14,7 @@ export class Trainer {
private currentBatchInputs: number[][] = [];
private currentBatchTargets: number[] = [];

stepBackward: () => BackwardStepGradientsPerConnecrion[];
stepBackward: () => BackwardStepGradientsPerConnection[];

constructor(network: MLP, config: TrainingConfig) {
this._network = network.clone();
Expand Down Expand Up @@ -115,12 +115,12 @@ export class Trainer {
return result;
}

stepBackwardAndGetGradientsGroupedByConnection(): BackwardStepGradientsPerConnecrion[] {

stepBackwardAndGetGradientsGroupedByConnection(): BackwardStepGradientsPerConnection[] {
// Recalculate the loss before each backward step
this.calculateLoss();

if (!this.currentLoss) {
// Recalculate the loss before each backward step
this.calculateLoss();
console.error("Loss not calculated");
return [];
}

Expand All @@ -130,15 +130,17 @@ export class Trainer {
// Perform backpropagation
this.currentLoss.backward();

const gradients: BackwardStepGradientsPerConnecrion[] = [];
const gradients: BackwardStepGradientsPerConnection[] = [];

// Iterate through each layer and neuron to collect gradients
this._network.layers.forEach((layer, layerIndex) => {
layer.neurons.forEach((neuron, neuronIndex) => {
neuron.w.forEach((weight, weightIndex) => {
const fromNodeId = layerIndex === 0 ? `input_${weightIndex}` : `layer${layerIndex - 1}_neuron${weightIndex}`;
const toNodeId = `layer${layerIndex}_neuron${neuronIndex}`;
const connectionId = `conn_${fromNodeId}_to_${toNodeId}`;
const fromNodeId = layerIndex === 0
? `neuron_-1_${weightIndex}` // Consistent with input node IDs
: `neuron_${layerIndex - 1}_${weightIndex}`;
const toNodeId = `neuron_${layerIndex}_${neuronIndex}`;
const connectionId = `from_neuron_${layerIndex - 1}_${weightIndex}_to_neuron_${layerIndex}_${neuronIndex}`; // Standardized format

gradients.push({
connectionId,
Expand Down
19 changes: 7 additions & 12 deletions src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ export type AppState = {
iteration: number,
currentLoss: null | number,
forwardStepResults: Prediction,
backwardStepGradients: BackwardStepGradientsPerConnecrion[],
backwardStepGradients: BackwardStepGradientsPerConnection[],
weightUpdateResults: [],
lossHistory: [],
};
Expand All @@ -35,18 +35,13 @@ export type AppState = {

type TrainingRun = any

export type BackwardStepGradients = {
neuron: number;
weights: number;
bias: number;
gradients: number[];
}[];
export interface BackwardStepGradientsPerConnection {
connectionId: string;
weightGradient: number;
biasGradient: number;
}

export type BackwardStepGradientsPerConnecrion = {
connectionId: string,
weightGradient: number,
biasGradient: number,
};
export type BackwardStepGradients = BackwardStepGradientsPerConnection[];

export interface TrainingStepResult {
gradients: number[] | null;
Expand Down

0 comments on commit a50e291

Please sign in to comment.