Skip to content

Commit

Permalink
fix transforms for VLine and HLine
Browse files Browse the repository at this point in the history
  • Loading branch information
ianhi committed Oct 12, 2022
1 parent d6a9c07 commit d67c2eb
Showing 1 changed file with 25 additions and 5 deletions.
30 changes: 25 additions & 5 deletions mpl_draggable_line/_line.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from matplotlib.backend_bases import MouseEvent
from matplotlib.cbook import CallbackRegistry
from matplotlib.lines import Line2D
from matplotlib.transforms import IdentityTransform, blended_transform_factory
from matplotlib.widgets import AxesWidget

__all__ = [
Expand All @@ -16,7 +17,16 @@


class DraggableLine(AxesWidget):
def __init__(self, ax, x, y, grab_range=10, useblit=False, **kwargs) -> None:
def __init__(
self,
ax,
x,
y,
grab_range=10,
useblit=False,
grab_range_transform=None,
**kwargs,
) -> None:
"""
Parameters
----------
Expand All @@ -29,6 +39,9 @@ def __init__(self, ax, x, y, grab_range=10, useblit=False, **kwargs) -> None:
Whether to use blitting for faster drawing (if supported by the
backend). See the tutorial :doc:`/tutorials/advanced/blitting`
for details.
grab_range_transform : matplotlib.transform.Transform, optional
The transform to use for the handle positions when calculating
if a handle has been grabbed.
**kwargs :
Passed on to Line2D for styling
"""
Expand All @@ -43,6 +56,8 @@ def __init__(self, ax, x, y, grab_range=10, useblit=False, **kwargs) -> None:
marker = kwargs.pop("marker", "o")
color = kwargs.pop("color", "k")
transform = kwargs.pop("transform", self.ax.transData)
self._grab_range_transform = grab_range_transform or self.ax.transLimits

self._handles = Line2D(
[x[0], center_x, x[1]],
[y[0], center_y, y[1]],
Expand Down Expand Up @@ -108,12 +123,11 @@ def _on_press(self, event: MouseEvent):
if not self.canvas.widgetlock.available(self):
return
# figure out if any handles are being grabbed
# maybe possible to do this with a pick event?

x, y = self._handles.get_data()
# this is taken pretty much directly from the implementation
# in matplotlib.widget.ToolHandles.closest
pts = self.ax.transLimits.transform(np.column_stack([x, y]))
# this is a modified version of
# matplotlib.widget.ToolHandles.closest
pts = self._grab_range_transform.transform(np.column_stack([x, y]))
diff = pts - self.ax.transLimits.transform((event.xdata, event.ydata))
dist = np.hypot(*diff.T)
idx = np.argmin(dist)
Expand Down Expand Up @@ -227,6 +241,9 @@ def __init__(self, ax, x, grab_range=0.1, useblit=False, **kwargs) -> None:
grab_range=grab_range,
useblit=useblit,
transform=ax.get_xaxis_transform(),
grab_range_transform=blended_transform_factory(
ax.transLimits, IdentityTransform()
),
**kwargs,
)
self._y_lock = True
Expand Down Expand Up @@ -292,6 +309,9 @@ def __init__(self, ax, y, grab_range=0.1, useblit=False, **kwargs) -> None:
grab_range=grab_range,
useblit=useblit,
transform=ax.get_yaxis_transform(),
grab_range_transform=blended_transform_factory(
IdentityTransform(), ax.transLimits
),
**kwargs,
)
self._x_lock = True
Expand Down

0 comments on commit d67c2eb

Please sign in to comment.