Skip to content

Commit

Permalink
Connection sidebar
Browse files Browse the repository at this point in the history
  • Loading branch information
vitaliiznak committed Sep 18, 2024
1 parent a726925 commit 184bf19
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 53 deletions.
54 changes: 28 additions & 26 deletions src/NeuralNetworkVisualizer/ConnectionSidebar.tsx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { Component } from "solid-js";
import { Component, Show } from "solid-js";
import { css } from "@emotion/css";
import { colors } from "../styles/colors";
import { VisualConnection } from "../types";
Expand All @@ -9,38 +9,40 @@ interface ConnectionSidebarProps {
}

const ConnectionSidebar: Component<ConnectionSidebarProps> = (props) => {
if (!props.connection) return null;


return (
<div class={styles.sidebar}>
<button class={styles.closeButton} onClick={props.onClose}>×</button>
<h2 class={styles.title}>Connection Details</h2>

<div class={styles.detail}>
<strong>From:</strong> {props.connection.from}
</div>
<div class={styles.detail}>
<strong>To:</strong> {props.connection.to}
</div>
<div class={styles.detail}>
<strong>Weight:</strong> {props.connection.weight.toFixed(4)}
</div>
<div class={styles.detail}>
<strong>Bias:</strong> {props.connection.bias.toFixed(4)}
</div>
<div class={styles.detail}>
<strong>Weight Gradient:</strong> {props.connection.weightGradient?.toFixed(4) || 'N/A'}
</div>
<div class={styles.detail}>
<strong>Bias Gradient:</strong> {props.connection.biasGradient?.toFixed(4) || 'N/A'}
<Show when={props.connection}>
<div class={styles.sidebar}>
<button class={styles.closeButton} onClick={props.onClose}>×</button>
<h2 class={styles.title}>Connection Details</h2>

<div class={styles.detail}>
<strong>From:</strong> {props.connection!.from}
</div>
<div class={styles.detail}>
<strong>To:</strong> {props.connection!.to}
</div>
<div class={styles.detail}>
<strong>Weight:</strong> {props.connection!.weight.toFixed(4)}
</div>
<div class={styles.detail}>
<strong>Bias:</strong> {props.connection!.bias.toFixed(4)}
</div>
<div class={styles.detail}>
<strong>Weight Gradient:</strong> {props.connection!.weightGradient?.toFixed(4) || 'N/A'}
</div>
<div class={styles.detail}>
<strong>Bias Gradient:</strong> {props.connection!.biasGradient?.toFixed(4) || 'N/A'}
</div>
</div>
</div>
</Show>
);
}

