Skip to content

Commit

Permalink
Clickable weights
Browse files Browse the repository at this point in the history
  • Loading branch information
vitaliiznak committed Sep 18, 2024
1 parent 232bc43 commit 76aa3a1
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 88 deletions.
4 changes: 2 additions & 2 deletions src/NeuralNetworkVisualizer/ConnectionSidebar.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ const ConnectionSidebar: Component<ConnectionSidebarProps> = (props) => {

const styles = {
sidebar: css`
position: fixed; /* Changed from absolute to fixed */
position: fixed;
right: 0;
top: 0;
width: 250px;
Expand All @@ -49,7 +49,7 @@ const styles = {
padding: 1rem;
box-shadow: -2px 0 5px rgba(0,0,0,0.1);
overflow-y: auto;
z-index: 1001; /* Ensure it's above other elements */
z-index: 1001;
`,
closeButton: css`
position: absolute;
Expand Down
66 changes: 37 additions & 29 deletions src/NeuralNetworkVisualizer/NetworkVisualizer.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@ import { Component, createEffect, onCleanup, createSignal, createMemo, Show, bat
import { NetworkRenderer } from "./renderer";
import NeuronInfoSidebar from "./NeuronInfoSidebar";
import { store } from "../store";
import { VisualConnection, VisualNode } from "../types";
import { VisualNode } from "../types";
import { useCanvasSetup } from "./useCanvasSetup";
import { canvasStyle, containerStyle, tooltipStyle } from "./NetworkVisualizerStyles";
import { debounce } from "@solid-primitives/scheduled";
import { css } from "@emotion/css";
import { colors } from "../styles/colors";
import ConnectionSidebar from "./ConnectionSidebar";

interface NetworkVisualizerProps {
Expand Down Expand Up @@ -76,7 +75,7 @@ const NetworkVisualizer: Component<NetworkVisualizerProps> = (props) => {
const rendererValue = renderer();
const selected = selectedNeuron();
if (rendererValue) {
rendererValue.render(visualData(), selected);
rendererValue.render(visualData(), selected, selectedConnectionId());
}
});

Expand Down Expand Up @@ -203,6 +202,7 @@ const NetworkVisualizer: Component<NetworkVisualizerProps> = (props) => {
const layoutCalculatorValue = layoutCalculator();
if (!layoutCalculatorValue) return;

// Check if hovering over a node
const hoveredNode = layoutCalculatorValue.findNodeAt(
x,
y,
Expand All @@ -212,9 +212,24 @@ const NetworkVisualizer: Component<NetworkVisualizerProps> = (props) => {
rendererValue.offsetY
);

if (hoveredNode) {
// Check if hovering over a connection or its label
const hoveredConnection = rendererValue.getConnectionAtPoint(x, y);

if (hoveredNode || hoveredConnection) {
canvasRef()!.style.cursor = 'pointer';
showTooltip(x, y, `Node: ${hoveredNode.label}\nOutput: ${hoveredNode.outputValue?.toFixed(4) || 'N/A'}`);
if (hoveredNode) {
showTooltip(
x,
y,
`Node: ${hoveredNode.label}\nOutput: ${hoveredNode.outputValue?.toFixed(4) || 'N/A'}`
);
} else if (hoveredConnection) {
showTooltip(
x,
y,
`Connection: ${hoveredConnection.from}${hoveredConnection.to}\nWeight: ${hoveredConnection.weight.toFixed(4)}`
);
}
} else {
canvasRef()!.style.cursor = 'grab';
hideTooltip();
Expand Down Expand Up @@ -243,22 +258,21 @@ const NetworkVisualizer: Component<NetworkVisualizerProps> = (props) => {

const handleMouseUp = () => {
const canvas = canvasRef();
if (canvas) {
batch(() => {
if (mouseDownTimer !== null) {
clearTimeout(mouseDownTimer);
if (draggedNode) {
setSelectedNeuron(draggedNode);
renderer()?.render(visualData(), draggedNode); // Trigger re-render with selected node
}

if (mouseDownTimer !== null) {
clearTimeout(mouseDownTimer);
if (draggedNode) {
setSelectedNeuron(draggedNode);
renderer()?.render(visualData(), draggedNode); // Trigger re-render with selected node
}
setIsPanning(false);
draggedNode = null;
mouseDownTimer = null;
lastPanPosition = { x: 0, y: 0 };
});
canvas.style.cursor = 'grab';
}
}
setIsPanning(false);
draggedNode = null;
mouseDownTimer = null;
lastPanPosition = { x: 0, y: 0 };
if (canvas) {
canvas.style.cursor = 'grab';
}
};

return (
Expand Down Expand Up @@ -301,16 +315,10 @@ const NetworkVisualizer: Component<NetworkVisualizerProps> = (props) => {
<div
class={css`
${tooltipStyle}
background-color: ${colors.surface};
color: ${colors.text};
border: 1px solid ${colors.border};
box-shadow: 0 2px 4px ${colors.shadow};
left: ${data.x + 10}px;
top: ${data.y + 10}px;
display: block;
`}
style={{
left: `${data.x}px`,
top: `${data.y}px`,
display: 'block'
}}
>
{data.text}
</div>
Expand Down
128 changes: 72 additions & 56 deletions src/NeuralNetworkVisualizer/renderer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,30 +7,25 @@ export class NetworkRenderer {
public scale: number = 1;
public offsetX: number = 0;
public offsetY: number = 0;
private debouncedRender: (data: VisualNetworkData, selectedNode: VisualNode | null) => void;
private debouncedRender: (data: VisualNetworkData, selectedNode: VisualNode | null, highlightedConnectionId: string | null) => void;
public nodeWidth: number;
public nodeHeight: number;
private highlightedNodeId: string | null = null;
private lastRenderedData: VisualNetworkData | undefined;
private lastRenderedSelectedNode: VisualNode | null = null;
private onConnectionClick: (connection: VisualConnection) => void = () => { };
private connectionPositions: { connection: VisualConnection; path: Path2D }[] = [];
private connectionControlPoints: { connection: VisualConnection; p0: Point; p1: Point; p2: Point; p3: Point }[] = [];
private selectedConnection: VisualConnection | null = null;
private readonly epsilon: number = 5; // pixels
private connectionControlPoints: { connection: VisualConnection; p0: Point; p1: Point; p2: Point; p3: Point }[] = [];
private highlightedConnectionId: string | null = null;
private pulse: number = 0;
private animationFrameId: number | null = null;
private labelBoundingBoxes: { connection: VisualConnection; rect: { x: number; y: number; width: number; height: number } }[] = [];

constructor(private canvas: HTMLCanvasElement) {
this.ctx = canvas.getContext('2d')!;
this.nodeWidth = 60; // or whatever default value you prefer
this.nodeHeight = 40; // or whatever default value you prefer
this.debouncedRender = throttle((data: VisualNetworkData | null, selectedNode: VisualNode | null) => {
this._render(data, selectedNode);
this.nodeWidth = 60;
this.nodeHeight = 40;
this.debouncedRender = throttle((data: VisualNetworkData, selectedNode: VisualNode | null, highlightedConnectionId: string | null) => {
this._render(data, selectedNode, highlightedConnectionId);
}, 16); // Debounce to ~60fps

this.startAnimation();
}

render(data?: VisualNetworkData, selectedNode?: VisualNode | null) {
Expand All @@ -41,12 +36,12 @@ export class NetworkRenderer {
console.warn('No data available to render.');
return;
}
this._render(this.lastRenderedData, this.lastRenderedSelectedNode);
this._render(this.lastRenderedData, this.lastRenderedSelectedNode, this.highlightedConnectionId);
}

setHighlightedNode(nodeId: string | null) {
this.highlightedNodeId = nodeId;
this.debouncedRender(this.lastRenderedData, this.lastRenderedSelectedNode);
this.debouncedRender(this.lastRenderedData, this.lastRenderedSelectedNode, this.highlightedConnectionId);
}

pan(dx: number, dy: number) {
Expand Down Expand Up @@ -242,7 +237,7 @@ export class NetworkRenderer {
this.ctx.lineTo(biasX, biasY + 8);
this.ctx.lineTo(biasX - 8, biasY);
this.ctx.closePath();
this.ctx.fillStyle = node.bias >= 0 ? colors.secondary : colors.textLight; // Use different color for negative bias
this.ctx.fillStyle = node.bias >= 0 ? colors.secondary : colors.textLight;
this.ctx.fill();
this.ctx.strokeStyle = colors.border;
this.ctx.lineWidth = 1;
Expand All @@ -258,7 +253,13 @@ export class NetworkRenderer {
// Add "Bias" label
this.ctx.fillStyle = '#fff';
this.ctx.font = '7px Arial';
this.ctx.fillText('Bias', biasX, biasY + 15);
const biasLabel = 'Bias';
const textWidth = this.ctx.measureText(biasLabel).width;
const padding = 2;
this.ctx.fillText(biasLabel, biasX, biasY + 15);

// **Do not store the bounding box for bias labels**
// This ensures that clicking on bias labels does not trigger the ConnectionSidebar
}

private drawOutputValue(node: VisualNode) {
Expand Down Expand Up @@ -293,8 +294,14 @@ export class NetworkRenderer {
this.ctx.fillText('Output', outputX, outputY + 15);
}

private drawConnections(connections: VisualConnection[], nodes: VisualNode[]) {
this.connectionControlPoints = []; // Reset on each render
private drawConnections(
connections: VisualConnection[],
nodes: VisualNode[],
highlightedConnectionId: string | null
) {
this.labelBoundingBoxes = []; // Reset on each render
this.connectionControlPoints = [];

connections.forEach(conn => {
const fromNode = nodes.find(n => n.id === conn.from)!;
const toNode = nodes.find(n => n.id === conn.to)!;
Expand All @@ -312,21 +319,11 @@ export class NetworkRenderer {
const p3 = { x: toX, y: toY };

// Set styles for connections
if (this.highlightedConnectionId && this.highlightedConnectionId === conn.id) {
// Pulsing Glow Effect
const glowIntensity = Math.abs(Math.sin(this.pulse)) * 15; // Varies between 0 and 15
this.ctx.lineWidth = 4 + Math.abs(Math.sin(this.pulse)) * 2; // Pulsates line width between 4 and 6
if (highlightedConnectionId && highlightedConnectionId === conn.id) {
this.ctx.lineWidth = 4;
this.ctx.strokeStyle = colors.highlight;
this.ctx.shadowColor = colors.highlight;
this.ctx.shadowBlur = glowIntensity;

// Moving Gradient Effect
const gradient = this.ctx.createLinearGradient(p0.x, p0.y, p3.x, p3.y);
const gradientPosition = (Math.sin(this.pulse) + 1) / 2; // Normalize to [0,1]
gradient.addColorStop(0, colors.gradientStart);
gradient.addColorStop(gradientPosition, colors.gradientEnd);
gradient.addColorStop(1, colors.gradientStart);
this.ctx.strokeStyle = gradient;
this.ctx.shadowBlur = 15;
} else {
this.ctx.lineWidth = 1;
this.ctx.strokeStyle = this.getConnectionColor(conn.weight);
Expand All @@ -349,17 +346,28 @@ export class NetworkRenderer {
const offsetY = -10;

// Draw weight label
this.drawLabel(midX, midY + offsetY, `x${conn.weight.toFixed(2)}`, conn.weight);
this.drawLabel(midX, midY + offsetY, `x${conn.weight.toFixed(2)}`, conn.weight, conn);
});

// Optionally, draw bias labels separately if needed
nodes.forEach(node => {
this.drawBias(node);
});
}

private drawLabel(x: number, y: number, text: string, weight: number) {
private drawLabel(
x: number,
y: number,
text: string,
weight: number,
connection: VisualConnection
) {
this.ctx.save();
this.ctx.font = '9px Arial Bold';
const textWidth = this.ctx.measureText(text).width;
const padding = 2;

// Semi-transparent background
// Semi-transparent background based on weight
this.ctx.fillStyle = weight >= 0 ? 'rgba(0, 255, 0, 0.6)' : 'rgba(255, 0, 0, 0.6)';
this.ctx.fillRect(x - textWidth / 2 - padding, y - 7 - padding, textWidth + padding * 2, 14 + padding * 2);

Expand All @@ -368,7 +376,19 @@ export class NetworkRenderer {
this.ctx.textAlign = 'center';
this.ctx.textBaseline = 'middle';
this.ctx.fillText(text, x, y);

this.ctx.restore();

// Store the bounding box for hit detection (only for weight labels)
this.labelBoundingBoxes.push({
connection,
rect: {
x: x - textWidth / 2 - padding,
y: y - 7 - padding,
width: textWidth + padding * 2,
height: 14 + padding * 2,
},
});
}

private drawCurvedArrow(fromX: number, fromY: number, toX: number, toY: number) {
Expand Down Expand Up @@ -423,12 +443,12 @@ export class NetworkRenderer {
}
}

private _render(data: VisualNetworkData, selectedNode: VisualNode | null) {
private _render(data: VisualNetworkData, selectedNode: VisualNode | null, highlightedConnectionId: string | null) {
this.clear();
this.ctx.save();
this.ctx.translate(this.offsetX, this.offsetY);
this.ctx.scale(this.scale, this.scale);
this.drawConnections(data.connections, data.nodes);
this.drawConnections(data.connections, data.nodes, highlightedConnectionId);
this.drawNodes(data.nodes, selectedNode);
this.ctx.restore();
}
Expand All @@ -443,14 +463,28 @@ export class NetworkRenderer {
this.onConnectionClick = callback;
}

// Method to get connection at a specific point within epsilon radius
// Adjusted getConnectionAtPoint to exclude bias labels
getConnectionAtPoint(x: number, y: number): VisualConnection | null {
// Check proximity to connection lines
for (const { connection, p0, p1, p2, p3 } of this.connectionControlPoints) {
const distance = this.calculateDistanceToBezier(x, y, p0, p1, p2, p3);
if (distance <= this.epsilon) {
return connection;
}
}

// Check if click is within any weight label bounding box
for (const { connection, rect } of this.labelBoundingBoxes) {
if (
x >= rect.x &&
x <= rect.x + rect.width &&
y >= rect.y &&
y <= rect.y + rect.height
) {
return connection;
}
}

return null;
}

Expand Down Expand Up @@ -503,26 +537,8 @@ export class NetworkRenderer {
this.render(this.lastRenderedData!, this.lastRenderedSelectedNode);
}

// Start the animation loop
private startAnimation() {
const animate = () => {
this.pulse += 0.05; // Increment pulse
this.render(this.lastRenderedData, this.lastRenderedSelectedNode);
this.animationFrameId = requestAnimationFrame(animate);
};
this.animationFrameId = requestAnimationFrame(animate);
}

// Stop the animation loop
private stopAnimation() {
if (this.animationFrameId !== null) {
cancelAnimationFrame(this.animationFrameId);
this.animationFrameId = null;
}
}

// Clean up when the renderer is no longer needed
destroy() {
this.stopAnimation();
// No cleanup needed
}
}
2 changes: 1 addition & 1 deletion src/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ export const CONFIG = {
INITIAL_NETWORK: {
inputSize: 1,
layers: [5, 3, 1],
activations: ['leaky-relu',, 'leaky-relu', 'identity']
activations: ['leaky-relu', 'leaky-relu', 'identity']
} as MLPConfig,
INITIAL_TRAINING: {
learningRate: 0.005
Expand Down

0 comments on commit 76aa3a1

Please sign in to comment.