diff --git a/src/NeuralNetworkVisualizer/ConnectionSidebar.tsx b/src/NeuralNetworkVisualizer/ConnectionSidebar.tsx index 6ff38b4..6161fd6 100644 --- a/src/NeuralNetworkVisualizer/ConnectionSidebar.tsx +++ b/src/NeuralNetworkVisualizer/ConnectionSidebar.tsx @@ -1,4 +1,4 @@ -import { Component } from "solid-js"; +import { Component, Show } from "solid-js"; import { css } from "@emotion/css"; import { colors } from "../styles/colors"; import { VisualConnection } from "../types"; @@ -9,38 +9,40 @@ interface ConnectionSidebarProps { } const ConnectionSidebar: Component = (props) => { - if (!props.connection) return null; + return ( -
- -

Connection Details

- -
- From: {props.connection.from} -
-
- To: {props.connection.to} -
-
- Weight: {props.connection.weight.toFixed(4)} -
-
- Bias: {props.connection.bias.toFixed(4)} -
-
- Weight Gradient: {props.connection.weightGradient?.toFixed(4) || 'N/A'} -
-
- Bias Gradient: {props.connection.biasGradient?.toFixed(4) || 'N/A'} + +
+ +

Connection Details

+ +
+ From: {props.connection!.from} +
+
+ To: {props.connection!.to} +
+
+ Weight: {props.connection!.weight.toFixed(4)} +
+
+ Bias: {props.connection!.bias.toFixed(4)} +
+
+ Weight Gradient: {props.connection!.weightGradient?.toFixed(4) || 'N/A'} +
+
+ Bias Gradient: {props.connection!.biasGradient?.toFixed(4) || 'N/A'} +
-
+ ); } const styles = { sidebar: css` - position: absolute; + position: fixed; /* Changed from absolute to fixed */ right: 0; top: 0; width: 250px; @@ -50,7 +52,7 @@ const styles = { padding: 1rem; box-shadow: -2px 0 5px rgba(0,0,0,0.1); overflow-y: auto; - z-index: 1000; + z-index: 1001; /* Ensure it's above other elements */ `, closeButton: css` position: absolute; diff --git a/src/NeuralNetworkVisualizer/NetworkVisualizer.tsx b/src/NeuralNetworkVisualizer/NetworkVisualizer.tsx index cd68713..27a6d5e 100644 --- a/src/NeuralNetworkVisualizer/NetworkVisualizer.tsx +++ b/src/NeuralNetworkVisualizer/NetworkVisualizer.tsx @@ -8,7 +8,7 @@ import { canvasStyle, containerStyle, tooltipStyle } from "./NetworkVisualizerSt import { debounce } from "@solid-primitives/scheduled"; import { css } from "@emotion/css"; import { colors } from "../styles/colors"; - +import ConnectionSidebar from "./ConnectionSidebar"; interface NetworkVisualizerProps { includeLossNode: boolean; @@ -35,7 +35,7 @@ const NetworkVisualizer: Component = (props) => { let draggedNode: VisualNode | null = null; let mouseDownTimer: number | null = null; - let lastPanPosition: { x: number; y: 0 } = { x: 0, y: 0 }; + let lastPanPosition: { x: number; y: number } = { x: 0, y: 0 }; const visualData = createMemo(() => { const layoutCalculatorValue = layoutCalculator(); @@ -72,7 +72,6 @@ const NetworkVisualizer: Component = (props) => { } }); - const setupEventListeners = () => { const canvas = canvasRef(); if (canvas) { @@ -224,13 +223,14 @@ const NetworkVisualizer: Component = (props) => { const connection = rendererValue.getConnectionAtPoint(x, y); if (connection) { setSelectedConnection(connection); + props.onSidebarToggle(true); + rendererValue.highlightConnection(connection); } else { setSelectedConnection(null); props.onSidebarToggle(false); + rendererValue.clearHighlightedConnection(); } - - // Optionally handle neuron selection if needed }; const handleMouseUp = () => { @@ -278,6 +278,14 @@ const NetworkVisualizer: Component = (props) => { renderer()?.render(visualData(), null); }} /> + { + setSelectedConnection(null); + renderer()?.clearHighlightedConnection(); + props.onSidebarToggle(false); + }} + /> {(tooltipAccessor) => { const data = tooltipAccessor(); @@ -301,11 +309,8 @@ const NetworkVisualizer: Component = (props) => { ); }} - -
); }; - export default NetworkVisualizer; \ No newline at end of file diff --git a/src/NeuralNetworkVisualizer/renderer.ts b/src/NeuralNetworkVisualizer/renderer.ts index 44cf631..995f179 100644 --- a/src/NeuralNetworkVisualizer/renderer.ts +++ b/src/NeuralNetworkVisualizer/renderer.ts @@ -1,6 +1,7 @@ import { throttle } from "@solid-primitives/scheduled"; import { colors } from "../styles/colors"; import { Point, VisualConnection, VisualNetworkData, VisualNode } from "../types"; + export class NetworkRenderer { private ctx: CanvasRenderingContext2D; public scale: number = 1; @@ -10,13 +11,16 @@ export class NetworkRenderer { public nodeWidth: number; public nodeHeight: number; private highlightedNodeId: string | null = null; - lastRenderedData: VisualNetworkData | undefined; - lastRenderedSelectedNode: VisualNode | null; + private lastRenderedData: VisualNetworkData | undefined; + private lastRenderedSelectedNode: VisualNode | null = null; private onConnectionClick: (connection: VisualConnection) => void = () => { }; private connectionPositions: { connection: VisualConnection; path: Path2D }[] = []; private selectedConnection: VisualConnection | null = null; private readonly epsilon: number = 5; // pixels private connectionControlPoints: { connection: VisualConnection; p0: Point; p1: Point; p2: Point; p3: Point }[] = []; + private highlightedConnection: VisualConnection | null = null; + private pulse: number = 0; + private animationFrameId: number | null = null; constructor(private canvas: HTMLCanvasElement) { this.ctx = canvas.getContext('2d')!; @@ -25,19 +29,19 @@ export class NetworkRenderer { this.debouncedRender = throttle((data: VisualNetworkData | null, selectedNode: VisualNode | null) => { this._render(data, selectedNode); }, 16); // Debounce to ~60fps - this.lastRenderedSelectedNode = null; - } - render(data: VisualNetworkData, selectedNode: VisualNode | null) { - this.lastRenderedData = data; - this.lastRenderedSelectedNode = selectedNode; - this._render(data, selectedNode); + this.startAnimation(); + } - // Initialize event listeners once - // if (!this.eventListenersInitialized) { - // this.initializeEventListeners(); - // this.eventListenersInitialized = true; - // } + render(data?: VisualNetworkData, selectedNode?: VisualNode | null) { + if (data && selectedNode !== undefined) { + this.lastRenderedData = data; + this.lastRenderedSelectedNode = selectedNode; + } else if (!this.lastRenderedData) { + console.warn('No data available to render.'); + return; + } + this._render(this.lastRenderedData, this.lastRenderedSelectedNode); } setHighlightedNode(nodeId: string | null) { @@ -307,8 +311,32 @@ export class NetworkRenderer { const p2 = { x: toX - controlPointOffset, y: toY }; const p3 = { x: toX, y: toY }; - // Draw the curved arrow - this.drawCurvedArrow(p0.x, p0.y, p3.x, p3.y, this.getConnectionColor(conn.weight)); + // Set styles for connections + if (this.highlightedConnection && this.highlightedConnection === conn) { + // Pulsing Glow Effect + const glowIntensity = Math.abs(Math.sin(this.pulse)) * 15; // Varies between 0 and 15 + this.ctx.lineWidth = 4 + Math.abs(Math.sin(this.pulse)) * 2; // Pulsates line width between 4 and 6 + this.ctx.strokeStyle = colors.highlight; + this.ctx.shadowColor = colors.highlight; + this.ctx.shadowBlur = glowIntensity; + + // Moving Gradient Effect + const gradient = this.ctx.createLinearGradient(p0.x, p0.y, p3.x, p3.y); + const gradientPosition = (Math.sin(this.pulse) + 1) / 2; // Normalize to [0,1] + gradient.addColorStop(0, colors.gradientStart); + gradient.addColorStop(gradientPosition, colors.gradientEnd); + gradient.addColorStop(1, colors.gradientStart); + this.ctx.strokeStyle = gradient; + } else { + this.ctx.lineWidth = 1; + this.ctx.strokeStyle = this.getConnectionColor(conn.weight); + this.ctx.shadowBlur = 0; + } + + this.drawCurvedArrow(p0.x, p0.y, p3.x, p3.y); + + // Reset shadow after drawing + this.ctx.shadowBlur = 0; // Store control points for hit detection this.connectionControlPoints.push({ connection: conn, p0, p1, p2, p3 }); @@ -343,14 +371,10 @@ export class NetworkRenderer { this.ctx.restore(); } - private drawCurvedArrow(fromX: number, fromY: number, toX: number, toY: number, color: string) { + private drawCurvedArrow(fromX: number, fromY: number, toX: number, toY: number) { const headLength = 10; const controlPointOffset = Math.abs(toX - fromX) * 0.2; - this.ctx.strokeStyle = color; - this.ctx.fillStyle = color; - this.ctx.lineWidth = 2; - // Draw the curved line this.ctx.beginPath(); this.ctx.moveTo(fromX, fromY); @@ -399,7 +423,6 @@ export class NetworkRenderer { } } - private _render(data: VisualNetworkData, selectedNode: VisualNode | null) { this.clear(); this.ctx.save(); @@ -420,7 +443,6 @@ export class NetworkRenderer { this.onConnectionClick = callback; } - // Method to get connection at a specific point within epsilon radius getConnectionAtPoint(x: number, y: number): VisualConnection | null { for (const { connection, p0, p1, p2, p3 } of this.connectionControlPoints) { @@ -432,7 +454,6 @@ export class NetworkRenderer { return null; } - setSelectedConnection(connection: VisualConnection | null) { this.selectedConnection = connection; } @@ -471,4 +492,37 @@ export class NetworkRenderer { return minDist; } + + highlightConnection(connection: VisualConnection): void { + this.highlightedConnection = connection; + this.render(this.lastRenderedData!, this.lastRenderedSelectedNode); + } + + clearHighlightedConnection(): void { + this.highlightedConnection = null; + this.render(this.lastRenderedData!, this.lastRenderedSelectedNode); + } + + // Start the animation loop + private startAnimation() { + const animate = () => { + this.pulse += 0.05; // Increment pulse + this.render(this.lastRenderedData, this.lastRenderedSelectedNode); + this.animationFrameId = requestAnimationFrame(animate); + }; + this.animationFrameId = requestAnimationFrame(animate); + } + + // Stop the animation loop + private stopAnimation() { + if (this.animationFrameId !== null) { + cancelAnimationFrame(this.animationFrameId); + this.animationFrameId = null; + } + } + + // Clean up when the renderer is no longer needed + destroy() { + this.stopAnimation(); + } } \ No newline at end of file diff --git a/src/styles/colors.ts b/src/styles/colors.ts index 1a416b6..b277217 100644 --- a/src/styles/colors.ts +++ b/src/styles/colors.ts @@ -17,4 +17,11 @@ export const colors = { textDark: 'var(--text-light)', // Example value success: '#28a745', // Add this line + + // Add the highlight color + highlight: '#FFD700', // Gold color for highlighting + + // Add gradient colors + gradientStart: '#FFD700', // Start of the gradient (same as highlight) + gradientEnd: '#FFA500', // End of the gradient (Orange) }; \ No newline at end of file