diff --git a/mpl_draggable_line/_line.py b/mpl_draggable_line/_line.py index 1431b40..489bc39 100644 --- a/mpl_draggable_line/_line.py +++ b/mpl_draggable_line/_line.py @@ -6,6 +6,7 @@ from matplotlib.backend_bases import MouseEvent from matplotlib.cbook import CallbackRegistry from matplotlib.lines import Line2D +from matplotlib.transforms import IdentityTransform, blended_transform_factory from matplotlib.widgets import AxesWidget __all__ = [ @@ -16,7 +17,16 @@ class DraggableLine(AxesWidget): - def __init__(self, ax, x, y, grab_range=10, useblit=False, **kwargs) -> None: + def __init__( + self, + ax, + x, + y, + grab_range=10, + useblit=False, + grab_range_transform=None, + **kwargs, + ) -> None: """ Parameters ---------- @@ -29,6 +39,9 @@ def __init__(self, ax, x, y, grab_range=10, useblit=False, **kwargs) -> None: Whether to use blitting for faster drawing (if supported by the backend). See the tutorial :doc:`/tutorials/advanced/blitting` for details. + grab_range_transform : matplotlib.transform.Transform, optional + The transform to use for the handle positions when calculating + if a handle has been grabbed. **kwargs : Passed on to Line2D for styling """ @@ -43,6 +56,8 @@ def __init__(self, ax, x, y, grab_range=10, useblit=False, **kwargs) -> None: marker = kwargs.pop("marker", "o") color = kwargs.pop("color", "k") transform = kwargs.pop("transform", self.ax.transData) + self._grab_range_transform = grab_range_transform or self.ax.transLimits + self._handles = Line2D( [x[0], center_x, x[1]], [y[0], center_y, y[1]], @@ -108,12 +123,11 @@ def _on_press(self, event: MouseEvent): if not self.canvas.widgetlock.available(self): return # figure out if any handles are being grabbed - # maybe possible to do this with a pick event? x, y = self._handles.get_data() - # this is taken pretty much directly from the implementation - # in matplotlib.widget.ToolHandles.closest - pts = self.ax.transLimits.transform(np.column_stack([x, y])) + # this is a modified version of + # matplotlib.widget.ToolHandles.closest + pts = self._grab_range_transform.transform(np.column_stack([x, y])) diff = pts - self.ax.transLimits.transform((event.xdata, event.ydata)) dist = np.hypot(*diff.T) idx = np.argmin(dist) @@ -227,6 +241,9 @@ def __init__(self, ax, x, grab_range=0.1, useblit=False, **kwargs) -> None: grab_range=grab_range, useblit=useblit, transform=ax.get_xaxis_transform(), + grab_range_transform=blended_transform_factory( + ax.transLimits, IdentityTransform() + ), **kwargs, ) self._y_lock = True @@ -292,6 +309,9 @@ def __init__(self, ax, y, grab_range=0.1, useblit=False, **kwargs) -> None: grab_range=grab_range, useblit=useblit, transform=ax.get_yaxis_transform(), + grab_range_transform=blended_transform_factory( + IdentityTransform(), ax.transLimits + ), **kwargs, ) self._x_lock = True