Skip to content

Commit

Permalink
Swap button added to the track refinement GUI (DeepLabCut#2680)
Browse files Browse the repository at this point in the history
* Swap feature

* Swap button added

* black formatting

---------

Co-authored-by: Julian ALVAREZ DE GIORGI <[email protected]>
  • Loading branch information
JulianAlvarezdeGiorgi and Julian ALVAREZ DE GIORGI authored Aug 13, 2024
1 parent 9e9d677 commit 165462f
Show file tree
Hide file tree
Showing 44 changed files with 336 additions and 155 deletions.
5 changes: 1 addition & 4 deletions deeplabcut/benchmark/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,7 @@ def _validate_predictions(self, name: str, predictions: dict) -> dict:
"individuals were detected in those images."
)

return {
img: predictions.get(img, tuple())
for img in test_images
}
return {img: predictions.get(img, tuple()) for img in test_images}


@dataclasses.dataclass
Expand Down
4 changes: 1 addition & 3 deletions deeplabcut/benchmark/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,9 +234,7 @@ def calc_rmse_from_obj(
kpts.pop(ind)

test_objects = {
k: v
for k, v in eval_results_obj.items()
if k in gt["annotations"].keys()
k: v for k, v in eval_results_obj.items() if k in gt["annotations"].keys()
}
if len(gt["annotations"]) != len(test_objects):
gt_images = list(gt["annotations"].keys())
Expand Down
6 changes: 3 additions & 3 deletions deeplabcut/create_project/new.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,9 +276,9 @@ def create_new_project(
cfg_file["x2"] = 640
cfg_file["y1"] = 277
cfg_file["y2"] = 624
cfg_file[
"batch_size"
] = 8 # batch size during inference (video - analysis); see https://www.biorxiv.org/content/early/2018/10/30/457242
cfg_file["batch_size"] = (
8 # batch size during inference (video - analysis); see https://www.biorxiv.org/content/early/2018/10/30/457242
)
cfg_file["corner2move2"] = (50, 50)
cfg_file["move2corner"] = True
cfg_file["skeleton_color"] = "black"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -272,9 +272,7 @@ def create_multianimaltraining_dataset(
# Loading the encoder (if necessary downloading from TF)
dlcparent_path = auxiliaryfunctions.get_deeplabcut_path()
defaultconfigfile = os.path.join(dlcparent_path, "pose_cfg.yaml")
model_path = auxfun_models.check_for_weights(
net_type, Path(dlcparent_path)
)
model_path = auxfun_models.check_for_weights(net_type, Path(dlcparent_path))

if Shuffles is None:
Shuffles = range(1, num_shuffles + 1, 1)
Expand Down Expand Up @@ -425,9 +423,9 @@ def create_multianimaltraining_dataset(
"multi_step": [[1e-4, 7500], [5 * 1e-5, 12000], [1e-5, 200000]],
"save_iters": 10000,
"display_iters": 500,
"num_idchannel": len(cfg["individuals"])
if cfg.get("identity", False)
else 0,
"num_idchannel": (
len(cfg["individuals"]) if cfg.get("identity", False) else 0
),
"crop_size": list(crop_size),
"crop_sampling": crop_sampling,
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -988,9 +988,7 @@ def create_training_dataset(
defaultconfigfile = os.path.join(dlcparent_path, "pose_cfg.yaml")
elif posecfg_template:
defaultconfigfile = posecfg_template
model_path = auxfun_models.check_for_weights(
net_type, Path(dlcparent_path)
)
model_path = auxfun_models.check_for_weights(net_type, Path(dlcparent_path))

if Shuffles is None:
Shuffles = range(1, num_shuffles + 1)
Expand Down
7 changes: 5 additions & 2 deletions deeplabcut/gui/tabs/extract_frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ def _set_page(self):

self.main_layout.addWidget(
_create_label_widget(
"Frame extraction from a video subset (optional for automatic extraction)", "font:bold"
"Frame extraction from a video subset (optional for automatic extraction)",
"font:bold",
)
)
self.video_selection_widget = VideoSelectionWidget(self.root, self)
Expand Down Expand Up @@ -206,7 +207,9 @@ def extract_frames(self):
return
first_video = videos[0]
if len(videos) > 1:
self.root.writer.write(f"Only the first video ({first_video}) will be opened.")
self.root.writer.write(
f"Only the first video ({first_video}) will be opened."
)
video_path_in_folder = self._check_symlink(first_video)
_ = launch_napari(str(video_path_in_folder))
return
Expand Down
4 changes: 3 additions & 1 deletion deeplabcut/gui/tabs/modelzoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ def _set_page(self):

tooltip_label = QtWidgets.QLabel()
tooltip_label.setPixmap(
QPixmap(os.path.join(BASE_DIR, "assets", "icons", "help2.png")).scaledToWidth(30)
QPixmap(
os.path.join(BASE_DIR, "assets", "icons", "help2.png")
).scaledToWidth(30)
)
tooltip_label.setToolTip(
"Approximate animal sizes in pixels, for spatial pyramid search. If left blank, defaults to video height +/- 50 pixels",
Expand Down
87 changes: 81 additions & 6 deletions deeplabcut/gui/tracklet_toolbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from deeplabcut.utils.auxfun_videos import VideoReader
from deeplabcut.utils.auxiliaryfunctions import attempt_to_make_folder
from matplotlib.path import Path
from matplotlib.widgets import Slider, LassoSelector, Button, CheckButtons
from matplotlib.widgets import Slider, LassoSelector, Button, CheckButtons, TextBox
from PySide6.QtWidgets import QMessageBox
from PySide6.QtCore import QMutex

Expand Down Expand Up @@ -327,6 +327,9 @@ def __init__(self, manager, videoname, trail_len=50):

self.dps = []

self.swap_id1 = None
self.swap_id2 = None

def _prepare_canvas(self, manager, fig):
params = {
"keymap.save": "s",
Expand Down Expand Up @@ -358,7 +361,7 @@ def _prepare_canvas(self, manager, fig):

img = self.video.read_frame()
self.im = self.ax1.imshow(img)
self.scat = self.ax1.scatter([], [], s=self.dotsize ** 2, picker=True)
self.scat = self.ax1.scatter([], [], s=self.dotsize**2, picker=True)
self.scat.set_offsets(manager.xy[:, 0])
self.scat.set_color(self.colors)
self.trails = sum(
Expand All @@ -374,6 +377,7 @@ def _prepare_canvas(self, manager, fig):
)
self.vline_x = self.ax2.axvline(0, 0, 1, c="k", ls=":")
self.vline_y = self.ax3.axvline(0, 0, 1, c="k", ls=":")

custom_lines = [
plt.Line2D([0], [0], color=self.cmap(i), lw=4)
for i in range(len(manager.individuals))
Expand Down Expand Up @@ -420,10 +424,15 @@ def _prepare_canvas(self, manager, fig):
self.ax_flag = self.fig.add_axes([0.75, 0.1, 0.05, 0.03])
self.ax_save = self.fig.add_axes([0.80, 0.1, 0.05, 0.03])
self.ax_help = self.fig.add_axes([0.85, 0.1, 0.05, 0.03])
self.ax_swap = self.fig.add_axes([0.90, 0.1, 0.05, 0.03]) # New button

self.save_button = Button(self.ax_save, "Save", color="darkorange")
self.save_button.on_clicked(self.save)
self.help_button = Button(self.ax_help, "Help")
self.help_button.on_clicked(self.display_help)
self.swap_button = Button(self.ax_swap, "Swap") # New button
self.swap_button.on_clicked(self.swap_tracklets) # Placeholder action

self.drag_toggle = CheckButtons(self.ax_drag, ["Drag"])
self.drag_toggle.on_clicked(self.toggle_draggable_points)
self.flag_button = Button(self.ax_flag, "Flag")
Expand All @@ -441,9 +450,75 @@ def _prepare_canvas(self, manager, fig):
self.ax1_background = self.fig.canvas.copy_from_bbox(self.ax1.bbox)
self.fig.show()

# Create dropdowns for selecting tracklets to swap, placing them near the swap button
self.ax_dropdown1 = self.fig.add_axes([0.9, 0.15, 0.05, 0.03])
self.ax_dropdown2 = self.fig.add_axes([0.9, 0.20, 0.05, 0.03])
self.textbox1 = TextBox(self.ax_dropdown1, "ID 1")
self.textbox2 = TextBox(self.ax_dropdown2, "ID 2")
self.textbox1.on_submit(self.set_swap_id1)
self.textbox2.on_submit(self.set_swap_id2)

def show(self, fig=None):
self._prepare_canvas(self.manager, fig)

def swap_tracklets(self, event):
if self.swap_id1 is not None and self.swap_id2 is not None:

# Get tracklet indices for each individual
inds1 = [
k
for k in range(len(self.manager.tracklet2id))
if self.manager.tracklet2id[k] == self.swap_id1
]
inds2 = [
k
for k in range(len(self.manager.tracklet2id))
if self.manager.tracklet2id[k] == self.swap_id2
]

print(f"Swapping tracklets {self.swap_id1} and {self.swap_id2}")

# Frames to swap
frames = []
if len(self.cuts) == 2:
frames = list(range(min(self.cuts), max(self.cuts) + 1))
elif len(self.cuts) == 1:
frames = [self.cuts[0]]
else:
frames = list(range(self.curr_frame, self.manager.nframes))

# Swap the tracklets
for i in range(min(len(inds1), len(inds2))):
self.manager.swap_tracklets(inds1[i], inds2[i], frames)
self.display_traces()
self.slider.set_val(self.curr_frame)

def set_swap_id1(self, val):
# check that the input is a valid from the list of individuals
if int(val) in self.manager.tracklet2id:
self.swap_id1 = int(val)
print("ID 1 set.")
else:
print(
f"Invalid ID. Please select a valid ID from the list of individuals: {set(self.manager.tracklet2id)}"
)
self.swap_id1 = None

def set_swap_id2(self, val):
# check that the input is a valid from the list of individuals
if int(val) in self.manager.tracklet2id:
self.swap_id2 = int(val)
print("ID 2 set.")
else:
print(
f"Invalid ID. Please select a valid ID from the list of individuals: {set(self.manager.tracklet2id)}"
)
self.swap_id2 = None

def terminate(self, event):
plt.close(self.fig)
self.player.terminate()

def fill_shaded_areas(self):
self.clean_collections()
if self.picked_pair:
Expand Down Expand Up @@ -587,9 +662,9 @@ def on_press(self, event):
if len(self.cuts) > 1:
self.cuts.sort()
if self.picked_pair:
self.manager.tracklet_swaps[self.picked_pair][
self.cuts
] = ~self.manager.tracklet_swaps[self.picked_pair][self.cuts]
self.manager.tracklet_swaps[self.picked_pair][self.cuts] = (
~self.manager.tracklet_swaps[self.picked_pair][self.cuts]
)
self.fill_shaded_areas()
self.cuts = []
for line in self.ax_slider.lines:
Expand Down Expand Up @@ -807,7 +882,7 @@ def on_change(self, val):

def update_dotsize(self, val):
self.dotsize = val
self.scat.set_sizes([self.dotsize ** 2])
self.scat.set_sizes([self.dotsize**2])

@staticmethod
def calc_distance(x1, y1, x2, y2):
Expand Down
4 changes: 1 addition & 3 deletions deeplabcut/gui/window.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,7 @@ def _check_for_updates(silent=True):
_ = msg.addButton("Skip", msg.RejectRole)
msg.exec_()
if msg.clickedButton() is update_btn:
subprocess.check_call(
[sys.executable, "-m", *command]
)
subprocess.check_call([sys.executable, "-m", *command])


class MainWindow(QMainWindow):
Expand Down
4 changes: 3 additions & 1 deletion deeplabcut/modelzoo/api/superanimal_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,9 @@ def video_inference(
customized_test_config="",
):
if superanimal_name not in MODELOPTIONS:
raise ValueError(f"{superanimal_name} not available. Available ones are: {MODELOPTIONS}. If you are confident `superanimal_name` is right, try updating `dlclibrary` with `pip install -U dlclibrary`.")
raise ValueError(
f"{superanimal_name} not available. Available ones are: {MODELOPTIONS}. If you are confident `superanimal_name` is right, try updating `dlclibrary` with `pip install -U dlclibrary`."
)

dlc_root_path = auxiliaryfunctions.get_deeplabcut_path()

Expand Down
4 changes: 3 additions & 1 deletion deeplabcut/pose_estimation_3d/camera_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
matplotlib_axes_logger.setLevel("ERROR")


def calibrate_cameras(config, cbrow=8, cbcol=6, calibrate=False, alpha=0.4, search_window_size=(11, 11)):
def calibrate_cameras(
config, cbrow=8, cbcol=6, calibrate=False, alpha=0.4, search_window_size=(11, 11)
):
"""This function extracts the corners points from the calibration images, calibrates the camera and stores the calibration files in the project folder (defined in the config file).
Make sure you have around 20-60 pairs of calibration images. The function should be used iteratively to select the right set of calibration images.
Expand Down
22 changes: 7 additions & 15 deletions deeplabcut/pose_estimation_tensorflow/core/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,9 +867,7 @@ def evaluate_network(
pose = predict.argmax_pose_predict(
scmap, locref, test_pose_cfg["stride"]
)
PredicteData[
imageindex, :
] = (
PredicteData[imageindex, :] = (
pose.flatten()
) # NOTE: thereby cfg_test['all_joints_names'] should be same order as bodyparts!

Expand Down Expand Up @@ -971,10 +969,7 @@ def evaluate_network(
print("Plotting...")
foldername = os.path.join(
str(evaluationfolder),
"LabeledImages_"
+ DLCscorer
+ "_"
+ snapshot_name,
"LabeledImages_" + DLCscorer + "_" + snapshot_name,
)
auxiliaryfunctions.attempt_to_make_folder(foldername)
Plotting(
Expand All @@ -997,10 +992,7 @@ def evaluate_network(
).T
foldername = os.path.join(
str(evaluationfolder),
"LabeledImages_"
+ DLCscorer
+ "_"
+ snapshot_name,
"LabeledImages_" + DLCscorer + "_" + snapshot_name,
)
if not os.path.exists(foldername):
print(
Expand Down Expand Up @@ -1104,14 +1096,14 @@ def get_available_requested_snapshots(


def get_snapshots_by_index(
idx: Union[int, str], available_snapshots: List[str],
idx: Union[int, str],
available_snapshots: List[str],
) -> List[str]:
"""
Assume available_snapshots is ordered in ascending order. Returns snapshot names.
"""
if (
isinstance(idx, int)
and -len(available_snapshots) <= idx < len(available_snapshots)
if isinstance(idx, int) and -len(available_snapshots) <= idx < len(
available_snapshots
):
return [available_snapshots[idx]]
elif idx == "all":
Expand Down
25 changes: 8 additions & 17 deletions deeplabcut/pose_estimation_tensorflow/core/evaluate_multianimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,9 +396,7 @@ def evaluate_multianimal_full(
coords_pred = pred["coordinates"][0]
probs_pred = pred["confidence"]
for bpt, xy_gt in df.groupby(level="bodyparts"):
inds_gt = np.flatnonzero(
np.all(~np.isnan(xy_gt), axis=1)
)
inds_gt = np.flatnonzero(np.all(~np.isnan(xy_gt), axis=1))
n_joint = joints.index(bpt)
xy = coords_pred[n_joint]
if inds_gt.size and xy.size:
Expand All @@ -422,9 +420,9 @@ def evaluate_multianimal_full(

if plotting == "bodypart":
temp_xy = GT.unstack("bodyparts")[joints].values
gt = temp_xy.reshape(
(-1, 2, temp_xy.shape[1])
).T.swapaxes(1, 2)
gt = temp_xy.reshape((-1, 2, temp_xy.shape[1])).T.swapaxes(
1, 2
)
h, w, _ = np.shape(frame)
fig.set_size_inches(w / 100, h / 100)
ax.set_xlim(0, w)
Expand Down Expand Up @@ -477,8 +475,7 @@ def evaluate_multianimal_full(
# Calculate overall prediction error
error = df_joint.xs("rmse", level="metrics", axis=1)
mask = (
df_joint.xs("conf", level="metrics", axis=1)
>= cfg["pcutoff"]
df_joint.xs("conf", level="metrics", axis=1) >= cfg["pcutoff"]
)
error_masked = error[mask]
error_train = np.nanmean(error.iloc[trainIndices])
Expand All @@ -505,9 +502,7 @@ def evaluate_multianimal_full(
testIndices,
)
kpt_filename = DLCscorer + "-keypoint-results.csv"
df_keypoint_error.to_csv(
Path(evaluationfolder) / kpt_filename
)
df_keypoint_error.to_csv(Path(evaluationfolder) / kpt_filename)

if show_errors:
string = (
Expand Down Expand Up @@ -666,12 +661,8 @@ def evaluate_multianimal_full(
df.loc(axis=0)[("mAR_train", "mean")] = [
d[0]["mAR"] for d in results[2]
]
df.loc(axis=0)[("mAP_test", "mean")] = [
d[1]["mAP"] for d in results[2]
]
df.loc(axis=0)[("mAR_test", "mean")] = [
d[1]["mAR"] for d in results[2]
]
df.loc(axis=0)[("mAP_test", "mean")] = [d[1]["mAP"] for d in results[2]]
df.loc(axis=0)[("mAR_test", "mean")] = [d[1]["mAR"] for d in results[2]]
with open(data_path.replace("_full.", "_map."), "wb") as file:
pickle.dump((df, paf_scores), file)

Expand Down
Loading

0 comments on commit 165462f

Please sign in to comment.