Skip to content

Commit

Permalink
Connection sidebar
Browse files Browse the repository at this point in the history
  • Loading branch information
vitaliiznak committed Sep 18, 2024
1 parent a726925 commit 3621810
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 57 deletions.
54 changes: 28 additions & 26 deletions src/NeuralNetworkVisualizer/ConnectionSidebar.tsx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { Component } from "solid-js";
import { Component, Show } from "solid-js";
import { css } from "@emotion/css";
import { colors } from "../styles/colors";
import { VisualConnection } from "../types";
Expand All @@ -9,38 +9,40 @@ interface ConnectionSidebarProps {
}

const ConnectionSidebar: Component<ConnectionSidebarProps> = (props) => {
if (!props.connection) return null;


return (
<div class={styles.sidebar}>
<button class={styles.closeButton} onClick={props.onClose}>×</button>
<h2 class={styles.title}>Connection Details</h2>

<div class={styles.detail}>
<strong>From:</strong> {props.connection.from}
</div>
<div class={styles.detail}>
<strong>To:</strong> {props.connection.to}
</div>
<div class={styles.detail}>
<strong>Weight:</strong> {props.connection.weight.toFixed(4)}
</div>
<div class={styles.detail}>
<strong>Bias:</strong> {props.connection.bias.toFixed(4)}
</div>
<div class={styles.detail}>
<strong>Weight Gradient:</strong> {props.connection.weightGradient?.toFixed(4) || 'N/A'}
</div>
<div class={styles.detail}>
<strong>Bias Gradient:</strong> {props.connection.biasGradient?.toFixed(4) || 'N/A'}
<Show when={props.connection}>
<div class={styles.sidebar}>
<button class={styles.closeButton} onClick={props.onClose}>×</button>
<h2 class={styles.title}>Connection Details</h2>

<div class={styles.detail}>
<strong>From:</strong> {props.connection!.from}
</div>
<div class={styles.detail}>
<strong>To:</strong> {props.connection!.to}
</div>
<div class={styles.detail}>
<strong>Weight:</strong> {props.connection!.weight.toFixed(4)}
</div>
<div class={styles.detail}>
<strong>Bias:</strong> {props.connection!.bias.toFixed(4)}
</div>
<div class={styles.detail}>
<strong>Weight Gradient:</strong> {props.connection!.weightGradient?.toFixed(4) || 'N/A'}
</div>
<div class={styles.detail}>
<strong>Bias Gradient:</strong> {props.connection!.biasGradient?.toFixed(4) || 'N/A'}
</div>
</div>
</div>
</Show>
);
}

