diff --git a/src/components/drawing.tsx b/src/components/drawing.tsx index ce9dddf..7ca093d 100644 --- a/src/components/drawing.tsx +++ b/src/components/drawing.tsx @@ -1,7 +1,7 @@ import React, { useCallback, useEffect, useRef, useState } from "react"; import { nanoid } from "nanoid"; -import { DrawingMode, Graph, Point, RubberBand } from "./graph"; +import { DrawingMode, Graph, Point, RubberBand, Transform } from "./graph"; import { Edge, GraphData, Node } from "../type"; import { DragIcon } from "./drawing/drag-icon"; @@ -38,6 +38,7 @@ export const Drawing = (props: Props) => { const [selectedNodeForModal, setSelectedNodeForModal] = useState(undefined); const widthRef = useRef(0); const heightRef = useRef(0); + const transformRef = useRef(); const setSelectedNodeId = useCallback((id?: string, skipToggle?: boolean) => { if (drawingMode === "select") { @@ -50,11 +51,16 @@ export const Drawing = (props: Props) => { heightRef.current = height; }; + const handleTransformed = (transform: Transform) => { + transformRef.current = transform; + }; + const translateToGraphPoint = (e: MouseEvent|React.MouseEvent): Point => { // the offsets were determined visually to put the state centered on the mouse + const {x, y, k} = transformRef.current ?? {x: 0, y: 0, k: 1}; return { - x: e.clientX - 50 - (widthRef.current / 2), - y: e.clientY - 10 - (heightRef.current / 2), + x: ((e.clientX - 50 - (widthRef.current / 2)) - x) / k, + y: ((e.clientY - 10 - (heightRef.current / 2)) - y) / k, }; }; @@ -106,7 +112,6 @@ export const Drawing = (props: Props) => { */ const addNode = useCallback(({x, y}: {x: number, y: number}) => { - console.log("ADD NODE"); setGraph(prev => { const id = nanoid(); const label = `State ${prev.nodes.length + 1}`; @@ -279,6 +284,7 @@ export const Drawing = (props: Props) => { onDragStop={handleDragStop} setSelectedNodeId={setSelectedNodeId} onDimensions={handleDimensionChange} + onTransformed={handleTransformed} /> void; onDragStop?: (id: string, pos: Point) => void; onDimensions?: (dimensions: {width: number, height: number}) => void; + onTransformed?: (transform: Transform) => void; setSelectedNodeId: (id?: string, skipToggle?: boolean) => void; }; @@ -237,7 +240,7 @@ export const Graph = (props: Props) => { const {graph, highlightNode, highlightLoopOnNode, highlightEdge, highlightAllNextNodes, allowDragging, autoArrange, rubberBand, drawingMode, onClick, onNodeClick, onNodeDoubleClick, onEdgeClick, onDragStop, - selectedNodeId, setSelectedNodeId, animating, onDimensions} = props; + selectedNodeId, setSelectedNodeId, animating, onDimensions, onTransformed} = props; const svgRef = useRef(null); const wrapperRef = useRef(null); const dimensions = useResizeObserver(wrapperRef); @@ -247,8 +250,9 @@ export const Graph = (props: Props) => { const lastClickTimeRef = useRef(undefined); const lastClickIdRef = useRef(undefined); const draggedRef = useRef(false); + const transformRef = useRef(undefined); - const highlightSelected = useCallback((svg: d3.Selection) => { + const highlightSelected = useCallback((svg: d3.Selection) => { if (animating || !selectedNodeId) { return; } @@ -260,7 +264,7 @@ export const Graph = (props: Props) => { // highlight selected node svg - .selectAll("g") + .selectAll("g.node") .selectAll("ellipse") .style("opacity", unselectedOpacity) .filter((d: any) => connectedNodeIds.includes(d.id)) @@ -305,7 +309,7 @@ export const Graph = (props: Props) => { // highlight text svg - .selectAll("g") + .selectAll("g.node") .selectAll("text") .style("opacity", unselectedOpacity) .filter((d: any) => connectedNodeIds.includes(d.id)) @@ -373,6 +377,22 @@ export const Graph = (props: Props) => { } }, [d3Graph, graph]); + // enable zooming + useEffect(() => { + const svg = d3.select(svgRef.current); + + // add zoom + const zoom = d3.zoom() + //.scaleExtent([0.5, 10]) // This defines the zoom levels (min, max) + .on("zoom", (e) => { + const root = svg.select("g.root"); + root.attr("transform", e.transform); + transformRef.current = e.transform; + onTransformed?.(e.transform); + }); + svg.call(zoom as any); + }); + // draw the graph useEffect(() => { if (!svgRef.current) { @@ -385,8 +405,15 @@ export const Graph = (props: Props) => { // clear the existing items svg.selectAll("*").remove(); + // add the root element + const root = svg + .append("g") + .attr("class", "root") + .attr("transform", transformRef.current) + ; + const addArrowMarker = (id: string, color: string, opacity?: number) => { - svg + root .append("svg:defs") .append("svg:marker") .attr("id", id) @@ -415,11 +442,12 @@ export const Graph = (props: Props) => { addArrowMarker("unselectedLoopArrow", "black", unselectedOpacity / lineAndLoopOpacity); // 0 // draw nodes - const nodes = svg - .selectAll("g") + const nodes = root + .selectAll("g.node") .data(d3Graph.nodes) .enter() - .append("g"); + .append("g") + .attr("class", "node"); const dragStart = (d: any) => { simulation?.alphaTarget(0.5).restart(); @@ -565,7 +593,7 @@ export const Graph = (props: Props) => { ].join(" "); // draw backgrounds for edges to increase click area - const lineBackgrounds = svg + const lineBackgrounds = root .selectAll("line.edge-background") .data(d3Graph.edges) .enter() @@ -585,7 +613,7 @@ export const Graph = (props: Props) => { ; // draw edges - const lines = svg + const lines = root .selectAll("line.edge") .data(d3Graph.edges) .enter() @@ -609,7 +637,7 @@ export const Graph = (props: Props) => { const loopStyle = drawingMode === "delete" ? "cursor: pointer" : "pointer-events: none"; - const loopBackgrounds = svg + const loopBackgrounds = root .selectAll("path.loop-background") .data(d3Graph.nodes.filter(n => n.loops)) .enter() @@ -626,7 +654,7 @@ export const Graph = (props: Props) => { }) ; - const loops = svg + const loops = root .selectAll("path.loop") .data(d3Graph.nodes.filter(n => n.loops)) .enter() @@ -647,7 +675,7 @@ export const Graph = (props: Props) => { const rubberBandNode = d3Graph.nodes.find(n => n.id === rubberBand?.from); if (rubberBand && rubberBandNode) { const data = [{x1: rubberBandNode.x, x2: rubberBand.to.x, y1: rubberBandNode.y, y2: rubberBand.to.y}]; - svg + root .selectAll("line.rubberband") .data(data) .enter() @@ -664,7 +692,7 @@ export const Graph = (props: Props) => { // add loopback "ghost" with background if (!rubberBandNode.loops) { - svg + root .selectAll("path.ghost-loop-background") .data([rubberBandNode]) .enter() @@ -680,7 +708,7 @@ export const Graph = (props: Props) => { onNodeClick?.(rubberBandNode.id); }) ; - svg + root .selectAll("path.ghost-loop") .data([rubberBandNode]) .enter() @@ -704,7 +732,8 @@ export const Graph = (props: Props) => { highlightSelected(svg); }, [svgRef, d3Graph, allowDragging, autoArrange, rubberBand, drawingMode, - onNodeClick, onNodeDoubleClick, onEdgeClick, onDragStop, setSelectedNodeId, selectedNodeId, highlightSelected]); + onNodeClick, onNodeDoubleClick, onEdgeClick, onDragStop, setSelectedNodeId, + selectedNodeId, highlightSelected]); // animate the node if needed useEffect(() => { @@ -713,22 +742,23 @@ export const Graph = (props: Props) => { } const svg = d3.select(svgRef.current); + const root = svg.select("g.root"); // de-highlight all nodes - svg - .selectAll("g") + root + .selectAll("g.node") .selectAll("ellipse") .attr("fill", "#fff"); // highlight animated node - svg - .selectAll("g") + root + .selectAll("g.node") .selectAll("ellipse") .filter((d: any) => highlightNode?.id === d.id) .attr("fill", animatedNodeColor); // highlight animated edges - svg + root .selectAll("line") .attr("stroke", "#999") .attr("stroke-dasharray", (d: any) => lineDashArray(d)) @@ -741,7 +771,7 @@ export const Graph = (props: Props) => { .attr("stroke-dasharray", highlightAllNextNodes ? "4" : "") .attr("marker-end", animatedArrowUrl); - svg + root .selectAll("path.loop") .attr("stroke", "#999") .attr("stroke-dasharray", "") @@ -751,7 +781,7 @@ export const Graph = (props: Props) => { .attr("stroke-dasharray", highlightAllNextNodes ? "4" : "") .attr("marker-end", animatedArrowUrl); - highlightSelected(svg); + highlightSelected(root); }, [svgRef, d3Graph.nodes, selectedNodeId, highlightNode, highlightLoopOnNode, highlightEdge, highlightAllNextNodes, highlightSelected]);