const styles = {
sidebar: css`
position: absolute;
position: fixed; /* Changed from absolute to fixed */
right: 0;
top: 0;
width: 250px;
Expand All @@ -50,7 +52,7 @@ const styles = {
padding: 1rem;
box-shadow: -2px 0 5px rgba(0,0,0,0.1);
overflow-y: auto;
z-index: 1000;
z-index: 1001; /* Ensure it's above other elements */
`,
closeButton: css`
position: absolute;
Expand Down
21 changes: 13 additions & 8 deletions src/NeuralNetworkVisualizer/NetworkVisualizer.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import { canvasStyle, containerStyle, tooltipStyle } from "./NetworkVisualizerSt
import { debounce } from "@solid-primitives/scheduled";
import { css } from "@emotion/css";
import { colors } from "../styles/colors";

import ConnectionSidebar from "./ConnectionSidebar";

interface NetworkVisualizerProps {
includeLossNode: boolean;
Expand All @@ -35,7 +35,7 @@ const NetworkVisualizer: Component<NetworkVisualizerProps> = (props) => {

let draggedNode: VisualNode | null = null;
let mouseDownTimer: number | null = null;
let lastPanPosition: { x: number; y: 0 } = { x: 0, y: 0 };
let lastPanPosition: { x: number; y: number } = { x: 0, y: 0 };

const visualData = createMemo(() => {
const layoutCalculatorValue = layoutCalculator();
Expand Down Expand Up @@ -72,7 +72,6 @@ const NetworkVisualizer: Component<NetworkVisualizerProps> = (props) => {
}
});


const setupEventListeners = () => {
const canvas = canvasRef();
if (canvas) {
Expand Down Expand Up @@ -224,13 +223,14 @@ const NetworkVisualizer: Component<NetworkVisualizerProps> = (props) => {
const connection = rendererValue.getConnectionAtPoint(x, y);
if (connection) {
setSelectedConnection(connection);

props.onSidebarToggle(true);
rendererValue.highlightConnection(connection);
} else {
setSelectedConnection(null);
props.onSidebarToggle(false);
rendererValue.clearHighlightedConnection();
}

// Optionally handle neuron selection if needed
};

const handleMouseUp = () => {
Expand Down Expand Up @@ -278,6 +278,14 @@ const NetworkVisualizer: Component<NetworkVisualizerProps> = (props) => {
renderer()?.render(visualData(), null);
}}
/>
<ConnectionSidebar
connection={selectedConnection()}
onClose={() => {
setSelectedConnection(null);
renderer()?.clearHighlightedConnection();
props.onSidebarToggle(false);
}}
/>
<Show when={tooltipData()}>
{(tooltipAccessor) => {
const data = tooltipAccessor();
Expand All @@ -301,11 +309,8 @@ const NetworkVisualizer: Component<NetworkVisualizerProps> = (props) => {
);
}}
</Show>


</div>
);
};


export default NetworkVisualizer;
51 changes: 32 additions & 19 deletions src/NeuralNetworkVisualizer/renderer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@ export class NetworkRenderer {
public nodeWidth: number;
public nodeHeight: number;
private highlightedNodeId: string | null = null;
lastRenderedData: VisualNetworkData | undefined;
lastRenderedSelectedNode: VisualNode | null;
private lastRenderedData: VisualNetworkData | undefined;
private lastRenderedSelectedNode: VisualNode | null = null;
private onConnectionClick: (connection: VisualConnection) => void = () => { };
private connectionPositions: { connection: VisualConnection; path: Path2D }[] = [];
private selectedConnection: VisualConnection | null = null;
private readonly epsilon: number = 5; // pixels
private connectionControlPoints: { connection: VisualConnection; p0: Point; p1: Point; p2: Point; p3: Point }[] = [];
private highlightedConnection: VisualConnection | null = null;

constructor(private canvas: HTMLCanvasElement) {
this.ctx = canvas.getContext('2d')!;
Expand All @@ -25,19 +26,17 @@ export class NetworkRenderer {
this.debouncedRender = throttle((data: VisualNetworkData | null, selectedNode: VisualNode | null) => {
this._render(data, selectedNode);
}, 16); // Debounce to ~60fps
this.lastRenderedSelectedNode = null;
}

render(data: VisualNetworkData, selectedNode: VisualNode | null) {
this.lastRenderedData = data;
this.lastRenderedSelectedNode = selectedNode;
this._render(data, selectedNode);

// Initialize event listeners once
// if (!this.eventListenersInitialized) {
// this.initializeEventListeners();
// this.eventListenersInitialized = true;
// }
render(data?: VisualNetworkData, selectedNode?: VisualNode | null) {
if (data && selectedNode !== undefined) {
this.lastRenderedData = data;
this.lastRenderedSelectedNode = selectedNode;
} else if (!this.lastRenderedData) {
console.warn('No data available to render.');
return;
}
this._render(this.lastRenderedData, this.lastRenderedSelectedNode);
}

setHighlightedNode(nodeId: string | null) {
Expand Down Expand Up @@ -308,7 +307,15 @@ export class NetworkRenderer {
const p3 = { x: toX, y: toY };

// Draw the curved arrow
this.drawCurvedArrow(p0.x, p0.y, p3.x, p3.y, this.getConnectionColor(conn.weight));
if (this.highlightedConnection && this.highlightedConnection === conn) {
this.ctx.lineWidth = 3;
this.ctx.strokeStyle = colors.highlight;
} else {
this.ctx.lineWidth = 1;
this.ctx.strokeStyle = this.getConnectionColor(conn.weight);
}

this.drawCurvedArrow(p0.x, p0.y, p3.x, p3.y);

// Store control points for hit detection
this.connectionControlPoints.push({ connection: conn, p0, p1, p2, p3 });
Expand Down Expand Up @@ -343,14 +350,10 @@ export class NetworkRenderer {
this.ctx.restore();
}

private drawCurvedArrow(fromX: number, fromY: number, toX: number, toY: number, color: string) {
private drawCurvedArrow(fromX: number, fromY: number, toX: number, toY: number) {
const headLength = 10;
const controlPointOffset = Math.abs(toX - fromX) * 0.2;

this.ctx.strokeStyle = color;
this.ctx.fillStyle = color;
this.ctx.lineWidth = 2;

// Draw the curved line
this.ctx.beginPath();
this.ctx.moveTo(fromX, fromY);
Expand Down Expand Up @@ -471,4 +474,14 @@ export class NetworkRenderer {

return minDist;
}

highlightConnection(connection: VisualConnection): void {
this.highlightedConnection = connection;
this.render();
}

clearHighlightedConnection(): void {
this.highlightedConnection = null;
this.render();
}
}

0 comments on commit 184bf19

Please sign in to comment.