const styles = {
sidebar: css`
position: absolute;
position: fixed; /* Changed from absolute to fixed */
right: 0;
top: 0;
width: 250px;
Expand All @@ -50,7 +52,7 @@ const styles = {
padding: 1rem;
box-shadow: -2px 0 5px rgba(0,0,0,0.1);
overflow-y: auto;
z-index: 1000;
z-index: 1001; /* Ensure it's above other elements */
`,
closeButton: css`
position: absolute;
Expand Down
21 changes: 13 additions & 8 deletions src/NeuralNetworkVisualizer/NetworkVisualizer.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import { canvasStyle, containerStyle, tooltipStyle } from "./NetworkVisualizerSt
import { debounce } from "@solid-primitives/scheduled";
import { css } from "@emotion/css";
import { colors } from "../styles/colors";

import ConnectionSidebar from "./ConnectionSidebar";

interface NetworkVisualizerProps {
includeLossNode: boolean;
Expand All @@ -35,7 +35,7 @@ const NetworkVisualizer: Component<NetworkVisualizerProps> = (props) => {

let draggedNode: VisualNode | null = null;
let mouseDownTimer: number | null = null;
let lastPanPosition: { x: number; y: 0 } = { x: 0, y: 0 };
let lastPanPosition: { x: number; y: number } = { x: 0, y: 0 };

const visualData = createMemo(() => {
const layoutCalculatorValue = layoutCalculator();
Expand Down Expand Up @@ -72,7 +72,6 @@ const NetworkVisualizer: Component<NetworkVisualizerProps> = (props) => {
}
});


const setupEventListeners = () => {
const canvas = canvasRef();
if (canvas) {
Expand Down Expand Up @@ -224,13 +223,14 @@ const NetworkVisualizer: Component<NetworkVisualizerProps> = (props) => {
const connection = rendererValue.getConnectionAtPoint(x, y);
if (connection) {
setSelectedConnection(connection);

props.onSidebarToggle(true);
rendererValue.highlightConnection(connection);
} else {
setSelectedConnection(null);
props.onSidebarToggle(false);
rendererValue.clearHighlightedConnection();
}

// Optionally handle neuron selection if needed
};

const handleMouseUp = () => {
Expand Down Expand Up @@ -278,6 +278,14 @@ const NetworkVisualizer: Component<NetworkVisualizerProps> = (props) => {
renderer()?.render(visualData(), null);
}}
/>
<ConnectionSidebar
connection={selectedConnection()}
onClose={() => {
setSelectedConnection(null);
renderer()?.clearHighlightedConnection();
props.onSidebarToggle(false);
}}
/>
<Show when={tooltipData()}>
{(tooltipAccessor) => {
const data = tooltipAccessor();
Expand All @@ -301,11 +309,8 @@ const NetworkVisualizer: Component<NetworkVisualizerProps> = (props) => {
);
}}
</Show>


</div>
);
};


export default NetworkVisualizer;
100 changes: 77 additions & 23 deletions src/NeuralNetworkVisualizer/renderer.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { throttle } from "@solid-primitives/scheduled";
import { colors } from "../styles/colors";
import { Point, VisualConnection, VisualNetworkData, VisualNode } from "../types";

export class NetworkRenderer {
private ctx: CanvasRenderingContext2D;
public scale: number = 1;
Expand All @@ -10,13 +11,16 @@ export class NetworkRenderer {
public nodeWidth: number;
public nodeHeight: number;
private highlightedNodeId: string | null = null;
lastRenderedData: VisualNetworkData | undefined;
lastRenderedSelectedNode: VisualNode | null;
private lastRenderedData: VisualNetworkData | undefined;
private lastRenderedSelectedNode: VisualNode | null = null;
private onConnectionClick: (connection: VisualConnection) => void = () => { };
private connectionPositions: { connection: VisualConnection; path: Path2D }[] = [];
private selectedConnection: VisualConnection | null = null;
private readonly epsilon: number = 5; // pixels
private connectionControlPoints: { connection: VisualConnection; p0: Point; p1: Point; p2: Point; p3: Point }[] = [];
private highlightedConnection: VisualConnection | null = null;
private pulse: number = 0;
private animationFrameId: number | null = null;

constructor(private canvas: HTMLCanvasElement) {
this.ctx = canvas.getContext('2d')!;
Expand All @@ -25,19 +29,19 @@ export class NetworkRenderer {
this.debouncedRender = throttle((data: VisualNetworkData | null, selectedNode: VisualNode | null) => {
this._render(data, selectedNode);
}, 16); // Debounce to ~60fps
this.lastRenderedSelectedNode = null;
}

render(data: VisualNetworkData, selectedNode: VisualNode | null) {
this.lastRenderedData = data;
this.lastRenderedSelectedNode = selectedNode;
this._render(data, selectedNode);
this.startAnimation();
}

// Initialize event listeners once
// if (!this.eventListenersInitialized) {
// this.initializeEventListeners();
// this.eventListenersInitialized = true;
// }
render(data?: VisualNetworkData, selectedNode?: VisualNode | null) {
if (data && selectedNode !== undefined) {
this.lastRenderedData = data;
this.lastRenderedSelectedNode = selectedNode;
} else if (!this.lastRenderedData) {
console.warn('No data available to render.');
return;
}
this._render(this.lastRenderedData, this.lastRenderedSelectedNode);
}

setHighlightedNode(nodeId: string | null) {
Expand Down Expand Up @@ -307,8 +311,32 @@ export class NetworkRenderer {
const p2 = { x: toX - controlPointOffset, y: toY };
const p3 = { x: toX, y: toY };

// Draw the curved arrow
this.drawCurvedArrow(p0.x, p0.y, p3.x, p3.y, this.getConnectionColor(conn.weight));
// Set styles for connections
if (this.highlightedConnection && this.highlightedConnection === conn) {
// Pulsing Glow Effect
const glowIntensity = Math.abs(Math.sin(this.pulse)) * 15; // Varies between 0 and 15
this.ctx.lineWidth = 4 + Math.abs(Math.sin(this.pulse)) * 2; // Pulsates line width between 4 and 6
this.ctx.strokeStyle = colors.highlight;
this.ctx.shadowColor = colors.highlight;
this.ctx.shadowBlur = glowIntensity;

// Moving Gradient Effect
const gradient = this.ctx.createLinearGradient(p0.x, p0.y, p3.x, p3.y);
const gradientPosition = (Math.sin(this.pulse) + 1) / 2; // Normalize to [0,1]
gradient.addColorStop(0, colors.gradientStart);
gradient.addColorStop(gradientPosition, colors.gradientEnd);
gradient.addColorStop(1, colors.gradientStart);
this.ctx.strokeStyle = gradient;
} else {
this.ctx.lineWidth = 1;
this.ctx.strokeStyle = this.getConnectionColor(conn.weight);
this.ctx.shadowBlur = 0;
}

this.drawCurvedArrow(p0.x, p0.y, p3.x, p3.y);

// Reset shadow after drawing
this.ctx.shadowBlur = 0;

// Store control points for hit detection
this.connectionControlPoints.push({ connection: conn, p0, p1, p2, p3 });
Expand Down Expand Up @@ -343,14 +371,10 @@ export class NetworkRenderer {
this.ctx.restore();
}

private drawCurvedArrow(fromX: number, fromY: number, toX: number, toY: number, color: string) {
private drawCurvedArrow(fromX: number, fromY: number, toX: number, toY: number) {
const headLength = 10;
const controlPointOffset = Math.abs(toX - fromX) * 0.2;

this.ctx.strokeStyle = color;
this.ctx.fillStyle = color;
this.ctx.lineWidth = 2;

// Draw the curved line
this.ctx.beginPath();
this.ctx.moveTo(fromX, fromY);
Expand Down Expand Up @@ -399,7 +423,6 @@ export class NetworkRenderer {
}
}


private _render(data: VisualNetworkData, selectedNode: VisualNode | null) {
this.clear();
this.ctx.save();
Expand All @@ -420,7 +443,6 @@ export class NetworkRenderer {
this.onConnectionClick = callback;
}


// Method to get connection at a specific point within epsilon radius
getConnectionAtPoint(x: number, y: number): VisualConnection | null {
for (const { connection, p0, p1, p2, p3 } of this.connectionControlPoints) {
Expand All @@ -432,7 +454,6 @@ export class NetworkRenderer {
return null;
}


setSelectedConnection(connection: VisualConnection | null) {
this.selectedConnection = connection;
}
Expand Down Expand Up @@ -471,4 +492,37 @@ export class NetworkRenderer {

return minDist;
}

highlightConnection(connection: VisualConnection): void {
this.highlightedConnection = connection;
this.render(this.lastRenderedData!, this.lastRenderedSelectedNode);
}

clearHighlightedConnection(): void {
this.highlightedConnection = null;
this.render(this.lastRenderedData!, this.lastRenderedSelectedNode);
}

// Start the animation loop
private startAnimation() {
const animate = () => {
this.pulse += 0.05; // Increment pulse
this.render(this.lastRenderedData, this.lastRenderedSelectedNode);
this.animationFrameId = requestAnimationFrame(animate);
};
this.animationFrameId = requestAnimationFrame(animate);
}

// Stop the animation loop
private stopAnimation() {
if (this.animationFrameId !== null) {
cancelAnimationFrame(this.animationFrameId);
this.animationFrameId = null;
}
}

// Clean up when the renderer is no longer needed
destroy() {
this.stopAnimation();
}
}
7 changes: 7 additions & 0 deletions src/styles/colors.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,11 @@ export const colors = {
textDark: 'var(--text-light)', // Example value

success: '#28a745', // Add this line

// Add the highlight color
highlight: '#FFD700', // Gold color for highlighting

// Add gradient colors
gradientStart: '#FFD700', // Start of the gradient (same as highlight)
gradientEnd: '#FFA500', // End of the gradient (Orange)
};

0 comments on commit 3621810

Please sign in to comment.